package main import ( "context" "errors" "fmt" "log/slog" "net/http" "os" "os/signal" "reflect" "sync" "syscall" "time" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/lru" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/playground" "github.com/alecthomas/kong" "github.com/rs/cors" "github.com/sparetimecoders/goamqp" "github.com/vektah/gqlparser/v2/ast" "gitlab.com/unboundsoftware/eventsourced/amqp" "gitlab.com/unboundsoftware/eventsourced/eventsourced" "gitlab.com/unboundsoftware/eventsourced/pg" "gitea.unbound.se/unboundsoftware/schemas/cache" "gitea.unbound.se/unboundsoftware/schemas/domain" "gitea.unbound.se/unboundsoftware/schemas/graph" "gitea.unbound.se/unboundsoftware/schemas/graph/generated" "gitea.unbound.se/unboundsoftware/schemas/health" "gitea.unbound.se/unboundsoftware/schemas/logging" "gitea.unbound.se/unboundsoftware/schemas/middleware" "gitea.unbound.se/unboundsoftware/schemas/monitoring" "gitea.unbound.se/unboundsoftware/schemas/store" ) type CLI struct { AmqpURL string `name:"amqp-url" env:"AMQP_URL" help:"URL to use to connect to RabbitMQ" default:"amqp://user:password@unbound-control-plane.orb.local:5672/"` Port int `name:"port" env:"PORT" help:"Listen-port for GraphQL API" default:"8080"` LogLevel string `name:"log-level" env:"LOG_LEVEL" help:"The level of logging to use (debug, info, warn, error, fatal)" default:"info"` LogFormat string `name:"log-format" env:"LOG_FORMAT" help:"The format of logs" default:"text" enum:"otel,json,text"` DatabaseURL string `name:"postgres-url" env:"POSTGRES_URL" help:"URL to use to connect to Postgres" default:"postgres://postgres:postgres@unbound-control-plane.orb.local:5432/schemas?sslmode=disable"` DatabaseDriverName string `name:"db-driver" env:"DB_DRIVER" help:"Driver to use to connect to db" default:"postgres"` Issuer string `name:"issuer" env:"ISSUER" help:"The JWT token issuer to use" default:"unbound.eu.auth0.com"` StrictSSL bool `name:"strict-ssl" env:"STRICT_SSL" help:"Should strict SSL handling be enabled" default:"true"` Environment string `name:"environment" env:"ENVIRONMENT" help:"The environment we are running in" default:"development" enum:"development,staging,production"` } var buildVersion = "none" const serviceName = "schemas" func main() { var cli CLI _ = kong.Parse(&cli) logger := logging.SetupLogger(cli.LogLevel, cli.LogFormat, serviceName, buildVersion) closeEvents := make(chan error) if err := start( closeEvents, logger, ConnectAMQP, cli, ); err != nil { logger.With("error", err).Error("process error") } } func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(url string) (Connection, error), cli CLI) error { rootCtx, rootCancel := context.WithCancel(context.Background()) defer rootCancel() shutdownFn, err := monitoring.SetupOTelSDK(rootCtx, cli.LogFormat == "otel", serviceName, buildVersion, cli.Environment) if err != nil { return err } defer func() { _ = errors.Join(shutdownFn(context.Background())) }() db, err := store.SetupDB(cli.DatabaseDriverName, cli.DatabaseURL) if err != nil { return fmt.Errorf("failed to setup DB: %v", err) } eventStore, err := pg.New( rootCtx, db.DB, pg.WithEventTypes( &domain.SubGraphUpdated{}, &domain.OrganizationAdded{}, &domain.UserAddedToOrganization{}, &domain.APIKeyAdded{}, &domain.APIKeyRemoved{}, &domain.OrganizationRemoved{}, ), ) if err != nil { return fmt.Errorf("failed to create eventstore: %v", err) } if err := store.RunEventStoreMigrations(db); err != nil { return fmt.Errorf("event migrations: %w", err) } publisher := goamqp.NewPublisher() eventPublisher, err := amqp.New(publisher) if err != nil { return fmt.Errorf("failed to create event publisher: %v", err) } conn, err := connectToAmqpFunc(cli.AmqpURL) if err != nil { return fmt.Errorf("failed to connect to AMQP: %v", err) } serviceCache := cache.New(logger) if err := loadOrganizations(rootCtx, eventStore, serviceCache); err != nil { return fmt.Errorf("caching organizations: %w", err) } if err := loadSubGraphs(rootCtx, eventStore, serviceCache); err != nil { return fmt.Errorf("caching subgraphs: %w", err) } setups := []goamqp.Setup{ goamqp.UseLogger(func(s string) { logger.Error(s) }), goamqp.CloseListener(closeEvents), goamqp.WithPrefetchLimit(20), goamqp.EventStreamPublisher(publisher), goamqp.TransientEventStreamConsumer("SubGraph.Updated", serviceCache.Update, domain.SubGraphUpdated{}), goamqp.TransientEventStreamConsumer("Organization.Added", serviceCache.Update, domain.OrganizationAdded{}), goamqp.TransientEventStreamConsumer("Organization.UserAdded", serviceCache.Update, domain.UserAddedToOrganization{}), goamqp.TransientEventStreamConsumer("Organization.APIKeyAdded", serviceCache.Update, domain.APIKeyAdded{}), goamqp.TransientEventStreamConsumer("Organization.APIKeyRemoved", serviceCache.Update, domain.APIKeyRemoved{}), goamqp.TransientEventStreamConsumer("Organization.Removed", serviceCache.Update, domain.OrganizationRemoved{}), goamqp.WithTypeMapping("SubGraph.Updated", domain.SubGraphUpdated{}), goamqp.WithTypeMapping("Organization.Added", domain.OrganizationAdded{}), goamqp.WithTypeMapping("Organization.UserAdded", domain.UserAddedToOrganization{}), goamqp.WithTypeMapping("Organization.APIKeyAdded", domain.APIKeyAdded{}), goamqp.WithTypeMapping("Organization.APIKeyRemoved", domain.APIKeyRemoved{}), goamqp.WithTypeMapping("Organization.Removed", domain.OrganizationRemoved{}), } if err := conn.Start(rootCtx, setups...); err != nil { return fmt.Errorf("failed to setup AMQP: %v", err) } defer func() { _ = conn.Close() }() logger.Info("Started") mux := http.NewServeMux() httpSrv := &http.Server{Addr: fmt.Sprintf(":%d", cli.Port), Handler: mux} wg := sync.WaitGroup{} sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) wg.Add(1) go func() { defer wg.Done() sig := <-sigint if sig != nil { // In case our shutdown logic is broken/incomplete we reset signal // handlers so next signal goes to go itself. Go is more aggressive when // shutting down goroutines signal.Reset(os.Interrupt, syscall.SIGTERM) logger.Info("Got shutdown signal..") rootCancel() } }() wg.Add(1) go func() { defer wg.Done() err := <-closeEvents if err != nil { logger.With("error", err).Error("received close from AMQP") rootCancel() } }() wg.Add(1) go func() { defer wg.Done() <-rootCtx.Done() shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second) defer shutdownRelease() if err := httpSrv.Shutdown(shutdownCtx); err != nil { logger.With("error", err).Error("close http server") } close(sigint) close(closeEvents) }() wg.Add(1) go func() { defer wg.Done() defer rootCancel() resolver := &graph.Resolver{ EventStore: eventStore, Publisher: eventPublisher, Logger: logger, Cache: serviceCache, PubSub: graph.NewPubSub(), CosmoGenerator: graph.NewCosmoGenerator(&graph.DefaultCommandExecutor{}, 60*time.Second), Debouncer: graph.NewDebouncer(500 * time.Millisecond), } config := generated.Config{ Resolvers: resolver, Complexity: generated.ComplexityRoot{}, } apiKeyMiddleware := middleware.NewApiKey() mw := middleware.NewAuth0("https://schemas.unbound.se", cli.Issuer, cli.StrictSSL) authMiddleware := middleware.NewAuth(serviceCache) config.Directives.Auth = authMiddleware.Directive srv := handler.New(generated.NewExecutableSchema(config)) srv.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Second, InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { logger.Info("WebSocket connection with API key", "has_key", true) ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key (cache handles hash comparison) if organization := serviceCache.OrganizationByAPIKey(apiKey); organization != nil { logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String()) ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } else { logger.Warn("WebSocket: No organization found for API key") } } else { logger.Info("WebSocket connection without API key") } return ctx, &initPayload, nil }, }) srv.AddTransport(transport.Options{}) srv.AddTransport(transport.GET{}) srv.AddTransport(transport.POST{}) srv.AddTransport(transport.MultipartForm{}) srv.SetQueryCache(lru.New[*ast.QueryDocument](1000)) srv.Use(extension.Introspection{}) srv.Use(extension.AutomaticPersistedQuery{ Cache: lru.New[string](100), }) healthChecker := health.New(db.DB, logger) mux.Handle("/", monitoring.Handler(playground.Handler("GraphQL playground", "/query"))) mux.Handle("/health", http.HandlerFunc(healthChecker.LivenessHandler)) mux.Handle("/health/live", http.HandlerFunc(healthChecker.LivenessHandler)) mux.Handle("/health/ready", http.HandlerFunc(healthChecker.ReadinessHandler)) mux.Handle("/query", cors.AllowAll().Handler( monitoring.Handler( mw.Middleware().CheckJWT( apiKeyMiddleware.Handler( authMiddleware.Handler(srv), ), ), ), )) logger.Info(fmt.Sprintf("connect to http://localhost:%d/ for GraphQL playground", cli.Port)) if err := httpSrv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { logger.With("error", err).Error("listen http") } }() wg.Wait() return nil } func loadOrganizations(ctx context.Context, eventStore eventsourced.EventStore, serviceCache *cache.Cache) error { roots, err := eventStore.GetAggregateRoots(ctx, reflect.TypeOf(domain.Organization{})) if err != nil { return err } for _, root := range roots { organization := &domain.Organization{BaseAggregate: eventsourced.BaseAggregateFromString(root.String())} if _, err := eventsourced.NewHandler(ctx, organization, eventStore); err != nil { return err } _, err := serviceCache.Update(organization, nil) if err != nil { return err } } return nil } func loadSubGraphs(ctx context.Context, eventStore eventsourced.EventStore, serviceCache *cache.Cache) error { roots, err := eventStore.GetAggregateRoots(ctx, reflect.TypeOf(domain.SubGraph{})) if err != nil { return err } for _, root := range roots { subGraph := &domain.SubGraph{BaseAggregate: eventsourced.BaseAggregateFromString(root.String())} if _, err := eventsourced.NewHandler(ctx, subGraph, eventStore); err != nil { return err } _, err := serviceCache.Update(subGraph, nil) if err != nil { return err } } return nil } func ConnectAMQP(url string) (Connection, error) { return goamqp.NewFromURL(serviceName, url) } type Connection interface { Start(ctx context.Context, opts ...goamqp.Setup) error Close() error }