Files
schemas/cmd/service/service.go
T
argoyle 1549538c70
schemas / vulnerabilities (pull_request) Successful in 2m11s
schemas / check-release (pull_request) Successful in 3m8s
schemas / check (pull_request) Successful in 3m30s
pre-commit / pre-commit (pull_request) Successful in 7m13s
schemas / build (pull_request) Successful in 6m30s
schemas / deploy-prod (pull_request) Has been skipped
perf(graph): warm schema cache on startup to kill cold-start spikes
Following the schema cache PR, warm pods serve from cache (~24/25 hits
on a long-running pod). New pods, however, start cold: the first
LatestSchema query per (orgId, ref) still runs the wgc router compose
subprocess, which costs 100-300m CPU per call.

That cold-start cost is what kept tripping the HPA into TooManyReplicas:
HPA scales up → new pod added → new pod runs wgc on first query →
metrics spike → HPA scales up further → cycle repeats. Even after the
caching PR landed, observed pods cycling 2→4→2→4 in production, with
fresh pods showing 2 'Fetching latest schema' (cold) entries and 0
cache hits within their first minute.

Add Cache.AllOrgRefs() exposing every tracked (orgId, ref) pair, and
Resolver.WarmCache(ctx) which iterates them after the event-sourced
caches have been populated. For each ref it fetches the subgraphs,
runs sdlmerge, runs CosmoGenerator.Generate, and stores both results
in the cache. Errors per ref are logged and skipped so a single bad
ref does not block warming the rest.

Service startup calls WarmCache right after the Resolver is wired,
before the HTTP server starts accepting traffic, so the first
LatestSchema query a pod receives is already a cache hit.
2026-05-21 17:10:30 +02:00

330 lines
11 KiB
Go

