package sdlmerge import ( "bytes" "fmt" "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) const ( rootOperationTypeDefinitions = ` type Query {} type Mutation {} type Subscription {} ` parseDocumentError = "parse graphql document string: %w" ) type Visitor interface { Register(walker *astvisitor.Walker) } func MergeAST(ast *ast.Document) error { normalizer := normalizer{} normalizer.setupWalkers() return normalizer.normalize(ast) } func MergeSDLs(SDLs ...string) (string, error) { rawDocs := make([]string, 0, len(SDLs)+1) rawDocs = append(rawDocs, rootOperationTypeDefinitions) rawDocs = append(rawDocs, SDLs...) if validationError := validateSubgraphs(rawDocs[1:]); validationError != nil { return "", validationError } if normalizationError := normalizeSubgraphs(rawDocs[1:]); normalizationError != nil { return "", normalizationError } doc, report := astparser.ParseGraphqlDocumentString(strings.Join(rawDocs, "\n")) if report.HasErrors() { return "", fmt.Errorf("parse graphql document string: %w", report) } astnormalization.NormalizeSubgraphSDL(&doc, &report) if report.HasErrors() { return "", fmt.Errorf("merge ast: %w", report) } if err := MergeAST(&doc); err != nil { return "", fmt.Errorf("merge ast: %w", err) } // Format with indentation for better readability buf := &bytes.Buffer{} if err := astprinter.PrintIndent(&doc, []byte(" "), buf); err != nil { return "", fmt.Errorf("stringify schema: %w", err) } return buf.String(), nil } func validateSubgraphs(subgraphs []string) error { validator := astvalidation.NewDefinitionValidator( astvalidation.PopulatedTypeBodies(), astvalidation.KnownTypeNames(), ) for _, subgraph := range subgraphs { doc, report := astparser.ParseGraphqlDocumentString(subgraph) if err := asttransform.MergeDefinitionWithBaseSchema(&doc); err != nil { return err } if report.HasErrors() { return fmt.Errorf(parseDocumentError, report) } validator.Validate(&doc, &report) if report.HasErrors() { return fmt.Errorf("validate schema: %w", report) } } return nil } func normalizeSubgraphs(subgraphs []string) error { subgraphNormalizer := astnormalization.NewSubgraphDefinitionNormalizer() for i, subgraph := range subgraphs { doc, report := astparser.ParseGraphqlDocumentString(subgraph) if report.HasErrors() { return fmt.Errorf(parseDocumentError, report) } subgraphNormalizer.NormalizeDefinition(&doc, &report) if report.HasErrors() { return fmt.Errorf("normalize schema: %w", report) } out, err := astprinter.PrintString(&doc) if err != nil { return fmt.Errorf("stringify schema: %w", err) } subgraphs[i] = out } return nil } type normalizer struct { walkers []*astvisitor.Walker } type entitySet map[string]struct{} func (m *normalizer) setupWalkers() { collectedEntities := make(entitySet) visitorGroups := [][]Visitor{ { newCollectEntitiesVisitor(collectedEntities), }, { newExtendEnumTypeDefinition(), newExtendInputObjectTypeDefinition(), newExtendInterfaceTypeDefinition(collectedEntities), newExtendScalarTypeDefinition(), newExtendUnionTypeDefinition(), newExtendObjectTypeDefinition(collectedEntities), newRemoveEmptyObjectTypeDefinition(), newRemoveMergedTypeExtensions(), }, // visitors for cleaning up federated duplicated fields and directives { newRemoveFieldDefinitions("external"), newRemoveDuplicateFieldedSharedTypesVisitor(), newRemoveDuplicateFieldlessSharedTypesVisitor(), newMergeDuplicatedFieldsVisitor(), newRemoveInterfaceDefinitionDirective("key"), newRemoveObjectTypeDefinitionDirective("key"), newRemoveFieldDefinitionDirective("provides", "requires"), }, } for _, visitorGroup := range visitorGroups { walker := astvisitor.NewWalker(48) for _, visitor := range visitorGroup { visitor.Register(&walker) m.walkers = append(m.walkers, &walker) } } } func (m *normalizer) normalize(operation *ast.Document) error { report := operationreport.Report{} for _, walker := range m.walkers { walker.Walk(operation, nil, &report) if report.HasErrors() { return fmt.Errorf("walk: %w", report) } } return nil } func (e entitySet) isExtensionForEntity(nameBytes []byte, directiveRefs []int, document *ast.Document) (bool, *operationreport.ExternalError) { name := string(nameBytes) hasDirectives := len(directiveRefs) > 0 if _, exists := e[name]; !exists { if !hasDirectives || !isEntityExtension(directiveRefs, document) { return false, nil } err := operationreport.ErrExtensionWithKeyDirectiveMustExtendEntity(name) return false, &err } if !hasDirectives { err := operationreport.ErrEntityExtensionMustHaveKeyDirective(name) return false, &err } if isEntityExtension(directiveRefs, document) { return true, nil } err := operationreport.ErrEntityExtensionMustHaveKeyDirective(name) return false, &err } func isEntityExtension(directiveRefs []int, document *ast.Document) bool { for _, directiveRef := range directiveRefs { if document.DirectiveNameString(directiveRef) == "key" { return true } } return false } func multipleExtensionError(isEntity bool, nameBytes []byte) *operationreport.ExternalError { if isEntity { err := operationreport.ErrEntitiesMustNotBeDuplicated(string(nameBytes)) return &err } err := operationreport.ErrSharedTypesMustNotBeExtended(string(nameBytes)) return &err }