diff --git a/cmd/service/service.go b/cmd/service/service.go index 43cee3a..dbefcf2 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -195,6 +195,7 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u Publisher: eventPublisher, Logger: logger, Cache: serviceCache, + PubSub: graph.NewPubSub(), } config := generated.Config{ diff --git a/graph/cosmo.go b/graph/cosmo.go new file mode 100644 index 0000000..150e162 --- /dev/null +++ b/graph/cosmo.go @@ -0,0 +1,54 @@ +package graph + +import ( + "encoding/json" + "fmt" + + "gitlab.com/unboundsoftware/schemas/graph/model" +) + +// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs +func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) { + // Build the Cosmo router config structure + // This is a simplified version - you may need to adjust based on actual Cosmo requirements + config := map[string]interface{}{ + "version": "1", + "subgraphs": convertSubGraphsToCosmo(subGraphs), + // Add other Cosmo-specific configuration as needed + } + + // Marshal to JSON + configJSON, err := json.MarshalIndent(config, "", " ") + if err != nil { + return "", fmt.Errorf("marshal cosmo config: %w", err) + } + + return string(configJSON), nil +} + +func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} { + cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs)) + + for _, sg := range subGraphs { + cosmoSg := map[string]interface{}{ + "name": sg.Service, + "sdl": sg.Sdl, + } + + if sg.URL != nil { + cosmoSg["routing_url"] = *sg.URL + } + + if sg.WsURL != nil { + cosmoSg["subscription"] = map[string]interface{}{ + "url": *sg.WsURL, + "protocol": "ws", + "websocket_subprotocol": "graphql-ws", + } + } + + cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg) + } + + return cosmoSubgraphs +} diff --git a/graph/cosmo_test.go b/graph/cosmo_test.go new file mode 100644 index 0000000..0f6de36 --- /dev/null +++ b/graph/cosmo_test.go @@ -0,0 +1,258 @@ +package graph + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/unboundsoftware/schemas/graph/model" +) + +func TestGenerateCosmoRouterConfig(t *testing.T) { + tests := []struct { + name string + subGraphs []*model.SubGraph + wantErr bool + validate func(t *testing.T, config string) + }{ + { + name: "single subgraph with all fields", + subGraphs: []*model.SubGraph{ + { + Service: "test-service", + URL: stringPtr("http://localhost:4001/query"), + WsURL: stringPtr("ws://localhost:4001/query"), + Sdl: "type Query { test: String }", + }, + }, + wantErr: false, + validate: func(t *testing.T, config string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(config), &result) + require.NoError(t, err, "Config should be valid JSON") + + assert.Equal(t, "1", result["version"], "Version should be 1") + + subgraphs, ok := result["subgraphs"].([]interface{}) + require.True(t, ok, "subgraphs should be an array") + require.Len(t, subgraphs, 1, "Should have 1 subgraph") + + sg := subgraphs[0].(map[string]interface{}) + assert.Equal(t, "test-service", sg["name"]) + assert.Equal(t, "http://localhost:4001/query", sg["routing_url"]) + assert.Equal(t, "type Query { test: String }", sg["sdl"]) + + subscription, ok := sg["subscription"].(map[string]interface{}) + require.True(t, ok, "Should have subscription config") + assert.Equal(t, "ws://localhost:4001/query", subscription["url"]) + assert.Equal(t, "ws", subscription["protocol"]) + assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"]) + }, + }, + { + name: "multiple subgraphs", + subGraphs: []*model.SubGraph{ + { + Service: "service-1", + URL: stringPtr("http://localhost:4001/query"), + Sdl: "type Query { field1: String }", + }, + { + Service: "service-2", + URL: stringPtr("http://localhost:4002/query"), + Sdl: "type Query { field2: String }", + }, + { + Service: "service-3", + URL: stringPtr("http://localhost:4003/query"), + WsURL: stringPtr("ws://localhost:4003/query"), + Sdl: "type Subscription { updates: String }", + }, + }, + wantErr: false, + validate: func(t *testing.T, config string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(config), &result) + require.NoError(t, err) + + subgraphs := result["subgraphs"].([]interface{}) + assert.Len(t, subgraphs, 3, "Should have 3 subgraphs") + + // Check first service has no subscription + sg1 := subgraphs[0].(map[string]interface{}) + assert.Equal(t, "service-1", sg1["name"]) + _, hasSubscription := sg1["subscription"] + assert.False(t, hasSubscription, "service-1 should not have subscription config") + + // Check third service has subscription + sg3 := subgraphs[2].(map[string]interface{}) + assert.Equal(t, "service-3", sg3["name"]) + subscription, hasSubscription := sg3["subscription"] + assert.True(t, hasSubscription, "service-3 should have subscription config") + assert.NotNil(t, subscription) + }, + }, + { + name: "subgraph with no URL", + subGraphs: []*model.SubGraph{ + { + Service: "test-service", + URL: nil, + WsURL: nil, + Sdl: "type Query { test: String }", + }, + }, + wantErr: false, + validate: func(t *testing.T, config string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(config), &result) + require.NoError(t, err) + + subgraphs := result["subgraphs"].([]interface{}) + sg := subgraphs[0].(map[string]interface{}) + + // Should not have routing_url or subscription fields if URLs are nil + _, hasRoutingURL := sg["routing_url"] + assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil") + + _, hasSubscription := sg["subscription"] + assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil") + }, + }, + { + name: "empty subgraphs", + subGraphs: []*model.SubGraph{}, + wantErr: false, + validate: func(t *testing.T, config string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(config), &result) + require.NoError(t, err) + + subgraphs := result["subgraphs"].([]interface{}) + assert.Len(t, subgraphs, 0, "Should have empty subgraphs array") + }, + }, + { + name: "nil subgraphs", + subGraphs: nil, + wantErr: false, + validate: func(t *testing.T, config string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(config), &result) + require.NoError(t, err) + + subgraphs := result["subgraphs"].([]interface{}) + assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array") + }, + }, + { + name: "complex SDL with multiple types", + subGraphs: []*model.SubGraph{ + { + Service: "complex-service", + URL: stringPtr("http://localhost:4001/query"), + Sdl: ` + type Query { + user(id: ID!): User + users: [User!]! + } + + type User { + id: ID! + name: String! + email: String! + } + `, + }, + }, + wantErr: false, + validate: func(t *testing.T, config string) { + var result map[string]interface{} + err := json.Unmarshal([]byte(config), &result) + require.NoError(t, err) + + subgraphs := result["subgraphs"].([]interface{}) + sg := subgraphs[0].(map[string]interface{}) + sdl := sg["sdl"].(string) + + assert.Contains(t, sdl, "type Query") + assert.Contains(t, sdl, "type User") + assert.Contains(t, sdl, "email: String!") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config, err := GenerateCosmoRouterConfig(tt.subGraphs) + + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.NotEmpty(t, config, "Config should not be empty") + + if tt.validate != nil { + tt.validate(t, config) + } + }) + } +} + +func TestConvertSubGraphsToCosmo(t *testing.T) { + tests := []struct { + name string + subGraphs []*model.SubGraph + wantLen int + validate func(t *testing.T, result []map[string]interface{}) + }{ + { + name: "preserves subgraph order", + subGraphs: []*model.SubGraph{ + {Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"}, + {Service: "beta", URL: stringPtr("http://b"), Sdl: "b"}, + {Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"}, + }, + wantLen: 3, + validate: func(t *testing.T, result []map[string]interface{}) { + assert.Equal(t, "alpha", result[0]["name"]) + assert.Equal(t, "beta", result[1]["name"]) + assert.Equal(t, "gamma", result[2]["name"]) + }, + }, + { + name: "includes SDL exactly as provided", + subGraphs: []*model.SubGraph{ + { + Service: "test", + URL: stringPtr("http://test"), + Sdl: "type Query { special: String! }", + }, + }, + wantLen: 1, + validate: func(t *testing.T, result []map[string]interface{}) { + assert.Equal(t, "type Query { special: String! }", result[0]["sdl"]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertSubGraphsToCosmo(tt.subGraphs) + assert.Len(t, result, tt.wantLen) + + if tt.validate != nil { + tt.validate(t, result) + } + }) + } +} + +// Helper function for tests +func stringPtr(s string) *string { + return &s +} diff --git a/graph/generated/generated.go b/graph/generated/generated.go index acf2a92..b4e5fec 100644 --- a/graph/generated/generated.go +++ b/graph/generated/generated.go @@ -42,6 +42,7 @@ type Config struct { type ResolverRoot interface { Mutation() MutationResolver Query() QueryResolver + Subscription() SubscriptionResolver } type DirectiveRoot struct { @@ -77,6 +78,13 @@ type ComplexityRoot struct { Supergraph func(childComplexity int, ref string, isAfter *string) int } + SchemaUpdate struct { + CosmoRouterConfig func(childComplexity int) int + ID func(childComplexity int) int + Ref func(childComplexity int) int + SubGraphs func(childComplexity int) int + } + SubGraph struct { ChangedAt func(childComplexity int) int ChangedBy func(childComplexity int) int @@ -94,6 +102,10 @@ type ComplexityRoot struct { SubGraphs func(childComplexity int) int } + Subscription struct { + SchemaUpdates func(childComplexity int, ref string) int + } + Unchanged struct { ID func(childComplexity int) int MinDelaySeconds func(childComplexity int) int @@ -113,6 +125,9 @@ type QueryResolver interface { Organizations(ctx context.Context) ([]*model.Organization, error) Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error) } +type SubscriptionResolver interface { + SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) +} type executableSchema struct { schema *ast.Schema @@ -253,6 +268,31 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.Query.Supergraph(childComplexity, args["ref"].(string), args["isAfter"].(*string)), true + case "SchemaUpdate.cosmoRouterConfig": + if e.complexity.SchemaUpdate.CosmoRouterConfig == nil { + break + } + + return e.complexity.SchemaUpdate.CosmoRouterConfig(childComplexity), true + case "SchemaUpdate.id": + if e.complexity.SchemaUpdate.ID == nil { + break + } + + return e.complexity.SchemaUpdate.ID(childComplexity), true + case "SchemaUpdate.ref": + if e.complexity.SchemaUpdate.Ref == nil { + break + } + + return e.complexity.SchemaUpdate.Ref(childComplexity), true + case "SchemaUpdate.subGraphs": + if e.complexity.SchemaUpdate.SubGraphs == nil { + break + } + + return e.complexity.SchemaUpdate.SubGraphs(childComplexity), true + case "SubGraph.changedAt": if e.complexity.SubGraph.ChangedAt == nil { break @@ -321,6 +361,18 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.SubGraphs.SubGraphs(childComplexity), true + case "Subscription.schemaUpdates": + if e.complexity.Subscription.SchemaUpdates == nil { + break + } + + args, err := ec.field_Subscription_schemaUpdates_args(ctx, rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Subscription.SchemaUpdates(childComplexity, args["ref"].(string)), true + case "Unchanged.id": if e.complexity.Unchanged.ID == nil { break @@ -396,6 +448,23 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { var buf bytes.Buffer data.MarshalGQL(&buf) + return &graphql.Response{ + Data: buf.Bytes(), + } + } + case ast.Subscription: + next := ec._Subscription(ctx, opCtx.Operation.SelectionSet) + + var buf bytes.Buffer + return func(ctx context.Context) *graphql.Response { + buf.Reset() + data := next(ctx) + + if data == nil { + return nil + } + data.MarshalGQL(&buf) + return &graphql.Response{ Data: buf.Bytes(), } @@ -459,6 +528,10 @@ type Mutation { updateSubGraph(input: InputSubGraph!): SubGraph! @auth(organization: true) } +type Subscription { + schemaUpdates(ref: String!): SchemaUpdate! @auth(organization: true) +} + type Organization { id: ID! name: String! @@ -504,6 +577,13 @@ type SubGraph { changedAt: Time! } +type SchemaUpdate { + ref: String! + id: ID! + subGraphs: [SubGraph!]! + cosmoRouterConfig: String +} + input InputAPIKey { name: String! organizationId: ID! @@ -607,6 +687,17 @@ func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, raw return args, nil } +func (ec *executionContext) field_Subscription_schemaUpdates_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { + var err error + args := map[string]any{} + arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string) + if err != nil { + return nil, err + } + args["ref"] = arg0 + return args, nil +} + func (ec *executionContext) field___Directive_args_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { var err error args := map[string]any{} @@ -1451,6 +1542,138 @@ func (ec *executionContext) fieldContext_Query___schema(_ context.Context, field return fc, nil } +func (ec *executionContext) _SchemaUpdate_ref(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_SchemaUpdate_ref, + func(ctx context.Context) (any, error) { + return obj.Ref, nil + }, + nil, + ec.marshalNString2string, + true, + true, + ) +} + +func (ec *executionContext) fieldContext_SchemaUpdate_ref(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "SchemaUpdate", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _SchemaUpdate_id(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_SchemaUpdate_id, + func(ctx context.Context) (any, error) { + return obj.ID, nil + }, + nil, + ec.marshalNID2string, + true, + true, + ) +} + +func (ec *executionContext) fieldContext_SchemaUpdate_id(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "SchemaUpdate", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type ID does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _SchemaUpdate_subGraphs(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_SchemaUpdate_subGraphs, + func(ctx context.Context) (any, error) { + return obj.SubGraphs, nil + }, + nil, + ec.marshalNSubGraph2ᚕᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSubGraphᚄ, + true, + true, + ) +} + +func (ec *executionContext) fieldContext_SchemaUpdate_subGraphs(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "SchemaUpdate", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_SubGraph_id(ctx, field) + case "service": + return ec.fieldContext_SubGraph_service(ctx, field) + case "url": + return ec.fieldContext_SubGraph_url(ctx, field) + case "wsUrl": + return ec.fieldContext_SubGraph_wsUrl(ctx, field) + case "sdl": + return ec.fieldContext_SubGraph_sdl(ctx, field) + case "changedBy": + return ec.fieldContext_SubGraph_changedBy(ctx, field) + case "changedAt": + return ec.fieldContext_SubGraph_changedAt(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type SubGraph", field.Name) + }, + } + return fc, nil +} + +func (ec *executionContext) _SchemaUpdate_cosmoRouterConfig(ctx context.Context, field graphql.CollectedField, obj *model.SchemaUpdate) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_SchemaUpdate_cosmoRouterConfig, + func(ctx context.Context) (any, error) { + return obj.CosmoRouterConfig, nil + }, + nil, + ec.marshalOString2ᚖstring, + true, + false, + ) +} + +func (ec *executionContext) fieldContext_SchemaUpdate_cosmoRouterConfig(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "SchemaUpdate", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _SubGraph_id(ctx context.Context, field graphql.CollectedField, obj *model.SubGraph) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, @@ -1786,6 +2009,75 @@ func (ec *executionContext) fieldContext_SubGraphs_subGraphs(_ context.Context, return fc, nil } +func (ec *executionContext) _Subscription_schemaUpdates(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { + return graphql.ResolveFieldStream( + ctx, + ec.OperationContext, + field, + ec.fieldContext_Subscription_schemaUpdates, + func(ctx context.Context) (any, error) { + fc := graphql.GetFieldContext(ctx) + return ec.resolvers.Subscription().SchemaUpdates(ctx, fc.Args["ref"].(string)) + }, + func(ctx context.Context, next graphql.Resolver) graphql.Resolver { + directive0 := next + + directive1 := func(ctx context.Context) (any, error) { + organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true) + if err != nil { + var zeroVal *model.SchemaUpdate + return zeroVal, err + } + if ec.directives.Auth == nil { + var zeroVal *model.SchemaUpdate + return zeroVal, errors.New("directive auth is not implemented") + } + return ec.directives.Auth(ctx, nil, directive0, nil, organization) + } + + next = directive1 + return next + }, + ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate, + true, + true, + ) +} + +func (ec *executionContext) fieldContext_Subscription_schemaUpdates(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Subscription", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "ref": + return ec.fieldContext_SchemaUpdate_ref(ctx, field) + case "id": + return ec.fieldContext_SchemaUpdate_id(ctx, field) + case "subGraphs": + return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field) + case "cosmoRouterConfig": + return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Subscription_schemaUpdates_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + func (ec *executionContext) _Unchanged_id(ctx context.Context, field graphql.CollectedField, obj *model.Unchanged) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, @@ -3737,6 +4029,57 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr return out } +var schemaUpdateImplementors = []string{"SchemaUpdate"} + +func (ec *executionContext) _SchemaUpdate(ctx context.Context, sel ast.SelectionSet, obj *model.SchemaUpdate) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, schemaUpdateImplementors) + + out := graphql.NewFieldSet(fields) + deferred := make(map[string]*graphql.FieldSet) + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("SchemaUpdate") + case "ref": + out.Values[i] = ec._SchemaUpdate_ref(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "id": + out.Values[i] = ec._SchemaUpdate_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "subGraphs": + out.Values[i] = ec._SchemaUpdate_subGraphs(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "cosmoRouterConfig": + out.Values[i] = ec._SchemaUpdate_cosmoRouterConfig(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), + FieldSet: dfs, + Context: ctx, + }) + } + + return out +} + var subGraphImplementors = []string{"SubGraph"} func (ec *executionContext) _SubGraph(ctx context.Context, sel ast.SelectionSet, obj *model.SubGraph) graphql.Marshaler { @@ -3854,6 +4197,26 @@ func (ec *executionContext) _SubGraphs(ctx context.Context, sel ast.SelectionSet return out } +var subscriptionImplementors = []string{"Subscription"} + +func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, subscriptionImplementors) + ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{ + Object: "Subscription", + }) + if len(fields) != 1 { + ec.Errorf(ctx, "must subscribe to exactly one stream") + return nil + } + + switch fields[0].Name { + case "schemaUpdates": + return ec._Subscription_schemaUpdates(ctx, fields[0]) + default: + panic("unknown field " + strconv.Quote(fields[0].Name)) + } +} + var unchangedImplementors = []string{"Unchanged", "Supergraph"} func (ec *executionContext) _Unchanged(ctx context.Context, sel ast.SelectionSet, obj *model.Unchanged) graphql.Marshaler { @@ -4441,6 +4804,20 @@ func (ec *executionContext) marshalNOrganization2ᚖgitlabᚗcomᚋunboundsoftwa return ec._Organization(ctx, sel, v) } +func (ec *executionContext) marshalNSchemaUpdate2gitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate(ctx context.Context, sel ast.SelectionSet, v model.SchemaUpdate) graphql.Marshaler { + return ec._SchemaUpdate(ctx, sel, &v) +} + +func (ec *executionContext) marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate(ctx context.Context, sel ast.SelectionSet, v *model.SchemaUpdate) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._SchemaUpdate(ctx, sel, v) +} + func (ec *executionContext) unmarshalNString2string(ctx context.Context, v any) (string, error) { res, err := graphql.UnmarshalString(v) return res, graphql.ErrorOnPath(ctx, err) diff --git a/graph/model/models_gen.go b/graph/model/models_gen.go index c1dcb42..4231bc4 100644 --- a/graph/model/models_gen.go +++ b/graph/model/models_gen.go @@ -49,6 +49,13 @@ type Organization struct { type Query struct { } +type SchemaUpdate struct { + Ref string `json:"ref"` + ID string `json:"id"` + SubGraphs []*SubGraph `json:"subGraphs"` + CosmoRouterConfig *string `json:"cosmoRouterConfig,omitempty"` +} + type SubGraph struct { ID string `json:"id"` Service string `json:"service"` @@ -68,6 +75,9 @@ type SubGraphs struct { func (SubGraphs) IsSupergraph() {} +type Subscription struct { +} + type Unchanged struct { ID string `json:"id"` MinDelaySeconds int `json:"minDelaySeconds"` diff --git a/graph/pubsub.go b/graph/pubsub.go new file mode 100644 index 0000000..c142574 --- /dev/null +++ b/graph/pubsub.go @@ -0,0 +1,66 @@ +package graph + +import ( + "sync" + + "gitlab.com/unboundsoftware/schemas/graph/model" +) + +// PubSub handles publishing schema updates to subscribers +type PubSub struct { + mu sync.RWMutex + subscribers map[string][]chan *model.SchemaUpdate +} + +func NewPubSub() *PubSub { + return &PubSub{ + subscribers: make(map[string][]chan *model.SchemaUpdate), + } +} + +// Subscribe creates a new subscription channel for a given schema ref +func (ps *PubSub) Subscribe(ref string) chan *model.SchemaUpdate { + ps.mu.Lock() + defer ps.mu.Unlock() + + ch := make(chan *model.SchemaUpdate, 10) + ps.subscribers[ref] = append(ps.subscribers[ref], ch) + + return ch +} + +// Unsubscribe removes a subscription channel +func (ps *PubSub) Unsubscribe(ref string, ch chan *model.SchemaUpdate) { + ps.mu.Lock() + defer ps.mu.Unlock() + + subs := ps.subscribers[ref] + for i, sub := range subs { + if sub == ch { + // Remove this subscriber + ps.subscribers[ref] = append(subs[:i], subs[i+1:]...) + close(sub) + break + } + } + + // Clean up empty subscriber lists + if len(ps.subscribers[ref]) == 0 { + delete(ps.subscribers, ref) + } +} + +// Publish sends a schema update to all subscribers of a given ref +func (ps *PubSub) Publish(ref string, update *model.SchemaUpdate) { + ps.mu.RLock() + defer ps.mu.RUnlock() + + for _, ch := range ps.subscribers[ref] { + // Non-blocking send - if subscriber is slow, skip + select { + case ch <- update: + default: + // Channel full, subscriber is too slow - skip this update + } + } +} diff --git a/graph/pubsub_test.go b/graph/pubsub_test.go new file mode 100644 index 0000000..0368835 --- /dev/null +++ b/graph/pubsub_test.go @@ -0,0 +1,256 @@ +package graph + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/unboundsoftware/schemas/graph/model" +) + +func TestPubSub_SubscribeAndPublish(t *testing.T) { + ps := NewPubSub() + ref := "Test@dev" + + // Subscribe + ch := ps.Subscribe(ref) + require.NotNil(t, ch, "Subscribe should return a channel") + + // Publish + update := &model.SchemaUpdate{ + Ref: ref, + ID: "test-id-1", + SubGraphs: []*model.SubGraph{ + { + ID: "sg1", + Service: "test-service", + Sdl: "type Query { test: String }", + }, + }, + } + + go ps.Publish(ref, update) + + // Receive + select { + case received := <-ch: + assert.Equal(t, update.Ref, received.Ref, "Ref should match") + assert.Equal(t, update.ID, received.ID, "ID should match") + assert.Equal(t, len(update.SubGraphs), len(received.SubGraphs), "SubGraphs count should match") + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for published update") + } +} + +func TestPubSub_MultipleSubscribers(t *testing.T) { + ps := NewPubSub() + ref := "Test@dev" + + // Create multiple subscribers + ch1 := ps.Subscribe(ref) + ch2 := ps.Subscribe(ref) + ch3 := ps.Subscribe(ref) + + update := &model.SchemaUpdate{ + Ref: ref, + ID: "test-id-2", + } + + // Publish once + ps.Publish(ref, update) + + // All subscribers should receive the update + var wg sync.WaitGroup + wg.Add(3) + + checkReceived := func(ch <-chan *model.SchemaUpdate, name string) { + defer wg.Done() + select { + case received := <-ch: + assert.Equal(t, update.ID, received.ID, "%s should receive correct update", name) + case <-time.After(1 * time.Second): + t.Errorf("%s: Timeout waiting for update", name) + } + } + + go checkReceived(ch1, "Subscriber 1") + go checkReceived(ch2, "Subscriber 2") + go checkReceived(ch3, "Subscriber 3") + + wg.Wait() +} + +func TestPubSub_DifferentRefs(t *testing.T) { + ps := NewPubSub() + + ref1 := "Test1@dev" + ref2 := "Test2@dev" + + ch1 := ps.Subscribe(ref1) + ch2 := ps.Subscribe(ref2) + + update1 := &model.SchemaUpdate{Ref: ref1, ID: "id1"} + update2 := &model.SchemaUpdate{Ref: ref2, ID: "id2"} + + // Publish to ref1 + ps.Publish(ref1, update1) + + // Only ch1 should receive + select { + case received := <-ch1: + assert.Equal(t, "id1", received.ID) + case <-time.After(100 * time.Millisecond): + t.Fatal("ch1 should have received update") + } + + // ch2 should not receive ref1's update + select { + case <-ch2: + t.Fatal("ch2 should not receive ref1's update") + case <-time.After(100 * time.Millisecond): + // Expected - no update + } + + // Publish to ref2 + ps.Publish(ref2, update2) + + // Now ch2 should receive + select { + case received := <-ch2: + assert.Equal(t, "id2", received.ID) + case <-time.After(100 * time.Millisecond): + t.Fatal("ch2 should have received update") + } +} + +func TestPubSub_Unsubscribe(t *testing.T) { + ps := NewPubSub() + ref := "Test@dev" + + ch := ps.Subscribe(ref) + + // Unsubscribe + ps.Unsubscribe(ref, ch) + + // Channel should be closed + _, ok := <-ch + assert.False(t, ok, "Channel should be closed after unsubscribe") + + // Publishing after unsubscribe should not panic + assert.NotPanics(t, func() { + ps.Publish(ref, &model.SchemaUpdate{Ref: ref}) + }) +} + +func TestPubSub_BufferedChannel(t *testing.T) { + ps := NewPubSub() + ref := "Test@dev" + + ch := ps.Subscribe(ref) + + // Publish multiple updates quickly (up to buffer size of 10) + for i := 0; i < 10; i++ { + update := &model.SchemaUpdate{ + Ref: ref, + ID: string(rune('a' + i)), + } + ps.Publish(ref, update) + } + + // All 10 should be buffered and receivable + received := 0 + timeout := time.After(1 * time.Second) + + for received < 10 { + select { + case <-ch: + received++ + case <-timeout: + t.Fatalf("Only received %d out of 10 updates", received) + } + } + + assert.Equal(t, 10, received, "Should receive all buffered updates") +} + +func TestPubSub_SlowSubscriber(t *testing.T) { + ps := NewPubSub() + ref := "Test@dev" + + ch := ps.Subscribe(ref) + + // Fill the buffer (10 items) + for i := 0; i < 10; i++ { + ps.Publish(ref, &model.SchemaUpdate{Ref: ref}) + } + + // Publish one more - this should be dropped (channel full, non-blocking send) + ps.Publish(ref, &model.SchemaUpdate{Ref: ref, ID: "should-be-dropped"}) + + // Drain the channel + count := 0 + timeout := time.After(500 * time.Millisecond) + +drainLoop: + for { + select { + case update := <-ch: + count++ + // Should not receive the dropped update + assert.NotEqual(t, "should-be-dropped", update.ID, "Should not receive dropped update") + case <-timeout: + break drainLoop + } + } + + // Should have received exactly 10 (the buffer size), not 11 + assert.Equal(t, 10, count, "Should only receive buffered updates, not the dropped one") +} + +func TestPubSub_ConcurrentPublish(t *testing.T) { + ps := NewPubSub() + ref := "Test@dev" + + ch := ps.Subscribe(ref) + + numPublishers := 10 + updatesPerPublisher := 10 + + var wg sync.WaitGroup + wg.Add(numPublishers) + + // Multiple goroutines publishing concurrently + for i := 0; i < numPublishers; i++ { + go func(publisherID int) { + defer wg.Done() + for j := 0; j < updatesPerPublisher; j++ { + ps.Publish(ref, &model.SchemaUpdate{ + Ref: ref, + ID: string(rune('a' + publisherID)), + }) + } + }(i) + } + + wg.Wait() + + // Should not panic and subscriber should receive updates + // (exact count may vary due to buffer and timing) + timeout := time.After(1 * time.Second) + received := 0 + +receiveLoop: + for { + select { + case <-ch: + received++ + case <-timeout: + break receiveLoop + } + } + + assert.Greater(t, received, 0, "Should have received some updates") +} diff --git a/graph/resolver.go b/graph/resolver.go index 5dc5e59..7b30ade 100644 --- a/graph/resolver.go +++ b/graph/resolver.go @@ -28,6 +28,7 @@ type Resolver struct { Publisher Publisher Logger *slog.Logger Cache *cache.Cache + PubSub *PubSub } func (r *Resolver) apiKeyCanAccessRef(ctx context.Context, ref string, publish bool) (string, error) { diff --git a/graph/schema.graphqls b/graph/schema.graphqls index 79409f7..97d82cb 100644 --- a/graph/schema.graphqls +++ b/graph/schema.graphqls @@ -9,6 +9,10 @@ type Mutation { updateSubGraph(input: InputSubGraph!): SubGraph! @auth(organization: true) } +type Subscription { + schemaUpdates(ref: String!): SchemaUpdate! @auth(organization: true) +} + type Organization { id: ID! name: String! @@ -54,6 +58,13 @@ type SubGraph { changedAt: Time! } +type SchemaUpdate { + ref: String! + id: ID! + subGraphs: [SubGraph!]! + cosmoRouterConfig: String +} + input InputAPIKey { name: String! organizationId: ID! diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 4daac31..6c68229 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -119,6 +119,44 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input if err != nil { return nil, err } + + // Publish schema update to subscribers + go func() { + services, lastUpdate := r.Cache.Services(orgId, input.Ref, "") + subGraphs := make([]*model.SubGraph, len(services)) + for i, id := range services { + sg, err := r.fetchSubGraph(context.Background(), id) + if err != nil { + r.Logger.Error("fetch subgraph for update notification", "error", err) + continue + } + subGraphs[i] = &model.SubGraph{ + ID: sg.ID.String(), + Service: sg.Service, + URL: sg.Url, + WsURL: sg.WSUrl, + Sdl: sg.Sdl, + ChangedBy: sg.ChangedBy, + ChangedAt: sg.ChangedAt, + } + } + + // Generate Cosmo router config + cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs) + if err != nil { + r.Logger.Error("generate cosmo config for update", "error", err) + cosmoConfig = "" // Send empty if generation fails + } + + // Publish to all subscribers of this ref + r.PubSub.Publish(input.Ref, &model.SchemaUpdate{ + Ref: input.Ref, + ID: lastUpdate, + SubGraphs: subGraphs, + CosmoRouterConfig: &cosmoConfig, + }) + }() + return r.toGqlSubGraph(subGraph), nil } @@ -184,13 +222,74 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str }, nil } +// SchemaUpdates is the resolver for the schemaUpdates field. +func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) { + orgId := middleware.OrganizationFromContext(ctx) + _, err := r.apiKeyCanAccessRef(ctx, ref, false) + if err != nil { + return nil, err + } + + // Subscribe to updates for this ref + ch := r.PubSub.Subscribe(ref) + + // Send initial state immediately + go func() { + services, lastUpdate := r.Cache.Services(orgId, ref, "") + subGraphs := make([]*model.SubGraph, len(services)) + for i, id := range services { + sg, err := r.fetchSubGraph(ctx, id) + if err != nil { + r.Logger.Error("fetch subgraph for initial update", "error", err) + continue + } + subGraphs[i] = &model.SubGraph{ + ID: sg.ID.String(), + Service: sg.Service, + URL: sg.Url, + WsURL: sg.WSUrl, + Sdl: sg.Sdl, + ChangedBy: sg.ChangedBy, + ChangedAt: sg.ChangedAt, + } + } + + // Generate Cosmo router config + cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs) + if err != nil { + r.Logger.Error("generate cosmo config", "error", err) + cosmoConfig = "" // Send empty if generation fails + } + + // Send initial update + ch <- &model.SchemaUpdate{ + Ref: ref, + ID: lastUpdate, + SubGraphs: subGraphs, + CosmoRouterConfig: &cosmoConfig, + } + }() + + // Clean up subscription when context is done + go func() { + <-ctx.Done() + r.PubSub.Unsubscribe(ref, ch) + }() + + return ch, nil +} + // Mutation returns generated.MutationResolver implementation. func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResolver{r} } // Query returns generated.QueryResolver implementation. func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} } +// Subscription returns generated.SubscriptionResolver implementation. +func (r *Resolver) Subscription() generated.SubscriptionResolver { return &subscriptionResolver{r} } + type ( - mutationResolver struct{ *Resolver } - queryResolver struct{ *Resolver } + mutationResolver struct{ *Resolver } + queryResolver struct{ *Resolver } + subscriptionResolver struct{ *Resolver } )