feat: add Cosmo Router config generation and PubSub support

Creates a new `GenerateCosmoRouterConfig` function to build and 
serialize a Cosmo Router configuration from subgraphs. Implements 
PubSub mechanism for managing schema updates, allowing 
subscription to updates. Adds Subscription resolver and updates 
existing structures to accommodate new functionalities. This 
enhances the system's capabilities for dynamic updates and 
configuration management.
This commit is contained in:
2025-11-19 11:29:30 +01:00
parent f6e4458efa
commit 80daed081d
10 changed files with 1135 additions and 2 deletions
+1
View File
@@ -195,6 +195,7 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
Publisher: eventPublisher,
Logger: logger,
Cache: serviceCache,
PubSub: graph.NewPubSub(),
}
config := generated.Config{
+54
View File
@@ -0,0 +1,54 @@
package graph
import (
"encoding/json"
"fmt"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
// Build the Cosmo router config structure
// This is a simplified version - you may need to adjust based on actual Cosmo requirements
config := map[string]interface{}{
"version": "1",
"subgraphs": convertSubGraphsToCosmo(subGraphs),
// Add other Cosmo-specific configuration as needed
}
// Marshal to JSON
configJSON, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", fmt.Errorf("marshal cosmo config: %w", err)
}
return string(configJSON), nil
}
func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} {
cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs))
for _, sg := range subGraphs {
cosmoSg := map[string]interface{}{
"name": sg.Service,
"sdl": sg.Sdl,
}
if sg.URL != nil {
cosmoSg["routing_url"] = *sg.URL
}
if sg.WsURL != nil {
cosmoSg["subscription"] = map[string]interface{}{
"url": *sg.WsURL,
"protocol": "ws",
"websocket_subprotocol": "graphql-ws",
}
}
cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg)
}
return cosmoSubgraphs
}
+258
View File
@@ -0,0 +1,258 @@
package graph
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
func TestGenerateCosmoRouterConfig(t *testing.T) {
tests := []struct {
name string
subGraphs []*model.SubGraph
wantErr bool
validate func(t *testing.T, config string)
}{
{
name: "single subgraph with all fields",
subGraphs: []*model.SubGraph{
{
Service: "test-service",
URL: stringPtr("http://localhost:4001/query"),
WsURL: stringPtr("ws://localhost:4001/query"),
Sdl: "type Query { test: String }",
},
},
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err, "Config should be valid JSON")
assert.Equal(t, "1", result["version"], "Version should be 1")
subgraphs, ok := result["subgraphs"].([]interface{})
require.True(t, ok, "subgraphs should be an array")
require.Len(t, subgraphs, 1, "Should have 1 subgraph")
sg := subgraphs[0].(map[string]interface{})
assert.Equal(t, "test-service", sg["name"])
assert.Equal(t, "http://localhost:4001/query", sg["routing_url"])
assert.Equal(t, "type Query { test: String }", sg["sdl"])
subscription, ok := sg["subscription"].(map[string]interface{})
require.True(t, ok, "Should have subscription config")
assert.Equal(t, "ws://localhost:4001/query", subscription["url"])
assert.Equal(t, "ws", subscription["protocol"])
assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"])
},
},
{
name: "multiple subgraphs",
subGraphs: []*model.SubGraph{
{
Service: "service-1",
URL: stringPtr("http://localhost:4001/query"),
Sdl: "type Query { field1: String }",
},
{
Service: "service-2",
URL: stringPtr("http://localhost:4002/query"),
Sdl: "type Query { field2: String }",
},
{
Service: "service-3",
URL: stringPtr("http://localhost:4003/query"),
WsURL: stringPtr("ws://localhost:4003/query"),
Sdl: "type Subscription { updates: String }",
},
},
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 3, "Should have 3 subgraphs")
// Check first service has no subscription
sg1 := subgraphs[0].(map[string]interface{})
assert.Equal(t, "service-1", sg1["name"])
_, hasSubscription := sg1["subscription"]
assert.False(t, hasSubscription, "service-1 should not have subscription config")
// Check third service has subscription
sg3 := subgraphs[2].(map[string]interface{})
assert.Equal(t, "service-3", sg3["name"])
subscription, hasSubscription := sg3["subscription"]
assert.True(t, hasSubscription, "service-3 should have subscription config")
assert.NotNil(t, subscription)
},
},
{
name: "subgraph with no URL",
subGraphs: []*model.SubGraph{
{
Service: "test-service",
URL: nil,
WsURL: nil,
Sdl: "type Query { test: String }",
},
},
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
sg := subgraphs[0].(map[string]interface{})
// Should not have routing_url or subscription fields if URLs are nil
_, hasRoutingURL := sg["routing_url"]
assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil")
_, hasSubscription := sg["subscription"]
assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil")
},
},
{
name: "empty subgraphs",
subGraphs: []*model.SubGraph{},
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 0, "Should have empty subgraphs array")
},
},
{
name: "nil subgraphs",
subGraphs: nil,
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array")
},
},
{
name: "complex SDL with multiple types",
subGraphs: []*model.SubGraph{
{
Service: "complex-service",
URL: stringPtr("http://localhost:4001/query"),
Sdl: `
type Query {
user(id: ID!): User
users: [User!]!
}
type User {
id: ID!
name: String!
email: String!
}
`,
},
},
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
sg := subgraphs[0].(map[string]interface{})
sdl := sg["sdl"].(string)
assert.Contains(t, sdl, "type Query")
assert.Contains(t, sdl, "type User")
assert.Contains(t, sdl, "email: String!")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config, err := GenerateCosmoRouterConfig(tt.subGraphs)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.NotEmpty(t, config, "Config should not be empty")
if tt.validate != nil {
tt.validate(t, config)
}
})
}
}
func TestConvertSubGraphsToCosmo(t *testing.T) {
tests := []struct {
name string
subGraphs []*model.SubGraph
wantLen int
validate func(t *testing.T, result []map[string]interface{})
}{
{
name: "preserves subgraph order",
subGraphs: []*model.SubGraph{
{Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"},
{Service: "beta", URL: stringPtr("http://b"), Sdl: "b"},
{Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"},
},
wantLen: 3,
validate: func(t *testing.T, result []map[string]interface{}) {
assert.Equal(t, "alpha", result[0]["name"])
assert.Equal(t, "beta", result[1]["name"])
assert.Equal(t, "gamma", result[2]["name"])
},
},
{
name: "includes SDL exactly as provided",
subGraphs: []*model.SubGraph{
{
Service: "test",
URL: stringPtr("http://test"),
Sdl: "type Query { special: String! }",
},
},
wantLen: 1,
validate: func(t *testing.T, result []map[string]interface{}) {
assert.Equal(t, "type Query { special: String! }", result[0]["sdl"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertSubGraphsToCosmo(tt.subGraphs)
assert.Len(t, result, tt.wantLen)
if tt.validate != nil {
tt.validate(t, result)
}
})
}
}
// Helper function for tests
func stringPtr(s string) *string {
return &s
}
+377
View File
@@ -42,6 +42,7 @@ type Config struct {
type ResolverRoot interface {
Mutation() MutationResolver
Query() QueryResolver
Subscription() SubscriptionResolver
}
type DirectiveRoot struct {
@@ -77,6 +78,13 @@ type ComplexityRoot struct {
Supergraph func(childComplexity int, ref string, isAfter *string) int
}
SchemaUpdate struct {
CosmoRouterConfig func(childComplexity int) int
ID func(childComplexity int) int
Ref func(childComplexity int) int
SubGraphs func(childComplexity int) int
}
SubGraph struct {
ChangedAt func(childComplexity int) int
ChangedBy func(childComplexity int) int
@@ -94,6 +102,10 @@ type ComplexityRoot struct {
SubGraphs func(childComplexity int) int
}
Subscription struct {
SchemaUpdates func(childComplexity int, ref string) int
}
Unchanged struct {
ID func(childComplexity int) int
MinDelaySeconds func(childComplexity int) int
@@ -113,6 +125,9 @@ type QueryResolver interface {
Organizations(ctx context.Context) ([]*model.Organization, error)
Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error)
}
type SubscriptionResolver interface {
SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error)
}
type executableSchema struct {
schema *ast.Schema
@@ -253,6 +268,31 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.Query.Supergraph(childComplexity, args["ref"].(string), args["isAfter"].(*string)), true
case "SchemaUpdate.cosmoRouterConfig":
if e.complexity.SchemaUpdate.CosmoRouterConfig == nil {
break
}
return e.complexity.SchemaUpdate.CosmoRouterConfig(childComplexity), true
case "SchemaUpdate.id":
if e.complexity.SchemaUpdate.ID == nil {
break
}
return e.complexity.SchemaUpdate.ID(childComplexity), true
case "SchemaUpdate.ref":
if e.complexity.SchemaUpdate.Ref == nil {
break
}
return e.complexity.SchemaUpdate.Ref(childComplexity), true
case "SchemaUpdate.subGraphs":
if e.complexity.SchemaUpdate.SubGraphs == nil {
break
}
return e.complexity.SchemaUpdate.SubGraphs(childComplexity), true
case "SubGraph.changedAt":
if e.complexity.SubGraph.ChangedAt == nil {
break
@@ -321,6 +361,18 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.SubGraphs.SubGraphs(childComplexity), true
case "Subscription.schemaUpdates":
if e.complexity.Subscription.SchemaUpdates == nil {
break
}
args, err := ec.field_Subscription_schemaUpdates_args(ctx, rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Subscription.SchemaUpdates(childComplexity, args["ref"].(string)), true
case "Unchanged.id":
if e.complexity.Unchanged.ID == nil {
break
@@ -396,6 +448,23 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
var buf bytes.Buffer
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
case ast.Subscription:
next := ec._Subscription(ctx, opCtx.Operation.SelectionSet)
var buf bytes.Buffer
return func(ctx context.Context) *graphql.Response {
buf.Reset()
data := next(ctx)
if data == nil {
return nil
}
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
@@ -459,6 +528,10 @@ type Mutation {
updateSubGraph(input: InputSubGraph!): SubGraph! @auth(organization: true)
}
type Subscription {
schemaUpdates(ref: String!): SchemaUpdate! @auth(organization: true)
}
type Organization {
id: ID!
name: String!
@@ -504,6 +577,13 @@ type SubGraph {
changedAt: Time!
}
type SchemaUpdate {
ref: String!
id: ID!
subGraphs: [SubGraph!]!
cosmoRouterConfig: String
}
input InputAPIKey {
name: String!
organizationId: ID!
@@ -607,6 +687,17 @@ func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, raw
return args, nil
}
func (ec *executionContext) field_Subscription_schemaUpdates_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string)
if err != nil {
return nil, err
}
args["ref"] = arg0
return args, nil
}
func (ec *executionContext) field___Directive_args_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
@@ -1451,6 +1542,138 @@ func (ec *executionContext) fieldContext_Query___schema(_ context.Context, field
return fc, nil
}
func (ec *executionContext) _SchemaUpdate_ref(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
ec.OperationContext,
field,
ec.fieldContext_SchemaUpdate_ref,
func(ctx context.Context) (any, error) {
return obj.Ref, nil
},
nil,
ec.marshalNString2string,
true,
true,
)
}
func (ec *executionContext) fieldContext_SchemaUpdate_ref(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "SchemaUpdate",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _SchemaUpdate_id(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
ec.OperationContext,
field,
ec.fieldContext_SchemaUpdate_id,
func(ctx context.Context) (any, error) {
return obj.ID, nil
},
nil,
ec.marshalNID2string,
true,
true,
)
}
func (ec *executionContext) fieldContext_SchemaUpdate_id(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "SchemaUpdate",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type ID does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _SchemaUpdate_subGraphs(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
ec.OperationContext,
field,
ec.fieldContext_SchemaUpdate_subGraphs,
func(ctx context.Context) (any, error) {
return obj.SubGraphs, nil
},
nil,
ec.marshalNSubGraph2ᚕᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSubGraphᚄ,
true,
true,
)
}
func (ec *executionContext) fieldContext_SchemaUpdate_subGraphs(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "SchemaUpdate",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
switch field.Name {
case "id":
return ec.fieldContext_SubGraph_id(ctx, field)
case "service":
return ec.fieldContext_SubGraph_service(ctx, field)
case "url":
return ec.fieldContext_SubGraph_url(ctx, field)
case "wsUrl":
return ec.fieldContext_SubGraph_wsUrl(ctx, field)
case "sdl":
return ec.fieldContext_SubGraph_sdl(ctx, field)
case "changedBy":
return ec.fieldContext_SubGraph_changedBy(ctx, field)
case "changedAt":
return ec.fieldContext_SubGraph_changedAt(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type SubGraph", field.Name)
},
}
return fc, nil
}
func (ec *executionContext) _SchemaUpdate_cosmoRouterConfig(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
ec.OperationContext,
field,
ec.fieldContext_SchemaUpdate_cosmoRouterConfig,
func(ctx context.Context) (any, error) {
return obj.CosmoRouterConfig, nil
},
nil,
ec.marshalOString2ᚖstring,
true,
false,
)
}
func (ec *executionContext) fieldContext_SchemaUpdate_cosmoRouterConfig(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "SchemaUpdate",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _SubGraph_id(ctx context.Context, field graphql.CollectedField, obj *model.SubGraph) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
@@ -1786,6 +2009,75 @@ func (ec *executionContext) fieldContext_SubGraphs_subGraphs(_ context.Context,
return fc, nil
}
func (ec *executionContext) _Subscription_schemaUpdates(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) {
return graphql.ResolveFieldStream(
ctx,
ec.OperationContext,
field,
ec.fieldContext_Subscription_schemaUpdates,
func(ctx context.Context) (any, error) {
fc := graphql.GetFieldContext(ctx)
return ec.resolvers.Subscription().SchemaUpdates(ctx, fc.Args["ref"].(string))
},
func(ctx context.Context, next graphql.Resolver) graphql.Resolver {
directive0 := next
directive1 := func(ctx context.Context) (any, error) {
organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true)
if err != nil {
var zeroVal *model.SchemaUpdate
return zeroVal, err
}
if ec.directives.Auth == nil {
var zeroVal *model.SchemaUpdate
return zeroVal, errors.New("directive auth is not implemented")
}
return ec.directives.Auth(ctx, nil, directive0, nil, organization)
}
next = directive1
return next
},
ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate,
true,
true,
)
}
func (ec *executionContext) fieldContext_Subscription_schemaUpdates(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Subscription",
Field: field,
IsMethod: true,
IsResolver: true,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
switch field.Name {
case "ref":
return ec.fieldContext_SchemaUpdate_ref(ctx, field)
case "id":
return ec.fieldContext_SchemaUpdate_id(ctx, field)
case "subGraphs":
return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field)
case "cosmoRouterConfig":
return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name)
},
}
defer func() {
if r := recover(); r != nil {
err = ec.Recover(ctx, r)
ec.Error(ctx, err)
}
}()
ctx = graphql.WithFieldContext(ctx, fc)
if fc.Args, err = ec.field_Subscription_schemaUpdates_args(ctx, field.ArgumentMap(ec.Variables)); err != nil {
ec.Error(ctx, err)
return fc, err
}
return fc, nil
}
func (ec *executionContext) _Unchanged_id(ctx context.Context, field graphql.CollectedField, obj *model.Unchanged) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
@@ -3737,6 +4029,57 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr
return out
}
var schemaUpdateImplementors = []string{"SchemaUpdate"}
func (ec *executionContext) _SchemaUpdate(ctx context.Context, sel ast.SelectionSet, obj *model.SchemaUpdate) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, schemaUpdateImplementors)
out := graphql.NewFieldSet(fields)
deferred := make(map[string]*graphql.FieldSet)
for i, field := range fields {
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString("SchemaUpdate")
case "ref":
out.Values[i] = ec._SchemaUpdate_ref(ctx, field, obj)
if out.Values[i] == graphql.Null {
out.Invalids++
}
case "id":
out.Values[i] = ec._SchemaUpdate_id(ctx, field, obj)
if out.Values[i] == graphql.Null {
out.Invalids++
}
case "subGraphs":
out.Values[i] = ec._SchemaUpdate_subGraphs(ctx, field, obj)
if out.Values[i] == graphql.Null {
out.Invalids++
}
case "cosmoRouterConfig":
out.Values[i] = ec._SchemaUpdate_cosmoRouterConfig(ctx, field, obj)
default:
panic("unknown field " + strconv.Quote(field.Name))
}
}
out.Dispatch(ctx)
if out.Invalids > 0 {
return graphql.Null
}
atomic.AddInt32(&ec.deferred, int32(len(deferred)))
for label, dfs := range deferred {
ec.processDeferredGroup(graphql.DeferredGroup{
Label: label,
Path: graphql.GetPath(ctx),
FieldSet: dfs,
Context: ctx,
})
}
return out
}
var subGraphImplementors = []string{"SubGraph"}
func (ec *executionContext) _SubGraph(ctx context.Context, sel ast.SelectionSet, obj *model.SubGraph) graphql.Marshaler {
@@ -3854,6 +4197,26 @@ func (ec *executionContext) _SubGraphs(ctx context.Context, sel ast.SelectionSet
return out
}
var subscriptionImplementors = []string{"Subscription"}
func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, subscriptionImplementors)
ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{
Object: "Subscription",
})
if len(fields) != 1 {
ec.Errorf(ctx, "must subscribe to exactly one stream")
return nil
}
switch fields[0].Name {
case "schemaUpdates":
return ec._Subscription_schemaUpdates(ctx, fields[0])
default:
panic("unknown field " + strconv.Quote(fields[0].Name))
}
}
var unchangedImplementors = []string{"Unchanged", "Supergraph"}
func (ec *executionContext) _Unchanged(ctx context.Context, sel ast.SelectionSet, obj *model.Unchanged) graphql.Marshaler {
@@ -4441,6 +4804,20 @@ func (ec *executionContext) marshalNOrganization2ᚖgitlabᚗcomᚋunboundsoftwa
return ec._Organization(ctx, sel, v)
}
func (ec *executionContext) marshalNSchemaUpdate2gitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate(ctx context.Context, sel ast.SelectionSet, v model.SchemaUpdate) graphql.Marshaler {
return ec._SchemaUpdate(ctx, sel, &v)
}
func (ec *executionContext) marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate(ctx context.Context, sel ast.SelectionSet, v *model.SchemaUpdate) graphql.Marshaler {
if v == nil {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "the requested element is null which the schema does not allow")
}
return graphql.Null
}
return ec._SchemaUpdate(ctx, sel, v)
}
func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) {
res, err := graphql.UnmarshalString(v)
return res, graphql.ErrorOnPath(ctx, err)
+10
View File
@@ -49,6 +49,13 @@ type Organization struct {
type Query struct {
}
type SchemaUpdate struct {
Ref string `json:"ref"`
ID string `json:"id"`
SubGraphs []*SubGraph `json:"subGraphs"`
CosmoRouterConfig *string `json:"cosmoRouterConfig,omitempty"`
}
type SubGraph struct {
ID string `json:"id"`
Service string `json:"service"`
@@ -68,6 +75,9 @@ type SubGraphs struct {
func (SubGraphs) IsSupergraph() {}
type Subscription struct {
}
type Unchanged struct {
ID string `json:"id"`
MinDelaySeconds int `json:"minDelaySeconds"`
+66
View File
@@ -0,0 +1,66 @@
package graph
import (
"sync"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
// PubSub handles publishing schema updates to subscribers
type PubSub struct {
mu sync.RWMutex
subscribers map[string][]chan *model.SchemaUpdate
}
func NewPubSub() *PubSub {
return &PubSub{
subscribers: make(map[string][]chan *model.SchemaUpdate),
}
}
// Subscribe creates a new subscription channel for a given schema ref
func (ps *PubSub) Subscribe(ref string) chan *model.SchemaUpdate {
ps.mu.Lock()
defer ps.mu.Unlock()
ch := make(chan *model.SchemaUpdate, 10)
ps.subscribers[ref] = append(ps.subscribers[ref], ch)
return ch
}
// Unsubscribe removes a subscription channel
func (ps *PubSub) Unsubscribe(ref string, ch chan *model.SchemaUpdate) {
ps.mu.Lock()
defer ps.mu.Unlock()
subs := ps.subscribers[ref]
for i, sub := range subs {
if sub == ch {
// Remove this subscriber
ps.subscribers[ref] = append(subs[:i], subs[i+1:]...)
close(sub)
break
}
}
// Clean up empty subscriber lists
if len(ps.subscribers[ref]) == 0 {
delete(ps.subscribers, ref)
}
}
// Publish sends a schema update to all subscribers of a given ref
func (ps *PubSub) Publish(ref string, update *model.SchemaUpdate) {
ps.mu.RLock()
defer ps.mu.RUnlock()
for _, ch := range ps.subscribers[ref] {
// Non-blocking send - if subscriber is slow, skip
select {
case ch <- update:
default:
// Channel full, subscriber is too slow - skip this update
}
}
}
+256
View File
@@ -0,0 +1,256 @@
package graph
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
func TestPubSub_SubscribeAndPublish(t *testing.T) {
ps := NewPubSub()
ref := "Test@dev"
// Subscribe
ch := ps.Subscribe(ref)
require.NotNil(t, ch, "Subscribe should return a channel")
// Publish
update := &model.SchemaUpdate{
Ref: ref,
ID: "test-id-1",
SubGraphs: []*model.SubGraph{
{
ID: "sg1",
Service: "test-service",
Sdl: "type Query { test: String }",
},
},
}
go ps.Publish(ref, update)
// Receive
select {
case received := <-ch:
assert.Equal(t, update.Ref, received.Ref, "Ref should match")
assert.Equal(t, update.ID, received.ID, "ID should match")
assert.Equal(t, len(update.SubGraphs), len(received.SubGraphs), "SubGraphs count should match")
case <-time.After(1 * time.Second):
t.Fatal("Timeout waiting for published update")
}
}
func TestPubSub_MultipleSubscribers(t *testing.T) {
ps := NewPubSub()
ref := "Test@dev"
// Create multiple subscribers
ch1 := ps.Subscribe(ref)
ch2 := ps.Subscribe(ref)
ch3 := ps.Subscribe(ref)
update := &model.SchemaUpdate{
Ref: ref,
ID: "test-id-2",
}
// Publish once
ps.Publish(ref, update)
// All subscribers should receive the update
var wg sync.WaitGroup
wg.Add(3)
checkReceived := func(ch <-chan *model.SchemaUpdate, name string) {
defer wg.Done()
select {
case received := <-ch:
assert.Equal(t, update.ID, received.ID, "%s should receive correct update", name)
case <-time.After(1 * time.Second):
t.Errorf("%s: Timeout waiting for update", name)
}
}
go checkReceived(ch1, "Subscriber 1")
go checkReceived(ch2, "Subscriber 2")
go checkReceived(ch3, "Subscriber 3")
wg.Wait()
}
func TestPubSub_DifferentRefs(t *testing.T) {
ps := NewPubSub()
ref1 := "Test1@dev"
ref2 := "Test2@dev"
ch1 := ps.Subscribe(ref1)
ch2 := ps.Subscribe(ref2)
update1 := &model.SchemaUpdate{Ref: ref1, ID: "id1"}
update2 := &model.SchemaUpdate{Ref: ref2, ID: "id2"}
// Publish to ref1
ps.Publish(ref1, update1)
// Only ch1 should receive
select {
case received := <-ch1:
assert.Equal(t, "id1", received.ID)
case <-time.After(100 * time.Millisecond):
t.Fatal("ch1 should have received update")
}
// ch2 should not receive ref1's update
select {
case <-ch2:
t.Fatal("ch2 should not receive ref1's update")
case <-time.After(100 * time.Millisecond):
// Expected - no update
}
// Publish to ref2
ps.Publish(ref2, update2)
// Now ch2 should receive
select {
case received := <-ch2:
assert.Equal(t, "id2", received.ID)
case <-time.After(100 * time.Millisecond):
t.Fatal("ch2 should have received update")
}
}
func TestPubSub_Unsubscribe(t *testing.T) {
ps := NewPubSub()
ref := "Test@dev"
ch := ps.Subscribe(ref)
// Unsubscribe
ps.Unsubscribe(ref, ch)
// Channel should be closed
_, ok := <-ch
assert.False(t, ok, "Channel should be closed after unsubscribe")
// Publishing after unsubscribe should not panic
assert.NotPanics(t, func() {
ps.Publish(ref, &model.SchemaUpdate{Ref: ref})
})
}
func TestPubSub_BufferedChannel(t *testing.T) {
ps := NewPubSub()
ref := "Test@dev"
ch := ps.Subscribe(ref)
// Publish multiple updates quickly (up to buffer size of 10)
for i := 0; i < 10; i++ {
update := &model.SchemaUpdate{
Ref: ref,
ID: string(rune('a' + i)),
}
ps.Publish(ref, update)
}
// All 10 should be buffered and receivable
received := 0
timeout := time.After(1 * time.Second)
for received < 10 {
select {
case <-ch:
received++
case <-timeout:
t.Fatalf("Only received %d out of 10 updates", received)
}
}
assert.Equal(t, 10, received, "Should receive all buffered updates")
}
func TestPubSub_SlowSubscriber(t *testing.T) {
ps := NewPubSub()
ref := "Test@dev"
ch := ps.Subscribe(ref)
// Fill the buffer (10 items)
for i := 0; i < 10; i++ {
ps.Publish(ref, &model.SchemaUpdate{Ref: ref})
}
// Publish one more - this should be dropped (channel full, non-blocking send)
ps.Publish(ref, &model.SchemaUpdate{Ref: ref, ID: "should-be-dropped"})
// Drain the channel
count := 0
timeout := time.After(500 * time.Millisecond)
drainLoop:
for {
select {
case update := <-ch:
count++
// Should not receive the dropped update
assert.NotEqual(t, "should-be-dropped", update.ID, "Should not receive dropped update")
case <-timeout:
break drainLoop
}
}
// Should have received exactly 10 (the buffer size), not 11
assert.Equal(t, 10, count, "Should only receive buffered updates, not the dropped one")
}
func TestPubSub_ConcurrentPublish(t *testing.T) {
ps := NewPubSub()
ref := "Test@dev"
ch := ps.Subscribe(ref)
numPublishers := 10
updatesPerPublisher := 10
var wg sync.WaitGroup
wg.Add(numPublishers)
// Multiple goroutines publishing concurrently
for i := 0; i < numPublishers; i++ {
go func(publisherID int) {
defer wg.Done()
for j := 0; j < updatesPerPublisher; j++ {
ps.Publish(ref, &model.SchemaUpdate{
Ref: ref,
ID: string(rune('a' + publisherID)),
})
}
}(i)
}
wg.Wait()
// Should not panic and subscriber should receive updates
// (exact count may vary due to buffer and timing)
timeout := time.After(1 * time.Second)
received := 0
receiveLoop:
for {
select {
case <-ch:
received++
case <-timeout:
break receiveLoop
}
}
assert.Greater(t, received, 0, "Should have received some updates")
}
+1
View File
@@ -28,6 +28,7 @@ type Resolver struct {
Publisher Publisher
Logger *slog.Logger
Cache *cache.Cache
PubSub *PubSub
}
func (r *Resolver) apiKeyCanAccessRef(ctx context.Context, ref string, publish bool) (string, error) {
+11
View File
@@ -9,6 +9,10 @@ type Mutation {
updateSubGraph(input: InputSubGraph!): SubGraph! @auth(organization: true)
}
type Subscription {
schemaUpdates(ref: String!): SchemaUpdate! @auth(organization: true)
}
type Organization {
id: ID!
name: String!
@@ -54,6 +58,13 @@ type SubGraph {
changedAt: Time!
}
type SchemaUpdate {
ref: String!
id: ID!
subGraphs: [SubGraph!]!
cosmoRouterConfig: String
}
input InputAPIKey {
name: String!
organizationId: ID!
+99
View File
@@ -119,6 +119,44 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
if err != nil {
return nil, err
}
// Publish schema update to subscribers
go func() {
services, lastUpdate := r.Cache.Services(orgId, input.Ref, "")
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(context.Background(), id)
if err != nil {
r.Logger.Error("fetch subgraph for update notification", "error", err)
continue
}
subGraphs[i] = &model.SubGraph{
ID: sg.ID.String(),
Service: sg.Service,
URL: sg.Url,
WsURL: sg.WSUrl,
Sdl: sg.Sdl,
ChangedBy: sg.ChangedBy,
ChangedAt: sg.ChangedAt,
}
}
// Generate Cosmo router config
cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs)
if err != nil {
r.Logger.Error("generate cosmo config for update", "error", err)
cosmoConfig = "" // Send empty if generation fails
}
// Publish to all subscribers of this ref
r.PubSub.Publish(input.Ref, &model.SchemaUpdate{
Ref: input.Ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
})
}()
return r.toGqlSubGraph(subGraph), nil
}
@@ -184,13 +222,74 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str
}, nil
}
// SchemaUpdates is the resolver for the schemaUpdates field.
func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) {
orgId := middleware.OrganizationFromContext(ctx)
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
if err != nil {
return nil, err
}
// Subscribe to updates for this ref
ch := r.PubSub.Subscribe(ref)
// Send initial state immediately
go func() {
services, lastUpdate := r.Cache.Services(orgId, ref, "")
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(ctx, id)
if err != nil {
r.Logger.Error("fetch subgraph for initial update", "error", err)
continue
}
subGraphs[i] = &model.SubGraph{
ID: sg.ID.String(),
Service: sg.Service,
URL: sg.Url,
WsURL: sg.WSUrl,
Sdl: sg.Sdl,
ChangedBy: sg.ChangedBy,
ChangedAt: sg.ChangedAt,
}
}
// Generate Cosmo router config
cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs)
if err != nil {
r.Logger.Error("generate cosmo config", "error", err)
cosmoConfig = "" // Send empty if generation fails
}
// Send initial update
ch <- &model.SchemaUpdate{
Ref: ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
}
}()
// Clean up subscription when context is done
go func() {
<-ctx.Done()
r.PubSub.Unsubscribe(ref, ch)
}()
return ch, nil
}
// Mutation returns generated.MutationResolver implementation.
func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResolver{r} }
// Query returns generated.QueryResolver implementation.
func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} }
// Subscription returns generated.SubscriptionResolver implementation.
func (r *Resolver) Subscription() generated.SubscriptionResolver { return &subscriptionResolver{r} }
type (
mutationResolver struct{ *Resolver }
queryResolver struct{ *Resolver }
subscriptionResolver struct{ *Resolver }
)