2025-02-28 13:10:07 +01:00
|
|
|
package sdlmerge
|
|
|
|
|
|
|
|
|
|
import (
|
2025-11-22 18:37:07 +01:00
|
|
|
"bytes"
|
2025-02-28 13:10:07 +01:00
|
|
|
"fmt"
|
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform"
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation"
|
|
|
|
|
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization"
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter"
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor"
|
|
|
|
|
"github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
rootOperationTypeDefinitions = `
|
|
|
|
|
type Query {}
|
|
|
|
|
type Mutation {}
|
|
|
|
|
type Subscription {}
|
|
|
|
|
`
|
|
|
|
|
|
|
|
|
|
parseDocumentError = "parse graphql document string: %w"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type Visitor interface {
|
|
|
|
|
Register(walker *astvisitor.Walker)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func MergeAST(ast *ast.Document) error {
|
|
|
|
|
normalizer := normalizer{}
|
|
|
|
|
normalizer.setupWalkers()
|
|
|
|
|
|
|
|
|
|
return normalizer.normalize(ast)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func MergeSDLs(SDLs ...string) (string, error) {
|
|
|
|
|
rawDocs := make([]string, 0, len(SDLs)+1)
|
|
|
|
|
rawDocs = append(rawDocs, rootOperationTypeDefinitions)
|
|
|
|
|
rawDocs = append(rawDocs, SDLs...)
|
|
|
|
|
if validationError := validateSubgraphs(rawDocs[1:]); validationError != nil {
|
|
|
|
|
return "", validationError
|
|
|
|
|
}
|
|
|
|
|
if normalizationError := normalizeSubgraphs(rawDocs[1:]); normalizationError != nil {
|
|
|
|
|
return "", normalizationError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
doc, report := astparser.ParseGraphqlDocumentString(strings.Join(rawDocs, "\n"))
|
|
|
|
|
if report.HasErrors() {
|
|
|
|
|
return "", fmt.Errorf("parse graphql document string: %w", report)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
astnormalization.NormalizeSubgraphSDL(&doc, &report)
|
|
|
|
|
if report.HasErrors() {
|
|
|
|
|
return "", fmt.Errorf("merge ast: %w", report)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := MergeAST(&doc); err != nil {
|
|
|
|
|
return "", fmt.Errorf("merge ast: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-22 18:37:07 +01:00
|
|
|
// Format with indentation for better readability
|
|
|
|
|
buf := &bytes.Buffer{}
|
|
|
|
|
if err := astprinter.PrintIndent(&doc, []byte(" "), buf); err != nil {
|
2025-02-28 13:10:07 +01:00
|
|
|
return "", fmt.Errorf("stringify schema: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
2025-11-22 18:37:07 +01:00
|
|
|
return buf.String(), nil
|
2025-02-28 13:10:07 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func validateSubgraphs(subgraphs []string) error {
|
|
|
|
|
validator := astvalidation.NewDefinitionValidator(
|
|
|
|
|
astvalidation.PopulatedTypeBodies(), astvalidation.KnownTypeNames(),
|
|
|
|
|
)
|
|
|
|
|
for _, subgraph := range subgraphs {
|
|
|
|
|
doc, report := astparser.ParseGraphqlDocumentString(subgraph)
|
|
|
|
|
if err := asttransform.MergeDefinitionWithBaseSchema(&doc); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if report.HasErrors() {
|
|
|
|
|
return fmt.Errorf(parseDocumentError, report)
|
|
|
|
|
}
|
|
|
|
|
validator.Validate(&doc, &report)
|
|
|
|
|
if report.HasErrors() {
|
|
|
|
|
return fmt.Errorf("validate schema: %w", report)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func normalizeSubgraphs(subgraphs []string) error {
|
|
|
|
|
subgraphNormalizer := astnormalization.NewSubgraphDefinitionNormalizer()
|
|
|
|
|
for i, subgraph := range subgraphs {
|
|
|
|
|
doc, report := astparser.ParseGraphqlDocumentString(subgraph)
|
|
|
|
|
if report.HasErrors() {
|
|
|
|
|
return fmt.Errorf(parseDocumentError, report)
|
|
|
|
|
}
|
|
|
|
|
subgraphNormalizer.NormalizeDefinition(&doc, &report)
|
|
|
|
|
if report.HasErrors() {
|
|
|
|
|
return fmt.Errorf("normalize schema: %w", report)
|
|
|
|
|
}
|
|
|
|
|
out, err := astprinter.PrintString(&doc)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("stringify schema: %w", err)
|
|
|
|
|
}
|
|
|
|
|
subgraphs[i] = out
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type normalizer struct {
|
|
|
|
|
walkers []*astvisitor.Walker
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type entitySet map[string]struct{}
|
|
|
|
|
|
|
|
|
|
func (m *normalizer) setupWalkers() {
|
|
|
|
|
collectedEntities := make(entitySet)
|
|
|
|
|
visitorGroups := [][]Visitor{
|
|
|
|
|
{
|
|
|
|
|
newCollectEntitiesVisitor(collectedEntities),
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
newExtendEnumTypeDefinition(),
|
|
|
|
|
newExtendInputObjectTypeDefinition(),
|
|
|
|
|
newExtendInterfaceTypeDefinition(collectedEntities),
|
|
|
|
|
newExtendScalarTypeDefinition(),
|
|
|
|
|
newExtendUnionTypeDefinition(),
|
|
|
|
|
newExtendObjectTypeDefinition(collectedEntities),
|
|
|
|
|
newRemoveEmptyObjectTypeDefinition(),
|
|
|
|
|
newRemoveMergedTypeExtensions(),
|
|
|
|
|
},
|
|
|
|
|
// visitors for cleaning up federated duplicated fields and directives
|
|
|
|
|
{
|
|
|
|
|
newRemoveFieldDefinitions("external"),
|
|
|
|
|
newRemoveDuplicateFieldedSharedTypesVisitor(),
|
|
|
|
|
newRemoveDuplicateFieldlessSharedTypesVisitor(),
|
|
|
|
|
newMergeDuplicatedFieldsVisitor(),
|
|
|
|
|
newRemoveInterfaceDefinitionDirective("key"),
|
|
|
|
|
newRemoveObjectTypeDefinitionDirective("key"),
|
|
|
|
|
newRemoveFieldDefinitionDirective("provides", "requires"),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, visitorGroup := range visitorGroups {
|
|
|
|
|
walker := astvisitor.NewWalker(48)
|
|
|
|
|
for _, visitor := range visitorGroup {
|
|
|
|
|
visitor.Register(&walker)
|
|
|
|
|
m.walkers = append(m.walkers, &walker)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *normalizer) normalize(operation *ast.Document) error {
|
|
|
|
|
report := operationreport.Report{}
|
|
|
|
|
|
|
|
|
|
for _, walker := range m.walkers {
|
|
|
|
|
walker.Walk(operation, nil, &report)
|
|
|
|
|
if report.HasErrors() {
|
|
|
|
|
return fmt.Errorf("walk: %w", report)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (e entitySet) isExtensionForEntity(nameBytes []byte, directiveRefs []int, document *ast.Document) (bool, *operationreport.ExternalError) {
|
|
|
|
|
name := string(nameBytes)
|
|
|
|
|
hasDirectives := len(directiveRefs) > 0
|
|
|
|
|
if _, exists := e[name]; !exists {
|
|
|
|
|
if !hasDirectives || !isEntityExtension(directiveRefs, document) {
|
|
|
|
|
return false, nil
|
|
|
|
|
}
|
|
|
|
|
err := operationreport.ErrExtensionWithKeyDirectiveMustExtendEntity(name)
|
|
|
|
|
return false, &err
|
|
|
|
|
}
|
|
|
|
|
if !hasDirectives {
|
|
|
|
|
err := operationreport.ErrEntityExtensionMustHaveKeyDirective(name)
|
|
|
|
|
return false, &err
|
|
|
|
|
}
|
|
|
|
|
if isEntityExtension(directiveRefs, document) {
|
|
|
|
|
return true, nil
|
|
|
|
|
}
|
|
|
|
|
err := operationreport.ErrEntityExtensionMustHaveKeyDirective(name)
|
|
|
|
|
return false, &err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func isEntityExtension(directiveRefs []int, document *ast.Document) bool {
|
|
|
|
|
for _, directiveRef := range directiveRefs {
|
|
|
|
|
if document.DirectiveNameString(directiveRef) == "key" {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func multipleExtensionError(isEntity bool, nameBytes []byte) *operationreport.ExternalError {
|
|
|
|
|
if isEntity {
|
|
|
|
|
err := operationreport.ErrEntitiesMustNotBeDuplicated(string(nameBytes))
|
|
|
|
|
return &err
|
|
|
|
|
}
|
|
|
|
|
err := operationreport.ErrSharedTypesMustNotBeExtended(string(nameBytes))
|
|
|
|
|
return &err
|
|
|
|
|
}
|