package main
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"reflect"
"sync"
"syscall"
"time"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/99designs/gqlgen/graphql/handler/lru"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/99designs/gqlgen/graphql/playground"
"github.com/alecthomas/kong"
"github.com/rs/cors"
"github.com/sparetimecoders/goamqp"
"github.com/vektah/gqlparser/v2/ast"
"gitlab.com/unboundsoftware/eventsourced/amqp"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/eventsourced/pg"
"gitea.unbound.se/unboundsoftware/schemas/cache"
"gitea.unbound.se/unboundsoftware/schemas/domain"
"gitea.unbound.se/unboundsoftware/schemas/graph"
"gitea.unbound.se/unboundsoftware/schemas/graph/generated"
"gitea.unbound.se/unboundsoftware/schemas/health"
"gitea.unbound.se/unboundsoftware/schemas/logging"
"gitea.unbound.se/unboundsoftware/schemas/middleware"
"gitea.unbound.se/unboundsoftware/schemas/monitoring"
"gitea.unbound.se/unboundsoftware/schemas/store"
)
type CLI struct {
AmqpURL string `name:"amqp-url" env:"AMQP_URL" help:"URL to use to connect to RabbitMQ" default:"amqp://user:password@unbound-control-plane.orb.local:5672/"`
Port int `name:"port" env:"PORT" help:"Listen-port for GraphQL API" default:"8080"`
LogLevel string `name:"log-level" env:"LOG_LEVEL" help:"The level of logging to use (debug, info, warn, error, fatal)" default:"info"`
LogFormat string `name:"log-format" env:"LOG_FORMAT" help:"The format of logs" default:"text" enum:"otel,json,text"`
DatabaseURL string `name:"postgres-url" env:"POSTGRES_URL" help:"URL to use to connect to Postgres" default:"postgres://postgres:postgres@unbound-control-plane.orb.local:5432/schemas?sslmode=disable"`
DatabaseDriverName string `name:"db-driver" env:"DB_DRIVER" help:"Driver to use to connect to db" default:"postgres"`
Issuer string `name:"issuer" env:"ISSUER" help:"The JWT token issuer to use" default:"unbound.eu.auth0.com"`
StrictSSL bool `name:"strict-ssl" env:"STRICT_SSL" help:"Should strict SSL handling be enabled" default:"true"`
Environment string `name:"environment" env:"ENVIRONMENT" help:"The environment we are running in" default:"development" enum:"development,staging,production"`
}
var buildVersion = "none"
const serviceName = "schemas"
func main() {
var cli CLI
_ = kong.Parse(&cli)
logger := logging.SetupLogger(cli.LogLevel, cli.LogFormat, serviceName, buildVersion)
closeEvents := make(chan error)
if err := start(
closeEvents,
logger,
ConnectAMQP,
cli,
); err != nil {
logger.With("error", err).Error("process error")
}
}
func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(url string) (Connection, error), cli CLI) error {
rootCtx, rootCancel := context.WithCancel(context.Background())
defer rootCancel()
shutdownFn, err := monitoring.SetupOTelSDK(rootCtx, cli.LogFormat == "otel", serviceName, buildVersion, cli.Environment)
if err != nil {
return err
}
defer func() {
_ = errors.Join(shutdownFn(context.Background()))
}()
db, err := store.SetupDB(cli.DatabaseDriverName, cli.DatabaseURL)
if err != nil {
return fmt.Errorf("failed to setup DB: %v", err)
}
eventStore, err := pg.New(
rootCtx,
db.DB,
pg.WithEventTypes(
&domain.SubGraphUpdated{},
&domain.OrganizationAdded{},
&domain.UserAddedToOrganization{},
&domain.APIKeyAdded{},
&domain.APIKeyRemoved{},
&domain.OrganizationRemoved{},
),
)
if err != nil {
return fmt.Errorf("failed to create eventstore: %v", err)
}
if err := store.RunEventStoreMigrations(db); err != nil {
return fmt.Errorf("event migrations: %w", err)
}
publisher := goamqp.NewPublisher()
eventPublisher, err := amqp.New(publisher)
if err != nil {
return fmt.Errorf("failed to create event publisher: %v", err)
}
conn, err := connectToAmqpFunc(cli.AmqpURL)
if err != nil {
return fmt.Errorf("failed to connect to AMQP: %v", err)
}
serviceCache := cache.New(logger)
if err := loadOrganizations(rootCtx, eventStore, serviceCache); err != nil {
return fmt.Errorf("caching organizations: %w", err)
}
if err := loadSubGraphs(rootCtx, eventStore, serviceCache); err != nil {
return fmt.Errorf("caching subgraphs: %w", err)
}
setups := []goamqp.Setup{
goamqp.UseLogger(func(s string) { logger.Error(s) }),
goamqp.CloseListener(closeEvents),
goamqp.WithPrefetchLimit(20),
goamqp.EventStreamPublisher(publisher),
goamqp.TransientEventStreamConsumer("SubGraph.Updated", serviceCache.Update, domain.SubGraphUpdated{}),
goamqp.TransientEventStreamConsumer("Organization.Added", serviceCache.Update, domain.OrganizationAdded{}),
goamqp.TransientEventStreamConsumer("Organization.UserAdded", serviceCache.Update, domain.UserAddedToOrganization{}),
goamqp.TransientEventStreamConsumer("Organization.APIKeyAdded", serviceCache.Update, domain.APIKeyAdded{}),
goamqp.TransientEventStreamConsumer("Organization.APIKeyRemoved", serviceCache.Update, domain.APIKeyRemoved{}),
goamqp.TransientEventStreamConsumer("Organization.Removed", serviceCache.Update, domain.OrganizationRemoved{}),
goamqp.WithTypeMapping("SubGraph.Updated", domain.SubGraphUpdated{}),
goamqp.WithTypeMapping("Organization.Added", domain.OrganizationAdded{}),
goamqp.WithTypeMapping("Organization.UserAdded", domain.UserAddedToOrganization{}),
goamqp.WithTypeMapping("Organization.APIKeyAdded", domain.APIKeyAdded{}),
goamqp.WithTypeMapping("Organization.APIKeyRemoved", domain.APIKeyRemoved{}),
goamqp.WithTypeMapping("Organization.Removed", domain.OrganizationRemoved{}),
}
if err := conn.Start(rootCtx, setups...); err != nil {
return fmt.Errorf("failed to setup AMQP: %v", err)
}
defer func() { _ = conn.Close() }()
logger.Info("Started")
mux := http.NewServeMux()
httpSrv := &http.Server{Addr: fmt.Sprintf(":%d", cli.Port), Handler: mux}
wg := sync.WaitGroup{}
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
wg.Add(1)
go func() {
defer wg.Done()
sig := <-sigint
if sig != nil {
// In case our shutdown logic is broken/incomplete we reset signal
// handlers so next signal goes to go itself. Go is more aggressive when
// shutting down goroutines
signal.Reset(os.Interrupt, syscall.SIGTERM)
logger.Info("Got shutdown signal..")
rootCancel()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
err := <-closeEvents
if err != nil {
logger.With("error", err).Error("received close from AMQP")
rootCancel()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-rootCtx.Done()
shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownRelease()
if err := httpSrv.Shutdown(shutdownCtx); err != nil {
logger.With("error", err).Error("close http server")
}
close(sigint)
close(closeEvents)
}()
wg.Add(1)
go func() {
defer wg.Done()
defer rootCancel()
resolver := &graph.Resolver{
EventStore: eventStore,
Publisher: eventPublisher,
Logger: logger,
Cache: serviceCache,
PubSub: graph.NewPubSub(),
CosmoGenerator: graph.NewCosmoGenerator(&graph.DefaultCommandExecutor{}, 60*time.Second),
Debouncer: graph.NewDebouncer(500 * time.Millisecond),
}
resolver.WarmCache(rootCtx)
config := generated.Config{
Resolvers: resolver,
Complexity: generated.ComplexityRoot{},
}
apiKeyMiddleware := middleware.NewApiKey()
mw := middleware.NewAuth0("https://schemas.unbound.se", cli.Issuer, cli.StrictSSL)
authMiddleware := middleware.NewAuth(serviceCache)
config.Directives.Auth = authMiddleware.Directive
srv := handler.New(generated.NewExecutableSchema(config))
srv.AddTransport(transport.Websocket{
KeepAlivePingInterval: 10 * time.Second,
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
logger.Info("WebSocket connection with API key", "has_key", true)
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := serviceCache.OrganizationByAPIKey(apiKey); organization != nil {
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
} else {
logger.Warn("WebSocket: No organization found for API key")
}
} else {
logger.Info("WebSocket connection without API key")
}
return ctx, &initPayload, nil
},
})
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})
srv.AddTransport(transport.POST{})
srv.AddTransport(transport.MultipartForm{})
srv.SetQueryCache(lru.New[*ast.QueryDocument](1000))
srv.Use(extension.Introspection{})
srv.Use(extension.AutomaticPersistedQuery{
Cache: lru.New[string](100),
})
healthChecker := health.New(db.DB, logger)
mux.Handle("/", monitoring.Handler(playground.Handler("GraphQL playground", "/query")))
mux.Handle("/health", http.HandlerFunc(healthChecker.LivenessHandler))
mux.Handle("/health/live", http.HandlerFunc(healthChecker.LivenessHandler))
mux.Handle("/health/ready", http.HandlerFunc(healthChecker.ReadinessHandler))
mux.Handle("/query", cors.AllowAll().Handler(
monitoring.Handler(
mw.Middleware().CheckJWT(
apiKeyMiddleware.Handler(
authMiddleware.Handler(srv),
),
),
),
))
logger.Info(fmt.Sprintf("connect to http://localhost:%d/ for GraphQL playground", cli.Port))
if err := httpSrv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
logger.With("error", err).Error("listen http")
}
}()
wg.Wait()
return nil
}
func loadOrganizations(ctx context.Context, eventStore eventsourced.EventStore, serviceCache *cache.Cache) error {
roots, err := eventStore.GetAggregateRoots(ctx, reflect.TypeOf(domain.Organization{}))
if err != nil {
return err
}
for _, root := range roots {
organization := &domain.Organization{BaseAggregate: eventsourced.BaseAggregateFromString(root.String())}
if _, err := eventsourced.NewHandler(ctx, organization, eventStore); err != nil {
return err
}
_, err := serviceCache.Update(organization, nil)
if err != nil {
return err
}
}
return nil
}
func loadSubGraphs(ctx context.Context, eventStore eventsourced.EventStore, serviceCache *cache.Cache) error {
roots, err := eventStore.GetAggregateRoots(ctx, reflect.TypeOf(domain.SubGraph{}))
if err != nil {
return err
}
for _, root := range roots {
subGraph := &domain.SubGraph{BaseAggregate: eventsourced.BaseAggregateFromString(root.String())}
if _, err := eventsourced.NewHandler(ctx, subGraph, eventStore); err != nil {
return err
}
_, err := serviceCache.Update(subGraph, nil)
if err != nil {
return err
}
}
return nil
}
func ConnectAMQP(url string) (Connection, error) {
return goamqp.NewFromURL(serviceName, url)
}
type Connection interface {
Start(ctx context.Context, opts ...goamqp.Setup) error
Close() error
}