Merge branch 'enhance-api-key-handling-logging' into 'main'
fix: enhance API key handling and logging in middleware See merge request unboundsoftware/schemas!623
This commit was merged in pull request #627.
This commit is contained in:
+9
-1
@@ -24,9 +24,17 @@ RUN GOOS=linux GOARCH=amd64 go build \
|
|||||||
FROM scratch as export
|
FROM scratch as export
|
||||||
COPY --from=build /build/coverage.txt /
|
COPY --from=build /build/coverage.txt /
|
||||||
|
|
||||||
FROM scratch
|
FROM node:22-alpine
|
||||||
ENV TZ Europe/Stockholm
|
ENV TZ Europe/Stockholm
|
||||||
|
|
||||||
|
# Install wgc CLI globally for Cosmo Router composition
|
||||||
|
RUN npm install -g wgc@latest
|
||||||
|
|
||||||
|
# Copy timezone data and certificates
|
||||||
COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo
|
COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo
|
||||||
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||||
|
|
||||||
|
# Copy the service binary
|
||||||
COPY --from=build /release/service /
|
COPY --from=build /release/service /
|
||||||
|
|
||||||
CMD ["/service"]
|
CMD ["/service"]
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ import (
|
|||||||
"gitlab.com/unboundsoftware/schemas/domain"
|
"gitlab.com/unboundsoftware/schemas/domain"
|
||||||
"gitlab.com/unboundsoftware/schemas/graph"
|
"gitlab.com/unboundsoftware/schemas/graph"
|
||||||
"gitlab.com/unboundsoftware/schemas/graph/generated"
|
"gitlab.com/unboundsoftware/schemas/graph/generated"
|
||||||
|
"gitlab.com/unboundsoftware/schemas/hash"
|
||||||
"gitlab.com/unboundsoftware/schemas/logging"
|
"gitlab.com/unboundsoftware/schemas/logging"
|
||||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||||
"gitlab.com/unboundsoftware/schemas/monitoring"
|
"gitlab.com/unboundsoftware/schemas/monitoring"
|
||||||
@@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
|
|||||||
|
|
||||||
srv.AddTransport(transport.Websocket{
|
srv.AddTransport(transport.Websocket{
|
||||||
KeepAlivePingInterval: 10 * time.Second,
|
KeepAlivePingInterval: 10 * time.Second,
|
||||||
|
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
logger.Info("WebSocket connection with API key", "has_key", true)
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key (same logic as auth middleware)
|
||||||
|
if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
} else {
|
||||||
|
logger.Warn("WebSocket: No organization found for API key")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("WebSocket connection without API key")
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
},
|
||||||
})
|
})
|
||||||
srv.AddTransport(transport.Options{})
|
srv.AddTransport(transport.Options{})
|
||||||
srv.AddTransport(transport.GET{})
|
srv.AddTransport(transport.GET{})
|
||||||
|
|||||||
@@ -0,0 +1,336 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/99designs/gqlgen/graphql/handler/transport"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||||
|
|
||||||
|
"gitlab.com/unboundsoftware/schemas/domain"
|
||||||
|
"gitlab.com/unboundsoftware/schemas/hash"
|
||||||
|
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCache is a mock implementation for testing
|
||||||
|
type MockCache struct {
|
||||||
|
organizations map[string]*domain.Organization
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||||
|
return m.organizations[apiKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Organization",
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := "test-api-key-123"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{
|
||||||
|
hashedKey: org,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is in context
|
||||||
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||||
|
assert.Equal(t, apiKey, value.(string))
|
||||||
|
} else {
|
||||||
|
t.Fatal("API key not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check organization is in context
|
||||||
|
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
||||||
|
capturedOrg, ok := value.(domain.Organization)
|
||||||
|
require.True(t, ok, "Organization should be of correct type")
|
||||||
|
assert.Equal(t, org.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
||||||
|
} else {
|
||||||
|
t.Fatal("Organization not found in context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := "invalid-api-key"
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is in context
|
||||||
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||||
|
assert.Equal(t, apiKey, value.(string))
|
||||||
|
} else {
|
||||||
|
t.Fatal("API key not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check organization is NOT in context (since API key is invalid)
|
||||||
|
value := resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is NOT in context
|
||||||
|
value := resultCtx.Value(middleware.ApiKey)
|
||||||
|
assert.Nil(t, value, "API key should not be set when not provided")
|
||||||
|
|
||||||
|
// Check organization is NOT in context
|
||||||
|
value = resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": "", // Empty string
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is NOT in context (because empty string fails the condition)
|
||||||
|
value := resultCtx.Value(middleware.ApiKey)
|
||||||
|
assert.Nil(t, value, "API key should not be set when empty")
|
||||||
|
|
||||||
|
// Check organization is NOT in context
|
||||||
|
value = resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test
|
||||||
|
ctx := context.Background()
|
||||||
|
initPayload := transport.InitPayload{
|
||||||
|
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, resultPayload)
|
||||||
|
|
||||||
|
// Check API key is NOT in context (type assertion fails)
|
||||||
|
value := resultCtx.Value(middleware.ApiKey)
|
||||||
|
assert.Nil(t, value, "API key should not be set when wrong type")
|
||||||
|
|
||||||
|
// Check organization is NOT in context
|
||||||
|
value = resultCtx.Value(middleware.OrganizationKey)
|
||||||
|
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||||
|
// Setup - create multiple organizations
|
||||||
|
org1ID := uuid.New()
|
||||||
|
org1 := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(org1ID.String()),
|
||||||
|
},
|
||||||
|
Name: "Organization 1",
|
||||||
|
}
|
||||||
|
|
||||||
|
org2ID := uuid.New()
|
||||||
|
org2 := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(org2ID.String()),
|
||||||
|
},
|
||||||
|
Name: "Organization 2",
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey1 := "api-key-org-1"
|
||||||
|
apiKey2 := "api-key-org-2"
|
||||||
|
hashedKey1 := hash.String(apiKey1)
|
||||||
|
hashedKey2 := hash.String(apiKey2)
|
||||||
|
|
||||||
|
mockCache := &MockCache{
|
||||||
|
organizations: map[string]*domain.Organization{
|
||||||
|
hashedKey1: org1,
|
||||||
|
hashedKey2: org2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create InitFunc
|
||||||
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||||
|
// Extract API key from WebSocket connection_init payload
|
||||||
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||||
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||||
|
|
||||||
|
// Look up organization by API key
|
||||||
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||||
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ctx, &initPayload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with first API key
|
||||||
|
ctx1 := context.Background()
|
||||||
|
initPayload1 := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey1,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
||||||
|
capturedOrg, ok := value.(domain.Organization)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, org1.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
||||||
|
} else {
|
||||||
|
t.Fatal("Organization 1 not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with second API key
|
||||||
|
ctx2 := context.Background()
|
||||||
|
initPayload2 := transport.InitPayload{
|
||||||
|
"X-Api-Key": apiKey2,
|
||||||
|
}
|
||||||
|
|
||||||
|
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
||||||
|
capturedOrg, ok := value.(domain.Organization)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, org2.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
||||||
|
} else {
|
||||||
|
t.Fatal("Organization 2 not found in context")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ require (
|
|||||||
github.com/apex/log v1.9.0
|
github.com/apex/log v1.9.0
|
||||||
github.com/auth0/go-jwt-middleware/v2 v2.3.0
|
github.com/auth0/go-jwt-middleware/v2 v2.3.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jmoiron/sqlx v1.4.0
|
github.com/jmoiron/sqlx v1.4.0
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/pressly/goose/v3 v3.26.0
|
github.com/pressly/goose/v3 v3.26.0
|
||||||
@@ -30,6 +31,7 @@ require (
|
|||||||
go.opentelemetry.io/otel/sdk/log v0.14.0
|
go.opentelemetry.io/otel/sdk/log v0.14.0
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||||
go.opentelemetry.io/otel/trace v1.38.0
|
go.opentelemetry.io/otel/trace v1.38.0
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -41,7 +43,6 @@ require (
|
|||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
|
||||||
github.com/gorilla/websocket v1.5.1 // indirect
|
github.com/gorilla/websocket v1.5.1 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||||
@@ -51,6 +52,7 @@ require (
|
|||||||
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
|
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
|
||||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||||
github.com/sosodev/duration v1.3.1 // indirect
|
github.com/sosodev/duration v1.3.1 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
github.com/tidwall/gjson v1.17.0 // indirect
|
github.com/tidwall/gjson v1.17.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
@@ -72,5 +74,4 @@ require (
|
|||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||||
google.golang.org/grpc v1.75.0 // indirect
|
google.golang.org/grpc v1.75.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.10 // indirect
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -141,6 +141,8 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA
|
|||||||
github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE=
|
github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE=
|
||||||
github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o=
|
github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
|||||||
+98
-27
@@ -1,54 +1,125 @@
|
|||||||
package graph
|
package graph
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"gitlab.com/unboundsoftware/schemas/graph/model"
|
"gitlab.com/unboundsoftware/schemas/graph/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
|
// CommandExecutor is an interface for executing external commands
|
||||||
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
|
// This allows for mocking in tests
|
||||||
// Build the Cosmo router config structure
|
type CommandExecutor interface {
|
||||||
// This is a simplified version - you may need to adjust based on actual Cosmo requirements
|
Execute(name string, args ...string) ([]byte, error)
|
||||||
config := map[string]interface{}{
|
|
||||||
"version": "1",
|
|
||||||
"subgraphs": convertSubGraphsToCosmo(subGraphs),
|
|
||||||
// Add other Cosmo-specific configuration as needed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal to JSON
|
|
||||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("marshal cosmo config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(configJSON), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} {
|
// DefaultCommandExecutor implements CommandExecutor using os/exec
|
||||||
cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs))
|
type DefaultCommandExecutor struct{}
|
||||||
|
|
||||||
|
// Execute runs a command and returns its combined output
|
||||||
|
func (e *DefaultCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
|
||||||
|
cmd := exec.Command(name, args...)
|
||||||
|
return cmd.CombinedOutput()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
|
||||||
|
// using the official wgc CLI tool via npx
|
||||||
|
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
|
||||||
|
return GenerateCosmoRouterConfigWithExecutor(subGraphs, &DefaultCommandExecutor{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCosmoRouterConfigWithExecutor generates a Cosmo Router execution config from subgraphs
|
||||||
|
// using the provided command executor (useful for testing)
|
||||||
|
func GenerateCosmoRouterConfigWithExecutor(subGraphs []*model.SubGraph, executor CommandExecutor) (string, error) {
|
||||||
|
if len(subGraphs) == 0 {
|
||||||
|
return "", fmt.Errorf("no subgraphs provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a temporary directory for composition
|
||||||
|
tmpDir, err := os.MkdirTemp("", "cosmo-compose-*")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create temp dir: %w", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
// Write each subgraph SDL to a file
|
||||||
|
type SubgraphConfig struct {
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||||
|
Schema map[string]string `yaml:"schema"`
|
||||||
|
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputConfig struct {
|
||||||
|
Version int `yaml:"version"`
|
||||||
|
Subgraphs []SubgraphConfig `yaml:"subgraphs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
inputConfig := InputConfig{
|
||||||
|
Version: 1,
|
||||||
|
Subgraphs: make([]SubgraphConfig, 0, len(subGraphs)),
|
||||||
|
}
|
||||||
|
|
||||||
for _, sg := range subGraphs {
|
for _, sg := range subGraphs {
|
||||||
cosmoSg := map[string]interface{}{
|
// Write SDL to a temp file
|
||||||
"name": sg.Service,
|
schemaFile := filepath.Join(tmpDir, fmt.Sprintf("%s.graphql", sg.Service))
|
||||||
"sdl": sg.Sdl,
|
if err := os.WriteFile(schemaFile, []byte(sg.Sdl), 0o644); err != nil {
|
||||||
|
return "", fmt.Errorf("write schema file for %s: %w", sg.Service, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
subgraphCfg := SubgraphConfig{
|
||||||
|
Name: sg.Service,
|
||||||
|
Schema: map[string]string{
|
||||||
|
"file": schemaFile,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if sg.URL != nil {
|
if sg.URL != nil {
|
||||||
cosmoSg["routing_url"] = *sg.URL
|
subgraphCfg.RoutingURL = *sg.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
if sg.WsURL != nil {
|
if sg.WsURL != nil {
|
||||||
cosmoSg["subscription"] = map[string]interface{}{
|
subgraphCfg.Subscription = map[string]interface{}{
|
||||||
"url": *sg.WsURL,
|
"url": *sg.WsURL,
|
||||||
"protocol": "ws",
|
"protocol": "ws",
|
||||||
"websocket_subprotocol": "graphql-ws",
|
"websocket_subprotocol": "graphql-ws",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg)
|
inputConfig.Subgraphs = append(inputConfig.Subgraphs, subgraphCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cosmoSubgraphs
|
// Write input config YAML
|
||||||
|
inputFile := filepath.Join(tmpDir, "input.yaml")
|
||||||
|
inputYAML, err := yaml.Marshal(inputConfig)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal input config: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(inputFile, inputYAML, 0o644); err != nil {
|
||||||
|
return "", fmt.Errorf("write input config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute wgc router compose
|
||||||
|
// wgc is installed globally in the Docker image
|
||||||
|
outputFile := filepath.Join(tmpDir, "config.json")
|
||||||
|
output, err := executor.Execute("wgc", "router", "compose",
|
||||||
|
"--input", inputFile,
|
||||||
|
"--out", outputFile,
|
||||||
|
"--suppress-warnings",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the generated config
|
||||||
|
configJSON, err := os.ReadFile(outputFile)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read output config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(configJSON), nil
|
||||||
}
|
}
|
||||||
|
|||||||
+293
-86
@@ -2,14 +2,185 @@ package graph
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"gitlab.com/unboundsoftware/schemas/graph/model"
|
"gitlab.com/unboundsoftware/schemas/graph/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MockCommandExecutor implements CommandExecutor for testing
|
||||||
|
type MockCommandExecutor struct {
|
||||||
|
// CallCount tracks how many times Execute was called
|
||||||
|
CallCount int
|
||||||
|
// LastArgs stores the arguments from the last call
|
||||||
|
LastArgs []string
|
||||||
|
// Error can be set to simulate command failures
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute mocks the wgc command by generating a realistic config.json file
|
||||||
|
func (m *MockCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
|
||||||
|
m.CallCount++
|
||||||
|
m.LastArgs = append([]string{name}, args...)
|
||||||
|
|
||||||
|
if m.Error != nil {
|
||||||
|
return nil, m.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the input file to understand what subgraphs we're composing
|
||||||
|
var inputFile, outputFile string
|
||||||
|
for i, arg := range args {
|
||||||
|
if arg == "--input" && i+1 < len(args) {
|
||||||
|
inputFile = args[i+1]
|
||||||
|
}
|
||||||
|
if arg == "--out" && i+1 < len(args) {
|
||||||
|
outputFile = args[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if inputFile == "" || outputFile == "" {
|
||||||
|
return nil, fmt.Errorf("missing required arguments")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the input YAML to get subgraph information
|
||||||
|
inputData, err := os.ReadFile(inputFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read input file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var input struct {
|
||||||
|
Version int `yaml:"version"`
|
||||||
|
Subgraphs []struct {
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||||
|
Schema map[string]string `yaml:"schema"`
|
||||||
|
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||||
|
} `yaml:"subgraphs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := yaml.Unmarshal(inputData, &input); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse input YAML: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a realistic Cosmo Router config based on the input
|
||||||
|
config := map[string]interface{}{
|
||||||
|
"version": "mock-version-uuid",
|
||||||
|
"subgraphs": func() []map[string]interface{} {
|
||||||
|
subgraphs := make([]map[string]interface{}, len(input.Subgraphs))
|
||||||
|
for i, sg := range input.Subgraphs {
|
||||||
|
subgraph := map[string]interface{}{
|
||||||
|
"id": fmt.Sprintf("mock-id-%d", i),
|
||||||
|
"name": sg.Name,
|
||||||
|
}
|
||||||
|
if sg.RoutingURL != "" {
|
||||||
|
subgraph["routingUrl"] = sg.RoutingURL
|
||||||
|
}
|
||||||
|
subgraphs[i] = subgraph
|
||||||
|
}
|
||||||
|
return subgraphs
|
||||||
|
}(),
|
||||||
|
"engineConfig": map[string]interface{}{
|
||||||
|
"graphqlSchema": generateMockSchema(input.Subgraphs),
|
||||||
|
"datasourceConfigurations": func() []map[string]interface{} {
|
||||||
|
dsConfigs := make([]map[string]interface{}, len(input.Subgraphs))
|
||||||
|
for i, sg := range input.Subgraphs {
|
||||||
|
// Read SDL from file
|
||||||
|
sdl := ""
|
||||||
|
if schemaFile, ok := sg.Schema["file"]; ok {
|
||||||
|
if sdlData, err := os.ReadFile(schemaFile); err == nil {
|
||||||
|
sdl = string(sdlData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dsConfig := map[string]interface{}{
|
||||||
|
"id": fmt.Sprintf("datasource-%d", i),
|
||||||
|
"kind": "GRAPHQL",
|
||||||
|
"customGraphql": map[string]interface{}{
|
||||||
|
"federation": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"serviceSdl": sdl,
|
||||||
|
},
|
||||||
|
"subscription": func() map[string]interface{} {
|
||||||
|
if len(sg.Subscription) > 0 {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"url": map[string]interface{}{
|
||||||
|
"staticVariableContent": sg.Subscription["url"],
|
||||||
|
},
|
||||||
|
"protocol": sg.Subscription["protocol"],
|
||||||
|
"websocketSubprotocol": sg.Subscription["websocket_subprotocol"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]interface{}{
|
||||||
|
"enabled": false,
|
||||||
|
}
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dsConfigs[i] = dsConfig
|
||||||
|
}
|
||||||
|
return dsConfigs
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the config to the output file
|
||||||
|
configJSON, err := json.Marshal(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(outputFile, configJSON, 0o644); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write output file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte("Success"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateMockSchema creates a simple merged schema from subgraphs
|
||||||
|
func generateMockSchema(subgraphs []struct {
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||||
|
Schema map[string]string `yaml:"schema"`
|
||||||
|
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||||
|
},
|
||||||
|
) string {
|
||||||
|
schema := strings.Builder{}
|
||||||
|
schema.WriteString("schema {\n query: Query\n")
|
||||||
|
|
||||||
|
// Check if any subgraph has subscriptions
|
||||||
|
hasSubscriptions := false
|
||||||
|
for _, sg := range subgraphs {
|
||||||
|
if len(sg.Subscription) > 0 {
|
||||||
|
hasSubscriptions = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasSubscriptions {
|
||||||
|
schema.WriteString(" subscription: Subscription\n")
|
||||||
|
}
|
||||||
|
schema.WriteString("}\n\n")
|
||||||
|
|
||||||
|
// Add types by reading SDL files
|
||||||
|
for _, sg := range subgraphs {
|
||||||
|
if schemaFile, ok := sg.Schema["file"]; ok {
|
||||||
|
if sdlData, err := os.ReadFile(schemaFile); err == nil {
|
||||||
|
schema.WriteString(string(sdlData))
|
||||||
|
schema.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema.String()
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateCosmoRouterConfig(t *testing.T) {
|
func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -33,7 +204,10 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
|||||||
err := json.Unmarshal([]byte(config), &result)
|
err := json.Unmarshal([]byte(config), &result)
|
||||||
require.NoError(t, err, "Config should be valid JSON")
|
require.NoError(t, err, "Config should be valid JSON")
|
||||||
|
|
||||||
assert.Equal(t, "1", result["version"], "Version should be 1")
|
// Version is a UUID string from wgc
|
||||||
|
version, ok := result["version"].(string)
|
||||||
|
require.True(t, ok, "Version should be a string")
|
||||||
|
assert.NotEmpty(t, version, "Version should not be empty")
|
||||||
|
|
||||||
subgraphs, ok := result["subgraphs"].([]interface{})
|
subgraphs, ok := result["subgraphs"].([]interface{})
|
||||||
require.True(t, ok, "subgraphs should be an array")
|
require.True(t, ok, "subgraphs should be an array")
|
||||||
@@ -41,14 +215,26 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
|||||||
|
|
||||||
sg := subgraphs[0].(map[string]interface{})
|
sg := subgraphs[0].(map[string]interface{})
|
||||||
assert.Equal(t, "test-service", sg["name"])
|
assert.Equal(t, "test-service", sg["name"])
|
||||||
assert.Equal(t, "http://localhost:4001/query", sg["routing_url"])
|
assert.Equal(t, "http://localhost:4001/query", sg["routingUrl"])
|
||||||
assert.Equal(t, "type Query { test: String }", sg["sdl"])
|
|
||||||
|
|
||||||
subscription, ok := sg["subscription"].(map[string]interface{})
|
// Check that datasource configurations include subscription settings
|
||||||
|
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have engineConfig")
|
||||||
|
|
||||||
|
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||||
|
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||||
|
|
||||||
|
ds := dsConfigs[0].(map[string]interface{})
|
||||||
|
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have customGraphql config")
|
||||||
|
|
||||||
|
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||||
require.True(t, ok, "Should have subscription config")
|
require.True(t, ok, "Should have subscription config")
|
||||||
assert.Equal(t, "ws://localhost:4001/query", subscription["url"])
|
assert.True(t, subscription["enabled"].(bool), "Subscription should be enabled")
|
||||||
assert.Equal(t, "ws", subscription["protocol"])
|
|
||||||
assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"])
|
subUrl, ok := subscription["url"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have subscription URL")
|
||||||
|
assert.Equal(t, "ws://localhost:4001/query", subUrl["staticVariableContent"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -80,18 +266,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
|||||||
subgraphs := result["subgraphs"].([]interface{})
|
subgraphs := result["subgraphs"].([]interface{})
|
||||||
assert.Len(t, subgraphs, 3, "Should have 3 subgraphs")
|
assert.Len(t, subgraphs, 3, "Should have 3 subgraphs")
|
||||||
|
|
||||||
// Check first service has no subscription
|
// Check service names
|
||||||
sg1 := subgraphs[0].(map[string]interface{})
|
sg1 := subgraphs[0].(map[string]interface{})
|
||||||
assert.Equal(t, "service-1", sg1["name"])
|
assert.Equal(t, "service-1", sg1["name"])
|
||||||
_, hasSubscription := sg1["subscription"]
|
|
||||||
assert.False(t, hasSubscription, "service-1 should not have subscription config")
|
|
||||||
|
|
||||||
// Check third service has subscription
|
|
||||||
sg3 := subgraphs[2].(map[string]interface{})
|
sg3 := subgraphs[2].(map[string]interface{})
|
||||||
assert.Equal(t, "service-3", sg3["name"])
|
assert.Equal(t, "service-3", sg3["name"])
|
||||||
subscription, hasSubscription := sg3["subscription"]
|
|
||||||
assert.True(t, hasSubscription, "service-3 should have subscription config")
|
// Check that datasource configurations include subscription for service-3
|
||||||
assert.NotNil(t, subscription)
|
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have engineConfig")
|
||||||
|
|
||||||
|
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||||
|
require.True(t, ok && len(dsConfigs) == 3, "Should have 3 datasource configurations")
|
||||||
|
|
||||||
|
// Find service-3's datasource config (should have subscription enabled)
|
||||||
|
ds3 := dsConfigs[2].(map[string]interface{})
|
||||||
|
customGraphql, ok := ds3["customGraphql"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Service-3 should have customGraphql config")
|
||||||
|
|
||||||
|
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Service-3 should have subscription config")
|
||||||
|
assert.True(t, subscription["enabled"].(bool), "Service-3 subscription should be enabled")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -113,39 +309,43 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
|||||||
subgraphs := result["subgraphs"].([]interface{})
|
subgraphs := result["subgraphs"].([]interface{})
|
||||||
sg := subgraphs[0].(map[string]interface{})
|
sg := subgraphs[0].(map[string]interface{})
|
||||||
|
|
||||||
// Should not have routing_url or subscription fields if URLs are nil
|
// Should not have routing URL when URL is nil
|
||||||
_, hasRoutingURL := sg["routing_url"]
|
_, hasRoutingURL := sg["routingUrl"]
|
||||||
assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil")
|
assert.False(t, hasRoutingURL, "Should not have routingUrl when URL is nil")
|
||||||
|
|
||||||
_, hasSubscription := sg["subscription"]
|
// Check datasource configurations don't have subscription enabled
|
||||||
assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil")
|
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have engineConfig")
|
||||||
|
|
||||||
|
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||||
|
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||||
|
|
||||||
|
ds := dsConfigs[0].(map[string]interface{})
|
||||||
|
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have customGraphql config")
|
||||||
|
|
||||||
|
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||||
|
if ok {
|
||||||
|
// wgc always enables subscription but URL should be empty when WsURL is nil
|
||||||
|
subUrl, hasUrl := subscription["url"].(map[string]interface{})
|
||||||
|
if hasUrl {
|
||||||
|
_, hasStaticContent := subUrl["staticVariableContent"]
|
||||||
|
assert.False(t, hasStaticContent, "Subscription URL should be empty when WsURL is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty subgraphs",
|
name: "empty subgraphs",
|
||||||
subGraphs: []*model.SubGraph{},
|
subGraphs: []*model.SubGraph{},
|
||||||
wantErr: false,
|
wantErr: true,
|
||||||
validate: func(t *testing.T, config string) {
|
validate: nil,
|
||||||
var result map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(config), &result)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
subgraphs := result["subgraphs"].([]interface{})
|
|
||||||
assert.Len(t, subgraphs, 0, "Should have empty subgraphs array")
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil subgraphs",
|
name: "nil subgraphs",
|
||||||
subGraphs: nil,
|
subGraphs: nil,
|
||||||
wantErr: false,
|
wantErr: true,
|
||||||
validate: func(t *testing.T, config string) {
|
validate: nil,
|
||||||
var result map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(config), &result)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
subgraphs := result["subgraphs"].([]interface{})
|
|
||||||
assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array")
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "complex SDL with multiple types",
|
name: "complex SDL with multiple types",
|
||||||
@@ -173,29 +373,58 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
|||||||
err := json.Unmarshal([]byte(config), &result)
|
err := json.Unmarshal([]byte(config), &result)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
subgraphs := result["subgraphs"].([]interface{})
|
// Check the composed graphqlSchema contains the types
|
||||||
sg := subgraphs[0].(map[string]interface{})
|
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||||
sdl := sg["sdl"].(string)
|
require.True(t, ok, "Should have engineConfig")
|
||||||
|
|
||||||
assert.Contains(t, sdl, "type Query")
|
graphqlSchema, ok := engineConfig["graphqlSchema"].(string)
|
||||||
assert.Contains(t, sdl, "type User")
|
require.True(t, ok, "Should have graphqlSchema")
|
||||||
assert.Contains(t, sdl, "email: String!")
|
|
||||||
|
assert.Contains(t, graphqlSchema, "Query", "Schema should contain Query type")
|
||||||
|
assert.Contains(t, graphqlSchema, "User", "Schema should contain User type")
|
||||||
|
|
||||||
|
// Check datasource has the original SDL
|
||||||
|
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||||
|
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||||
|
|
||||||
|
ds := dsConfigs[0].(map[string]interface{})
|
||||||
|
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have customGraphql config")
|
||||||
|
|
||||||
|
federation, ok := customGraphql["federation"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "Should have federation config")
|
||||||
|
|
||||||
|
serviceSdl, ok := federation["serviceSdl"].(string)
|
||||||
|
require.True(t, ok, "Should have serviceSdl")
|
||||||
|
assert.Contains(t, serviceSdl, "email: String!", "SDL should contain email field")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
config, err := GenerateCosmoRouterConfig(tt.subGraphs)
|
// Use mock executor for all tests
|
||||||
|
mockExecutor := &MockCommandExecutor{}
|
||||||
|
config, err := GenerateCosmoRouterConfigWithExecutor(tt.subGraphs, mockExecutor)
|
||||||
|
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
// Verify executor was not called for error cases
|
||||||
|
if len(tt.subGraphs) == 0 {
|
||||||
|
assert.Equal(t, 0, mockExecutor.CallCount, "Should not call executor for empty subgraphs")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, config, "Config should not be empty")
|
assert.NotEmpty(t, config, "Config should not be empty")
|
||||||
|
|
||||||
|
// Verify executor was called correctly
|
||||||
|
assert.Equal(t, 1, mockExecutor.CallCount, "Should call executor once")
|
||||||
|
assert.Equal(t, "wgc", mockExecutor.LastArgs[0], "Should call wgc command")
|
||||||
|
assert.Contains(t, mockExecutor.LastArgs, "router", "Should include 'router' arg")
|
||||||
|
assert.Contains(t, mockExecutor.LastArgs, "compose", "Should include 'compose' arg")
|
||||||
|
|
||||||
if tt.validate != nil {
|
if tt.validate != nil {
|
||||||
tt.validate(t, config)
|
tt.validate(t, config)
|
||||||
}
|
}
|
||||||
@@ -203,53 +432,31 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertSubGraphsToCosmo(t *testing.T) {
|
// TestGenerateCosmoRouterConfig_MockError tests error handling with mock executor
|
||||||
tests := []struct {
|
func TestGenerateCosmoRouterConfig_MockError(t *testing.T) {
|
||||||
name string
|
subGraphs := []*model.SubGraph{
|
||||||
subGraphs []*model.SubGraph
|
|
||||||
wantLen int
|
|
||||||
validate func(t *testing.T, result []map[string]interface{})
|
|
||||||
}{
|
|
||||||
{
|
{
|
||||||
name: "preserves subgraph order",
|
Service: "test-service",
|
||||||
subGraphs: []*model.SubGraph{
|
URL: stringPtr("http://localhost:4001/query"),
|
||||||
{Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"},
|
Sdl: "type Query { test: String }",
|
||||||
{Service: "beta", URL: stringPtr("http://b"), Sdl: "b"},
|
|
||||||
{Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"},
|
|
||||||
},
|
|
||||||
wantLen: 3,
|
|
||||||
validate: func(t *testing.T, result []map[string]interface{}) {
|
|
||||||
assert.Equal(t, "alpha", result[0]["name"])
|
|
||||||
assert.Equal(t, "beta", result[1]["name"])
|
|
||||||
assert.Equal(t, "gamma", result[2]["name"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "includes SDL exactly as provided",
|
|
||||||
subGraphs: []*model.SubGraph{
|
|
||||||
{
|
|
||||||
Service: "test",
|
|
||||||
URL: stringPtr("http://test"),
|
|
||||||
Sdl: "type Query { special: String! }",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantLen: 1,
|
|
||||||
validate: func(t *testing.T, result []map[string]interface{}) {
|
|
||||||
assert.Equal(t, "type Query { special: String! }", result[0]["sdl"])
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
// Create a mock executor that returns an error
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
mockExecutor := &MockCommandExecutor{
|
||||||
result := convertSubGraphsToCosmo(tt.subGraphs)
|
Error: fmt.Errorf("simulated wgc failure"),
|
||||||
assert.Len(t, result, tt.wantLen)
|
|
||||||
|
|
||||||
if tt.validate != nil {
|
|
||||||
tt.validate(t, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config, err := GenerateCosmoRouterConfigWithExecutor(subGraphs, mockExecutor)
|
||||||
|
|
||||||
|
// Verify error is propagated
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "wgc router compose failed")
|
||||||
|
assert.Contains(t, err.Error(), "simulated wgc failure")
|
||||||
|
assert.Empty(t, config)
|
||||||
|
|
||||||
|
// Verify executor was called
|
||||||
|
assert.Equal(t, 1, mockExecutor.CallCount, "Should have attempted to call executor")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function for tests
|
// Helper function for tests
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ type ComplexityRoot struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Query struct {
|
Query struct {
|
||||||
|
LatestSchema func(childComplexity int, ref string) int
|
||||||
Organizations func(childComplexity int) int
|
Organizations func(childComplexity int) int
|
||||||
Supergraph func(childComplexity int, ref string, isAfter *string) int
|
Supergraph func(childComplexity int, ref string, isAfter *string) int
|
||||||
}
|
}
|
||||||
@@ -124,6 +125,7 @@ type MutationResolver interface {
|
|||||||
type QueryResolver interface {
|
type QueryResolver interface {
|
||||||
Organizations(ctx context.Context) ([]*model.Organization, error)
|
Organizations(ctx context.Context) ([]*model.Organization, error)
|
||||||
Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error)
|
Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error)
|
||||||
|
LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error)
|
||||||
}
|
}
|
||||||
type SubscriptionResolver interface {
|
type SubscriptionResolver interface {
|
||||||
SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error)
|
SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error)
|
||||||
@@ -250,6 +252,17 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
|
|||||||
|
|
||||||
return e.complexity.Organization.Users(childComplexity), true
|
return e.complexity.Organization.Users(childComplexity), true
|
||||||
|
|
||||||
|
case "Query.latestSchema":
|
||||||
|
if e.complexity.Query.LatestSchema == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := ec.field_Query_latestSchema_args(ctx, rawArgs)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.complexity.Query.LatestSchema(childComplexity, args["ref"].(string)), true
|
||||||
case "Query.organizations":
|
case "Query.organizations":
|
||||||
if e.complexity.Query.Organizations == nil {
|
if e.complexity.Query.Organizations == nil {
|
||||||
break
|
break
|
||||||
@@ -520,6 +533,7 @@ var sources = []*ast.Source{
|
|||||||
{Name: "../schema.graphqls", Input: `type Query {
|
{Name: "../schema.graphqls", Input: `type Query {
|
||||||
organizations: [Organization!]! @auth(user: true)
|
organizations: [Organization!]! @auth(user: true)
|
||||||
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
||||||
|
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Mutation {
|
type Mutation {
|
||||||
@@ -671,6 +685,17 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
|
|||||||
return args, nil
|
return args, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ec *executionContext) field_Query_latestSchema_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
|
||||||
|
var err error
|
||||||
|
args := map[string]any{}
|
||||||
|
arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
args["ref"] = arg0
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
|
func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
|
||||||
var err error
|
var err error
|
||||||
args := map[string]any{}
|
args := map[string]any{}
|
||||||
@@ -1434,6 +1459,75 @@ func (ec *executionContext) fieldContext_Query_supergraph(ctx context.Context, f
|
|||||||
return fc, nil
|
return fc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ec *executionContext) _Query_latestSchema(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
|
||||||
|
return graphql.ResolveField(
|
||||||
|
ctx,
|
||||||
|
ec.OperationContext,
|
||||||
|
field,
|
||||||
|
ec.fieldContext_Query_latestSchema,
|
||||||
|
func(ctx context.Context) (any, error) {
|
||||||
|
fc := graphql.GetFieldContext(ctx)
|
||||||
|
return ec.resolvers.Query().LatestSchema(ctx, fc.Args["ref"].(string))
|
||||||
|
},
|
||||||
|
func(ctx context.Context, next graphql.Resolver) graphql.Resolver {
|
||||||
|
directive0 := next
|
||||||
|
|
||||||
|
directive1 := func(ctx context.Context) (any, error) {
|
||||||
|
organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true)
|
||||||
|
if err != nil {
|
||||||
|
var zeroVal *model.SchemaUpdate
|
||||||
|
return zeroVal, err
|
||||||
|
}
|
||||||
|
if ec.directives.Auth == nil {
|
||||||
|
var zeroVal *model.SchemaUpdate
|
||||||
|
return zeroVal, errors.New("directive auth is not implemented")
|
||||||
|
}
|
||||||
|
return ec.directives.Auth(ctx, nil, directive0, nil, organization)
|
||||||
|
}
|
||||||
|
|
||||||
|
next = directive1
|
||||||
|
return next
|
||||||
|
},
|
||||||
|
ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *executionContext) fieldContext_Query_latestSchema(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
|
||||||
|
fc = &graphql.FieldContext{
|
||||||
|
Object: "Query",
|
||||||
|
Field: field,
|
||||||
|
IsMethod: true,
|
||||||
|
IsResolver: true,
|
||||||
|
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
|
||||||
|
switch field.Name {
|
||||||
|
case "ref":
|
||||||
|
return ec.fieldContext_SchemaUpdate_ref(ctx, field)
|
||||||
|
case "id":
|
||||||
|
return ec.fieldContext_SchemaUpdate_id(ctx, field)
|
||||||
|
case "subGraphs":
|
||||||
|
return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field)
|
||||||
|
case "cosmoRouterConfig":
|
||||||
|
return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = ec.Recover(ctx, r)
|
||||||
|
ec.Error(ctx, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctx = graphql.WithFieldContext(ctx, fc)
|
||||||
|
if fc.Args, err = ec.field_Query_latestSchema_args(ctx, field.ArgumentMap(ec.Variables)); err != nil {
|
||||||
|
ec.Error(ctx, err)
|
||||||
|
return fc, err
|
||||||
|
}
|
||||||
|
return fc, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
|
func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
|
||||||
return graphql.ResolveField(
|
return graphql.ResolveField(
|
||||||
ctx,
|
ctx,
|
||||||
@@ -3997,6 +4091,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr
|
|||||||
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
|
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
|
||||||
|
case "latestSchema":
|
||||||
|
field := field
|
||||||
|
|
||||||
|
innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
ec.Error(ctx, ec.Recover(ctx, r))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
res = ec._Query_latestSchema(ctx, field)
|
||||||
|
if res == graphql.Null {
|
||||||
|
atomic.AddUint32(&fs.Invalids, 1)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
rrm := func(ctx context.Context) graphql.Marshaler {
|
||||||
|
return ec.OperationContext.RootResolverMiddleware(ctx,
|
||||||
|
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
|
||||||
|
}
|
||||||
|
|
||||||
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
|
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
|
||||||
case "__type":
|
case "__type":
|
||||||
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
|
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
type Query {
|
type Query {
|
||||||
organizations: [Organization!]! @auth(user: true)
|
organizations: [Organization!]! @auth(user: true)
|
||||||
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
||||||
|
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Mutation {
|
type Mutation {
|
||||||
|
|||||||
+113
-5
@@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
|
|||||||
// Publish schema update to subscribers
|
// Publish schema update to subscribers
|
||||||
go func() {
|
go func() {
|
||||||
services, lastUpdate := r.Cache.Services(orgId, input.Ref, "")
|
services, lastUpdate := r.Cache.Services(orgId, input.Ref, "")
|
||||||
|
r.Logger.Info("Publishing schema update after subgraph change",
|
||||||
|
"ref", input.Ref,
|
||||||
|
"orgId", orgId,
|
||||||
|
"lastUpdate", lastUpdate,
|
||||||
|
"servicesCount", len(services),
|
||||||
|
)
|
||||||
|
|
||||||
subGraphs := make([]*model.SubGraph, len(services))
|
subGraphs := make([]*model.SubGraph, len(services))
|
||||||
for i, id := range services {
|
for i, id := range services {
|
||||||
sg, err := r.fetchSubGraph(context.Background(), id)
|
sg, err := r.fetchSubGraph(context.Background(), id)
|
||||||
@@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Publish to all subscribers of this ref
|
// Publish to all subscribers of this ref
|
||||||
r.PubSub.Publish(input.Ref, &model.SchemaUpdate{
|
update := &model.SchemaUpdate{
|
||||||
Ref: input.Ref,
|
Ref: input.Ref,
|
||||||
ID: lastUpdate,
|
ID: lastUpdate,
|
||||||
SubGraphs: subGraphs,
|
SubGraphs: subGraphs,
|
||||||
CosmoRouterConfig: &cosmoConfig,
|
CosmoRouterConfig: &cosmoConfig,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
r.Logger.Info("Publishing schema update to subscribers",
|
||||||
|
"ref", update.Ref,
|
||||||
|
"id", update.ID,
|
||||||
|
"subGraphsCount", len(update.SubGraphs),
|
||||||
|
"cosmoConfigLength", len(cosmoConfig),
|
||||||
|
)
|
||||||
|
|
||||||
|
r.PubSub.Publish(input.Ref, update)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return r.toGqlSubGraph(subGraph), nil
|
return r.toGqlSubGraph(subGraph), nil
|
||||||
@@ -222,11 +238,84 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LatestSchema is the resolver for the latestSchema field.
|
||||||
|
func (r *queryResolver) LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) {
|
||||||
|
orgId := middleware.OrganizationFromContext(ctx)
|
||||||
|
|
||||||
|
r.Logger.Info("LatestSchema query",
|
||||||
|
"ref", ref,
|
||||||
|
"orgId", orgId,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
|
||||||
|
if err != nil {
|
||||||
|
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current services and schema
|
||||||
|
services, lastUpdate := r.Cache.Services(orgId, ref, "")
|
||||||
|
r.Logger.Info("Fetching latest schema",
|
||||||
|
"ref", ref,
|
||||||
|
"orgId", orgId,
|
||||||
|
"lastUpdate", lastUpdate,
|
||||||
|
"servicesCount", len(services),
|
||||||
|
)
|
||||||
|
|
||||||
|
subGraphs := make([]*model.SubGraph, len(services))
|
||||||
|
for i, id := range services {
|
||||||
|
sg, err := r.fetchSubGraph(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
r.Logger.Error("fetch subgraph", "error", err, "id", id)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
subGraphs[i] = &model.SubGraph{
|
||||||
|
ID: sg.ID.String(),
|
||||||
|
Service: sg.Service,
|
||||||
|
URL: sg.Url,
|
||||||
|
WsURL: sg.WSUrl,
|
||||||
|
Sdl: sg.Sdl,
|
||||||
|
ChangedBy: sg.ChangedBy,
|
||||||
|
ChangedAt: sg.ChangedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate Cosmo router config
|
||||||
|
cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs)
|
||||||
|
if err != nil {
|
||||||
|
r.Logger.Error("generate cosmo config", "error", err)
|
||||||
|
cosmoConfig = "" // Return empty if generation fails
|
||||||
|
}
|
||||||
|
|
||||||
|
update := &model.SchemaUpdate{
|
||||||
|
Ref: ref,
|
||||||
|
ID: lastUpdate,
|
||||||
|
SubGraphs: subGraphs,
|
||||||
|
CosmoRouterConfig: &cosmoConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Logger.Info("Latest schema fetched",
|
||||||
|
"ref", update.Ref,
|
||||||
|
"id", update.ID,
|
||||||
|
"subGraphsCount", len(update.SubGraphs),
|
||||||
|
"cosmoConfigLength", len(cosmoConfig),
|
||||||
|
)
|
||||||
|
|
||||||
|
return update, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SchemaUpdates is the resolver for the schemaUpdates field.
|
// SchemaUpdates is the resolver for the schemaUpdates field.
|
||||||
func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) {
|
func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) {
|
||||||
orgId := middleware.OrganizationFromContext(ctx)
|
orgId := middleware.OrganizationFromContext(ctx)
|
||||||
|
|
||||||
|
r.Logger.Info("SchemaUpdates subscription started",
|
||||||
|
"ref", ref,
|
||||||
|
"orgId", orgId,
|
||||||
|
)
|
||||||
|
|
||||||
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
|
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,12 +324,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
|
|||||||
|
|
||||||
// Send initial state immediately
|
// Send initial state immediately
|
||||||
go func() {
|
go func() {
|
||||||
|
// Use background context for async operation
|
||||||
|
bgCtx := context.Background()
|
||||||
|
|
||||||
services, lastUpdate := r.Cache.Services(orgId, ref, "")
|
services, lastUpdate := r.Cache.Services(orgId, ref, "")
|
||||||
|
r.Logger.Info("Preparing initial schema update",
|
||||||
|
"ref", ref,
|
||||||
|
"orgId", orgId,
|
||||||
|
"lastUpdate", lastUpdate,
|
||||||
|
"servicesCount", len(services),
|
||||||
|
)
|
||||||
|
|
||||||
subGraphs := make([]*model.SubGraph, len(services))
|
subGraphs := make([]*model.SubGraph, len(services))
|
||||||
for i, id := range services {
|
for i, id := range services {
|
||||||
sg, err := r.fetchSubGraph(ctx, id)
|
sg, err := r.fetchSubGraph(bgCtx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Logger.Error("fetch subgraph for initial update", "error", err)
|
r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
subGraphs[i] = &model.SubGraph{
|
subGraphs[i] = &model.SubGraph{
|
||||||
@@ -262,12 +361,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send initial update
|
// Send initial update
|
||||||
ch <- &model.SchemaUpdate{
|
update := &model.SchemaUpdate{
|
||||||
Ref: ref,
|
Ref: ref,
|
||||||
ID: lastUpdate,
|
ID: lastUpdate,
|
||||||
SubGraphs: subGraphs,
|
SubGraphs: subGraphs,
|
||||||
CosmoRouterConfig: &cosmoConfig,
|
CosmoRouterConfig: &cosmoConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.Logger.Info("Sending initial schema update",
|
||||||
|
"ref", update.Ref,
|
||||||
|
"id", update.ID,
|
||||||
|
"subGraphsCount", len(update.SubGraphs),
|
||||||
|
"cosmoConfigLength", len(cosmoConfig),
|
||||||
|
)
|
||||||
|
|
||||||
|
ch <- update
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Clean up subscription when context is done
|
// Clean up subscription when context is done
|
||||||
|
|||||||
+3
-1
@@ -49,7 +49,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
|
|||||||
_, _ = w.Write([]byte("Invalid API Key format"))
|
_, _ = w.Write([]byte("Invalid API Key format"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
hashedKey := hash.String(apiKey)
|
||||||
|
organization := m.cache.OrganizationByAPIKey(hashedKey)
|
||||||
|
if organization != nil {
|
||||||
ctx = context.WithValue(ctx, OrganizationKey, *organization)
|
ctx = context.WithValue(ctx, OrganizationKey, *organization)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,467 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||||
|
|
||||||
|
"gitlab.com/unboundsoftware/schemas/domain"
|
||||||
|
"gitlab.com/unboundsoftware/schemas/hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCache is a mock implementation of the Cache interface
|
||||||
|
type MockCache struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||||
|
args := m.Called(apiKey)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(*domain.Organization)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
orgID := uuid.New()
|
||||||
|
expectedOrg := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Organization",
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := "test-api-key-123"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with API key in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.NotNil(t, capturedOrg)
|
||||||
|
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||||
|
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
apiKey := "invalid-api-key"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with API key in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||||
|
emptyKeyHash := hash.String("")
|
||||||
|
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request without API key
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||||
|
emptyKeyHash := hash.String("")
|
||||||
|
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||||
|
|
||||||
|
userID := "user-123"
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"sub": userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedUser string
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if user := r.Context().Value(UserKey); user != nil {
|
||||||
|
if u, ok := user.(string); ok {
|
||||||
|
capturedUser = u
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with JWT token in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, userID, capturedUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with invalid API key type in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with invalid JWT token type in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||||
|
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||||
|
// Setup
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
orgID := uuid.New()
|
||||||
|
expectedOrg := &domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Organization",
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := "user-123"
|
||||||
|
apiKey := "test-api-key-123"
|
||||||
|
hashedKey := hash.String(apiKey)
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"sub": userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||||
|
|
||||||
|
// Create a test handler that checks both user and organization in context
|
||||||
|
var capturedUser string
|
||||||
|
var capturedOrg *domain.Organization
|
||||||
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if user := r.Context().Value(UserKey); user != nil {
|
||||||
|
if u, ok := user.(string); ok {
|
||||||
|
capturedUser = u
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||||
|
if o, ok := org.(domain.Organization); ok {
|
||||||
|
capturedOrg = &o
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create request with both JWT and API key in context
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||||
|
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, userID, capturedUser)
|
||||||
|
require.NotNil(t, capturedOrg)
|
||||||
|
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||||
|
mockCache.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserFromContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with valid user",
|
||||||
|
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
|
||||||
|
expected: "user-123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without user",
|
||||||
|
ctx: context.Background(),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with invalid type",
|
||||||
|
ctx: context.WithValue(context.Background(), UserKey, 123),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := UserFromContext(tt.ctx)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrganizationFromContext(t *testing.T) {
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Org",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with valid organization",
|
||||||
|
ctx: context.WithValue(context.Background(), OrganizationKey, org),
|
||||||
|
expected: orgID.String(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without organization",
|
||||||
|
ctx: context.Background(),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with invalid type",
|
||||||
|
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := OrganizationFromContext(tt.ctx)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
requireUser := true
|
||||||
|
|
||||||
|
// Test with user present
|
||||||
|
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||||
|
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Test without user
|
||||||
|
ctx = context.Background()
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, nil)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no user available in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
requireOrg := true
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Org",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with organization present
|
||||||
|
ctx := context.WithValue(context.Background(), OrganizationKey, org)
|
||||||
|
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, nil, &requireOrg)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Test without organization
|
||||||
|
ctx = context.Background()
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, nil, &requireOrg)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no organization available in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
requireUser := true
|
||||||
|
requireOrg := true
|
||||||
|
orgID := uuid.New()
|
||||||
|
org := domain.Organization{
|
||||||
|
BaseAggregate: eventsourced.BaseAggregate{
|
||||||
|
ID: eventsourced.IdFromString(orgID.String()),
|
||||||
|
},
|
||||||
|
Name: "Test Org",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with both present
|
||||||
|
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||||
|
ctx = context.WithValue(ctx, OrganizationKey, org)
|
||||||
|
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, &requireOrg)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Test with only user
|
||||||
|
ctx = context.WithValue(context.Background(), UserKey, "user-123")
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, &requireOrg)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Test with only organization
|
||||||
|
ctx = context.WithValue(context.Background(), OrganizationKey, org)
|
||||||
|
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, &requireUser, &requireOrg)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
||||||
|
mockCache := new(MockCache)
|
||||||
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
// Test with no requirements
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||||
|
return "success", nil
|
||||||
|
}, nil, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "success", result)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user