fix: enhance API key handling and logging in middleware

Refactor API key processing to improve clarity and reduce code 
duplication. Introduce detailed logging for schema updates and 
initializations, capturing relevant context information. Use 
background context for async operations to avoid blocking. 
Implement organization lookup logic in the WebSocket init 
function for consistent API key handling across connections.
This commit is contained in:
2025-11-20 08:09:00 +01:00
parent a9a47c1690
commit bb0c08be06
3 changed files with 69 additions and 6 deletions
+19
View File
@@ -30,6 +30,7 @@ import (
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/graph"
"gitlab.com/unboundsoftware/schemas/graph/generated"
"gitlab.com/unboundsoftware/schemas/hash"
"gitlab.com/unboundsoftware/schemas/logging"
"gitlab.com/unboundsoftware/schemas/middleware"
"gitlab.com/unboundsoftware/schemas/monitoring"
@@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
srv.AddTransport(transport.Websocket{
KeepAlivePingInterval: 10 * time.Second,
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
logger.Info("WebSocket connection with API key", "has_key", true)
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (same logic as auth middleware)
if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
} else {
logger.Warn("WebSocket: No organization found for API key")
}
} else {
logger.Info("WebSocket connection without API key")
}
return ctx, &initPayload, nil
},
})
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})