bb0c08be06
Refactor API key processing to improve clarity and reduce code duplication. Introduce detailed logging for schema updates and initializations, capturing relevant context information. Use background context for async operations to avoid blocking. Implement organization lookup logic in the WebSocket init function for consistent API key handling across connections.
317 lines
10 KiB
Go
317 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"reflect"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/99designs/gqlgen/graphql/handler"
|
|
"github.com/99designs/gqlgen/graphql/handler/extension"
|
|
"github.com/99designs/gqlgen/graphql/handler/lru"
|
|
"github.com/99designs/gqlgen/graphql/handler/transport"
|
|
"github.com/99designs/gqlgen/graphql/playground"
|
|
"github.com/alecthomas/kong"
|
|
"github.com/rs/cors"
|
|
"github.com/sparetimecoders/goamqp"
|
|
"github.com/vektah/gqlparser/v2/ast"
|
|
"gitlab.com/unboundsoftware/eventsourced/amqp"
|
|
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
|
"gitlab.com/unboundsoftware/eventsourced/pg"
|
|
|
|
"gitlab.com/unboundsoftware/schemas/cache"
|
|
"gitlab.com/unboundsoftware/schemas/domain"
|
|
"gitlab.com/unboundsoftware/schemas/graph"
|
|
"gitlab.com/unboundsoftware/schemas/graph/generated"
|
|
"gitlab.com/unboundsoftware/schemas/hash"
|
|
"gitlab.com/unboundsoftware/schemas/logging"
|
|
"gitlab.com/unboundsoftware/schemas/middleware"
|
|
"gitlab.com/unboundsoftware/schemas/monitoring"
|
|
"gitlab.com/unboundsoftware/schemas/store"
|
|
)
|
|
|
|
type CLI struct {
|
|
AmqpURL string `name:"amqp-url" env:"AMQP_URL" help:"URL to use to connect to RabbitMQ" default:"amqp://user:password@unbound-control-plane.orb.local:5672/"`
|
|
Port int `name:"port" env:"PORT" help:"Listen-port for GraphQL API" default:"8080"`
|
|
LogLevel string `name:"log-level" env:"LOG_LEVEL" help:"The level of logging to use (debug, info, warn, error, fatal)" default:"info"`
|
|
LogFormat string `name:"log-format" env:"LOG_FORMAT" help:"The format of logs" default:"text" enum:"otel,json,text"`
|
|
DatabaseURL string `name:"postgres-url" env:"POSTGRES_URL" help:"URL to use to connect to Postgres" default:"postgres://postgres:postgres@unbound-control-plane.orb.local:5432/schemas?sslmode=disable"`
|
|
DatabaseDriverName string `name:"db-driver" env:"DB_DRIVER" help:"Driver to use to connect to db" default:"postgres"`
|
|
Issuer string `name:"issuer" env:"ISSUER" help:"The JWT token issuer to use" default:"unbound.eu.auth0.com"`
|
|
StrictSSL bool `name:"strict-ssl" env:"STRICT_SSL" help:"Should strict SSL handling be enabled" default:"true"`
|
|
Environment string `name:"environment" env:"ENVIRONMENT" help:"The environment we are running in" default:"development" enum:"development,staging,production"`
|
|
}
|
|
|
|
var buildVersion = "none"
|
|
|
|
const serviceName = "schemas"
|
|
|
|
func main() {
|
|
var cli CLI
|
|
_ = kong.Parse(&cli)
|
|
logger := logging.SetupLogger(cli.LogLevel, cli.LogFormat, serviceName, buildVersion)
|
|
closeEvents := make(chan error)
|
|
|
|
if err := start(
|
|
closeEvents,
|
|
logger,
|
|
ConnectAMQP,
|
|
cli,
|
|
); err != nil {
|
|
logger.With("error", err).Error("process error")
|
|
}
|
|
}
|
|
|
|
func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(url string) (Connection, error), cli CLI) error {
|
|
rootCtx, rootCancel := context.WithCancel(context.Background())
|
|
defer rootCancel()
|
|
|
|
shutdownFn, err := monitoring.SetupOTelSDK(rootCtx, cli.LogFormat == "otel", serviceName, buildVersion, cli.Environment)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = errors.Join(shutdownFn(context.Background()))
|
|
}()
|
|
|
|
db, err := store.SetupDB(cli.DatabaseDriverName, cli.DatabaseURL)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to setup DB: %v", err)
|
|
}
|
|
|
|
eventStore, err := pg.New(
|
|
rootCtx,
|
|
db.DB,
|
|
pg.WithEventTypes(
|
|
&domain.SubGraphUpdated{},
|
|
&domain.OrganizationAdded{},
|
|
&domain.APIKeyAdded{},
|
|
),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create eventstore: %v", err)
|
|
}
|
|
|
|
if err := store.RunEventStoreMigrations(db); err != nil {
|
|
return fmt.Errorf("event migrations: %w", err)
|
|
}
|
|
|
|
publisher := goamqp.NewPublisher()
|
|
eventPublisher, err := amqp.New(publisher)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create event publisher: %v", err)
|
|
}
|
|
conn, err := connectToAmqpFunc(cli.AmqpURL)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to AMQP: %v", err)
|
|
}
|
|
|
|
serviceCache := cache.New(logger)
|
|
if err := loadOrganizations(rootCtx, eventStore, serviceCache); err != nil {
|
|
return fmt.Errorf("caching organizations: %w", err)
|
|
}
|
|
if err := loadSubGraphs(rootCtx, eventStore, serviceCache); err != nil {
|
|
return fmt.Errorf("caching subgraphs: %w", err)
|
|
}
|
|
setups := []goamqp.Setup{
|
|
goamqp.UseLogger(func(s string) { logger.Error(s) }),
|
|
goamqp.CloseListener(closeEvents),
|
|
goamqp.WithPrefetchLimit(20),
|
|
goamqp.EventStreamPublisher(publisher),
|
|
goamqp.TransientEventStreamConsumer("SubGraph.Updated", serviceCache.Update, domain.SubGraphUpdated{}),
|
|
goamqp.TransientEventStreamConsumer("Organization.Added", serviceCache.Update, domain.OrganizationAdded{}),
|
|
goamqp.TransientEventStreamConsumer("Organization.APIKeyAdded", serviceCache.Update, domain.APIKeyAdded{}),
|
|
goamqp.WithTypeMapping("SubGraph.Updated", domain.SubGraphUpdated{}),
|
|
goamqp.WithTypeMapping("Organization.Added", domain.OrganizationAdded{}),
|
|
goamqp.WithTypeMapping("Organization.APIKeyAdded", domain.APIKeyAdded{}),
|
|
}
|
|
if err := conn.Start(rootCtx, setups...); err != nil {
|
|
return fmt.Errorf("failed to setup AMQP: %v", err)
|
|
}
|
|
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
logger.Info("Started")
|
|
|
|
mux := http.NewServeMux()
|
|
httpSrv := &http.Server{Addr: fmt.Sprintf(":%d", cli.Port), Handler: mux}
|
|
|
|
wg := sync.WaitGroup{}
|
|
|
|
sigint := make(chan os.Signal, 1)
|
|
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
sig := <-sigint
|
|
if sig != nil {
|
|
// In case our shutdown logic is broken/incomplete we reset signal
|
|
// handlers so next signal goes to go itself. Go is more aggressive when
|
|
// shutting down goroutines
|
|
signal.Reset(os.Interrupt, syscall.SIGTERM)
|
|
logger.Info("Got shutdown signal..")
|
|
rootCancel()
|
|
}
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
err := <-closeEvents
|
|
if err != nil {
|
|
logger.With("error", err).Error("received close from AMQP")
|
|
rootCancel()
|
|
}
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
<-rootCtx.Done()
|
|
|
|
shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer shutdownRelease()
|
|
|
|
if err := httpSrv.Shutdown(shutdownCtx); err != nil {
|
|
logger.With("error", err).Error("close http server")
|
|
}
|
|
close(sigint)
|
|
close(closeEvents)
|
|
}()
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
defer rootCancel()
|
|
|
|
resolver := &graph.Resolver{
|
|
EventStore: eventStore,
|
|
Publisher: eventPublisher,
|
|
Logger: logger,
|
|
Cache: serviceCache,
|
|
PubSub: graph.NewPubSub(),
|
|
}
|
|
|
|
config := generated.Config{
|
|
Resolvers: resolver,
|
|
Complexity: generated.ComplexityRoot{},
|
|
}
|
|
apiKeyMiddleware := middleware.NewApiKey()
|
|
mw := middleware.NewAuth0("https://schemas.unbound.se", cli.Issuer, cli.StrictSSL)
|
|
authMiddleware := middleware.NewAuth(serviceCache)
|
|
config.Directives.Auth = authMiddleware.Directive
|
|
srv := handler.New(generated.NewExecutableSchema(config))
|
|
|
|
srv.AddTransport(transport.Websocket{
|
|
KeepAlivePingInterval: 10 * time.Second,
|
|
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
logger.Info("WebSocket connection with API key", "has_key", true)
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key (same logic as auth middleware)
|
|
if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
|
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
} else {
|
|
logger.Warn("WebSocket: No organization found for API key")
|
|
}
|
|
} else {
|
|
logger.Info("WebSocket connection without API key")
|
|
}
|
|
return ctx, &initPayload, nil
|
|
},
|
|
})
|
|
srv.AddTransport(transport.Options{})
|
|
srv.AddTransport(transport.GET{})
|
|
srv.AddTransport(transport.POST{})
|
|
srv.AddTransport(transport.MultipartForm{})
|
|
|
|
srv.SetQueryCache(lru.New[*ast.QueryDocument](1000))
|
|
|
|
srv.Use(extension.Introspection{})
|
|
srv.Use(extension.AutomaticPersistedQuery{
|
|
Cache: lru.New[string](100),
|
|
})
|
|
|
|
mux.Handle("/", monitoring.Handler(playground.Handler("GraphQL playground", "/query")))
|
|
mux.Handle("/health", http.HandlerFunc(healthFunc))
|
|
mux.Handle("/query", cors.AllowAll().Handler(
|
|
monitoring.Handler(
|
|
mw.Middleware().CheckJWT(
|
|
apiKeyMiddleware.Handler(
|
|
authMiddleware.Handler(srv),
|
|
),
|
|
),
|
|
),
|
|
))
|
|
|
|
logger.Info(fmt.Sprintf("connect to http://localhost:%d/ for GraphQL playground", cli.Port))
|
|
|
|
if err := httpSrv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
|
logger.With("error", err).Error("listen http")
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
func loadOrganizations(ctx context.Context, eventStore eventsourced.EventStore, serviceCache *cache.Cache) error {
|
|
roots, err := eventStore.GetAggregateRoots(ctx, reflect.TypeOf(domain.Organization{}))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, root := range roots {
|
|
organization := &domain.Organization{BaseAggregate: eventsourced.BaseAggregateFromString(root.String())}
|
|
if _, err := eventsourced.NewHandler(ctx, organization, eventStore); err != nil {
|
|
return err
|
|
}
|
|
_, err := serviceCache.Update(organization, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loadSubGraphs(ctx context.Context, eventStore eventsourced.EventStore, serviceCache *cache.Cache) error {
|
|
roots, err := eventStore.GetAggregateRoots(ctx, reflect.TypeOf(domain.SubGraph{}))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, root := range roots {
|
|
subGraph := &domain.SubGraph{BaseAggregate: eventsourced.BaseAggregateFromString(root.String())}
|
|
if _, err := eventsourced.NewHandler(ctx, subGraph, eventStore); err != nil {
|
|
return err
|
|
}
|
|
_, err := serviceCache.Update(subGraph, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func healthFunc(w http.ResponseWriter, _ *http.Request) {
|
|
_, _ = w.Write([]byte("OK"))
|
|
}
|
|
|
|
func ConnectAMQP(url string) (Connection, error) {
|
|
return goamqp.NewFromURL(serviceName, url)
|
|
}
|
|
|
|
type Connection interface {
|
|
Start(ctx context.Context, opts ...goamqp.Setup) error
|
|
Close() error
|
|
}
|