Compare commits

...

26 Commits

Author SHA1 Message Date
argoyle 06aeedc3b0 Merge branch 'next-release' into 'main'
chore(release): prepare for v0.8.0

See merge request unboundsoftware/schemas!624
2025-11-21 11:32:33 +01:00
Unbound Release fce85782f0 chore(release): prepare for v0.8.0 2025-11-21 11:32:33 +01:00
argoyle 9cd8218eb4 Merge branch 'test/cache-reduce-goroutines-stability' into 'main'
test(cache): reduce goroutines for race detector stability

See merge request unboundsoftware/schemas!633
2025-11-21 11:19:28 +01:00
argoyle 98ef62b144 Merge branch 'renovate/github.com-auth0-go-jwt-middleware-v2-2.x' into 'main'
fix(deps): update module github.com/auth0/go-jwt-middleware/v2 to v2.3.1

See merge request unboundsoftware/schemas!630
2025-11-21 11:10:23 +01:00
argoyle e0cdd2aa58 Merge branch 'renovate/golang.org-x-crypto-0.x' into 'main'
fix(deps): update module golang.org/x/crypto to v0.45.0

See merge request unboundsoftware/schemas!629
2025-11-21 11:09:17 +01:00
argoyle e22e8b339c Merge branch 'renovate/node-24.x' into 'main'
chore(deps): update node.js to v24

See merge request unboundsoftware/schemas!627
2025-11-21 11:09:03 +01:00
argoyle 6404f7a497 test(cache): reduce goroutines for race detector stability
Decrease the number of goroutines in concurrent read and write tests to 
minimize race conditions during testing. This ensures more reliable 
test results and makes it easier to identify concurrency issues.
2025-11-21 11:06:36 +01:00
argoyle 5dc5043d46 Merge branch 'feat/cache-concurrency-logging' into 'main'
feat(cache): add concurrency safety and logging improvements

See merge request unboundsoftware/schemas!631
2025-11-21 10:45:48 +01:00
argoyle bcca005256 Merge branch 'feat/add-health-check-endpoints' into 'main'
feat(health): add health checking endpoints and logic

See merge request unboundsoftware/schemas!632
2025-11-21 10:38:01 +01:00
argoyle a9dea19531 feat(health): add health checking endpoints and logic
Introduce health checking functionality with liveness and readiness 
endpoints to monitor the application's status. Implement a health 
checker that verifies database connectivity and provides a simple 
liveness check. Update service routing to use the new health 
checker functionality. Add corresponding unit tests for validation.
2025-11-21 10:24:34 +01:00
argoyle 130e92dc5f feat(cache): add concurrency safety and logging improvements
Implement read-write mutex locks for cache functions to ensure
concurrency safety. Add debug logging for cache updates to enhance
traceability of operations. Optimize user addition logic to prevent
duplicates. Introduce a new test file for comprehensive cache
functionality testing, ensuring reliable behavior.
2025-11-21 10:21:08 +01:00
Renovate c4112a005f fix(deps): update module github.com/auth0/go-jwt-middleware/v2 to v2.3.1 2025-11-21 08:08:20 +00:00
Renovate 549f6617df fix(deps): update module golang.org/x/crypto to v0.45.0 2025-11-20 22:08:30 +00:00
argoyle a1b0f49aab Merge branch 'feat/cache-hashed-api-key-storage' into 'main'
feat(cache): implement hashed API key storage and retrieval

See merge request unboundsoftware/schemas!628
2025-11-20 22:30:49 +01:00
argoyle 4468903535 feat(cache): implement hashed API key storage and retrieval
Adds a new hashed key storage mechanism for API keys in the cache. 
Replaces direct mapping to API keys with composite keys based on 
organizationId and name. Implements searching of API keys using 
hash comparisons for improved security. Updates related tests to 
ensure correct functionality and validate the hashing. Also, 
adds support for a new dependency `golang.org/x/crypto`.
2025-11-20 22:11:24 +01:00
Renovate df054ca451 chore(deps): update node.js to v24 2025-11-20 21:09:40 +00:00
argoyle 1e2236dc9e Merge branch 'add-claude-md-documentation' into 'main'
feat: add CLAUDE.md for project documentation and guidelines

See merge request unboundsoftware/schemas!625
2025-11-20 21:50:16 +01:00
argoyle 6ccd7f4f25 feat: add CLAUDE.md for project documentation and guidelines
Adds CLAUDE.md to provide comprehensive documentation for the 
GraphQL schema registry service, covering architecture, event 
sourcing, GraphQL layer, schema merging, authentication, 
Cosmo Router integration, and development workflow. Updates 
.gitignore to include the claude directory.
2025-11-20 21:36:58 +01:00
argoyle b1a46f9d4e Merge branch 'enhance-api-key-handling-logging' into 'main'
fix: enhance API key handling and logging in middleware

See merge request unboundsoftware/schemas!623
2025-11-20 21:26:21 +01:00
argoyle 47dbf827f2 fix: add command executor interface for better testing
Introduce the CommandExecutor interface to abstract command execution, 
allowing for easier mocking in tests. Implement DefaultCommandExecutor 
to use the os/exec package for executing commands. Update the 
GenerateCosmoRouterConfig function to utilize the new 
GenerateCosmoRouterConfigWithExecutor function that accepts a command 
executor parameter. Add a MockCommandExecutor for simulating command 
execution in unit tests with realistic behavior based on input YAML 
files. This enhances test coverage and simplifies error handling.
2025-11-20 21:09:00 +01:00
argoyle df44ddbb8e test: enhance assertions for version and subscription config
Update version check to validate it is a non-empty string. Improve 
assertions for the subscription configuration by ensuring the presence 
of required fields and correct types. Adapt checks for routing URLs 
and decentralize subscription validation for more robust testing. 
These changes ensure better verification of configuration 
integrity and correctness in test scenarios.
2025-11-20 18:24:49 +01:00
argoyle 9368d77bc8 feat: add latestSchema query for retrieving schema updates
Implements the `latestSchema` query to fetch the latest schema 
updates for an organization. This change enhances the GraphQL API by
allowing clients to retrieve the most recent schema version and its 
associated subgraphs. The resolver performs necessary access checks, 
logs relevant information, and generates the Cosmo router configuration 
from fetched subgraph SDLs, returning structured schema update details.
2025-11-20 18:24:36 +01:00
argoyle 4d18cf4175 feat(tests): add unit tests for WebSocket initialization logic
Adds unit tests for the WebSocket initialization function to validate
behavior with valid, invalid, and absent API keys. Introduces a mock
cache implementation to simulate organization retrieval based on
hashed API keys. Ensures proper context value setting upon
initialization, enhancing test coverage and reliability for API key
handling in WebSocket connections.
2025-11-20 14:25:02 +01:00
argoyle bb0c08be06 fix: enhance API key handling and logging in middleware
Refactor API key processing to improve clarity and reduce code 
duplication. Introduce detailed logging for schema updates and 
initializations, capturing relevant context information. Use 
background context for async operations to avoid blocking. 
Implement organization lookup logic in the WebSocket init 
function for consistent API key handling across connections.
2025-11-20 12:58:15 +01:00
argoyle a9a47c1690 Merge branch 'renovate/gitleaks-gitleaks-8.x' into 'main'
chore(deps): update pre-commit hook gitleaks/gitleaks to v8.29.1

See merge request unboundsoftware/schemas!622
2025-11-20 09:21:24 +01:00
Renovate de073ce2da chore(deps): update pre-commit hook gitleaks/gitleaks to v8.29.1 2025-11-19 21:58:54 +00:00
27 changed files with 2617 additions and 181 deletions
+2
View File
@@ -1,9 +1,11 @@
.idea
.claude
.testCoverage.txt
.testCoverage.txt.tmp
coverage.html
/exported
/release
/schemactl
/service
CHANGES.md
VERSION
+1 -1
View File
@@ -41,7 +41,7 @@ repos:
hooks:
- id: golangci-lint-full
- repo: https://github.com/gitleaks/gitleaks
rev: v8.29.0
rev: v8.29.1
hooks:
- id: gitleaks
exclude: '^ctl/generated.go|graph/generated/.*$|^graph/model/models_gen.go|^tools/.*$$'
+1 -1
View File
@@ -1 +1 @@
{"version":"v0.7.0"}
{"version":"v0.8.0"}
+28
View File
@@ -2,6 +2,34 @@
All notable changes to this project will be documented in this file.
## [0.8.0] - 2025-11-21
### 🚀 Features
- *(tests)* Add unit tests for WebSocket initialization logic
- Add latestSchema query for retrieving schema updates
- Add CLAUDE.md for project documentation and guidelines
- *(cache)* Implement hashed API key storage and retrieval
- *(health)* Add health checking endpoints and logic
- *(cache)* Add concurrency safety and logging improvements
### 🐛 Bug Fixes
- Enhance API key handling and logging in middleware
- Add command executor interface for better testing
- *(deps)* Update module golang.org/x/crypto to v0.45.0
- *(deps)* Update module github.com/auth0/go-jwt-middleware/v2 to v2.3.1
### 🧪 Testing
- Enhance assertions for version and subscription config
- *(cache)* Reduce goroutines for race detector stability
### ⚙️ Miscellaneous Tasks
- *(deps)* Update pre-commit hook gitleaks/gitleaks to v8.29.1
- *(deps)* Update node.js to v24
## [0.7.0] - 2025-11-19
### 🚀 Features
+136
View File
@@ -0,0 +1,136 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
This is a GraphQL schema registry service that manages federated GraphQL schemas for microservices. It allows services to publish their subgraph schemas and provides merged supergraphs with Cosmo Router configuration for federated GraphQL gateways.
## Architecture
### Event Sourcing
The system uses event sourcing via `gitlab.com/unboundsoftware/eventsourced`. Key domain aggregates are:
- **Organization** (domain/aggregates.go): Manages organizations, users, and API keys
- **SubGraph** (domain/aggregates.go): Tracks subgraph schemas with versioning
All state changes flow through events (domain/events.go) and commands (domain/commands.go). The EventStore persists events to PostgreSQL, and events are published to RabbitMQ for downstream consumers.
### GraphQL Layer
- **Schema**: graph/schema.graphqls defines the API
- **Resolvers**: graph/schema.resolvers.go implements mutations/queries
- **Generated Code**: graph/generated/ and graph/model/ (auto-generated by gqlgen)
The resolver (graph/resolver.go) coordinates between the EventStore, Publisher (RabbitMQ), Cache, and PubSub for subscriptions.
### Schema Merging
The sdlmerge/ package handles GraphQL schema federation:
- Merges multiple subgraph SDL schemas into a unified supergraph
- Uses wundergraph/graphql-go-tools for AST manipulation
- Removes duplicates, extends types, and applies federation directives
### Authentication & Authorization
- **Auth0 JWT** (middleware/auth0.go): Validates user tokens from Auth0
- **API Keys** (middleware/apikey.go): Validates service API keys
- **Auth Middleware** (middleware/auth.go): Routes auth based on context
The @auth directive controls field-level access (user vs organization API key).
### Cosmo Router Integration
The service generates Cosmo Router configuration (graph/cosmo.go) using the wgc CLI tool installed in the Docker container. This config enables federated query execution across subgraphs.
### PubSub for Real-time Updates
graph/pubsub.go implements subscription support for schemaUpdates, allowing clients to receive real-time notifications when schemas change.
## Commands
### Code Generation
```bash
# Generate GraphQL server code (gqlgen), format, and organize imports
go generate ./...
```
Always run this after modifying graph/schema.graphqls. The go:generate directives are in:
- graph/resolver.go: runs gqlgen, gofumpt, and goimports
- ctl/ctl.go: generates genqlient client code
### Testing
```bash
# Run all tests
go test ./... -v
# Run tests with race detection and coverage (as used in CI)
CGO_ENABLED=1 go test -race -coverprofile=coverage.txt -covermode=atomic ./...
# Run specific package tests
go test ./middleware -v
go test ./graph -v -run TestGenerateCosmoRouterConfig
# Run single test
go test ./cmd/service -v -run TestWebSocket
```
### Building
```bash
# Build the service binary
go build -o service ./cmd/service/service.go
# Build the CLI tool
go build -o schemactl ./cmd/schemactl/schemactl.go
# Docker build (multi-stage)
docker build -t schemas .
```
The Dockerfile runs tests with coverage before building the production binary.
### Running the Service
```bash
# Start the service (requires PostgreSQL and RabbitMQ)
go run ./cmd/service/service.go \
--postgres-url="postgres://user:pass@localhost:5432/schemas?sslmode=disable" \
--amqp-url="amqp://user:pass@localhost:5672/" \
--issuer="your-auth0-domain.auth0.com"
# The service listens on port 8080 by default
# GraphQL Playground available at http://localhost:8080/
```
### Using the schemactl CLI
```bash
# Publish a subgraph schema
schemactl publish \
--api-key="your-api-key" \
--schema-ref="production" \
--service="users" \
--url="http://users-service:8080/query" \
--sdl=schema.graphql
# List subgraphs for a ref
schemactl list \
--api-key="your-api-key" \
--schema-ref="production"
```
## Development Workflow
1. **Schema Changes**: Edit graph/schema.graphqls → run `go generate ./...`
2. **Resolver Implementation**: Implement in graph/schema.resolvers.go
3. **Testing**: Write tests, run `go test ./...`
4. **Pre-commit**: Hooks run go-mod-tidy, goimports, gofumpt, golangci-lint, and tests
## Key Dependencies
- **gqlgen**: GraphQL server generation
- **genqlient**: GraphQL client generation (for ctl package)
- **eventsourced**: Event sourcing framework
- **wundergraph/graphql-go-tools**: Schema federation and composition
- **wgc CLI**: Cosmo Router config generation (Node.js tool)
- **Auth0**: JWT authentication
- **OpenTelemetry**: Observability (traces, metrics, logs)
## Important Files
- gqlgen.yml: gqlgen configuration
- graph/tools.go: Declares build-time tool dependencies
- .pre-commit-config.yaml: Pre-commit hooks configuration
- cliff.toml: Changelog generation config
+9 -1
View File
@@ -24,9 +24,17 @@ RUN GOOS=linux GOARCH=amd64 go build \
FROM scratch as export
COPY --from=build /build/coverage.txt /
FROM scratch
FROM node:24-alpine@sha256:2867d550cf9d8bb50059a0fff528741f11a84d985c732e60e19e8e75c7239c43
ENV TZ Europe/Stockholm
# Install wgc CLI globally for Cosmo Router composition
RUN npm install -g wgc@latest
# Copy timezone data and certificates
COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
# Copy the service binary
COPY --from=build /release/service /
CMD ["/service"]
+67 -20
View File
@@ -3,18 +3,21 @@ package cache
import (
"fmt"
"log/slog"
"sync"
"time"
"github.com/sparetimecoders/goamqp"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
)
type Cache struct {
mu sync.RWMutex
organizations map[string]domain.Organization
users map[string][]string
apiKeys map[string]domain.APIKey
apiKeys map[string]domain.APIKey // keyed by organizationId-name
services map[string]map[string]map[string]struct{}
subGraphs map[string]string
lastUpdate map[string]string
@@ -22,18 +25,26 @@ type Cache struct {
}
func (c *Cache) OrganizationByAPIKey(apiKey string) *domain.Organization {
key, exists := c.apiKeys[apiKey]
if !exists {
return nil
c.mu.RLock()
defer c.mu.RUnlock()
// Find the API key by comparing hashes
for _, key := range c.apiKeys {
if hash.CompareAPIKey(key.Key, apiKey) {
org, exists := c.organizations[key.OrganizationId]
if !exists {
return nil
}
return &org
}
}
org, exists := c.organizations[key.OrganizationId]
if !exists {
return nil
}
return &org
return nil
}
func (c *Cache) OrganizationsByUser(sub string) []domain.Organization {
c.mu.RLock()
defer c.mu.RUnlock()
orgIds := c.users[sub]
orgs := make([]domain.Organization, len(orgIds))
for i, id := range orgIds {
@@ -43,14 +54,22 @@ func (c *Cache) OrganizationsByUser(sub string) []domain.Organization {
}
func (c *Cache) ApiKeyByKey(key string) *domain.APIKey {
k, exists := c.apiKeys[hash.String(key)]
if !exists {
return nil
c.mu.RLock()
defer c.mu.RUnlock()
// Find the API key by comparing hashes
for _, apiKey := range c.apiKeys {
if hash.CompareAPIKey(apiKey.Key, key) {
return &apiKey
}
}
return &k
return nil
}
func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) {
c.mu.RLock()
defer c.mu.RUnlock()
key := refKey(orgId, ref)
var services []string
if lastUpdate == "" || c.lastUpdate[key] > lastUpdate {
@@ -62,41 +81,56 @@ func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) {
}
func (c *Cache) SubGraphId(orgId, ref, service string) string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.subGraphs[subGraphKey(orgId, ref, service)]
}
func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) {
c.mu.Lock()
defer c.mu.Unlock()
switch m := msg.(type) {
case *domain.OrganizationAdded:
o := domain.Organization{}
o := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(m.ID.String()),
}
m.UpdateOrganization(&o)
c.organizations[m.ID.String()] = o
c.addUser(m.Initiator, o)
c.logger.With("org_id", m.ID.String(), "event", "OrganizationAdded").Debug("cache updated")
case *domain.APIKeyAdded:
key := domain.APIKey{
Name: m.Name,
OrganizationId: m.OrganizationId,
Key: m.Key,
Key: m.Key, // This is now the hashed key
Refs: m.Refs,
Read: m.Read,
Publish: m.Publish,
CreatedBy: m.Initiator,
CreatedAt: m.When(),
}
c.apiKeys[m.Key] = key
// Use composite key: organizationId-name
c.apiKeys[apiKeyId(m.OrganizationId, m.Name)] = key
org := c.organizations[m.OrganizationId]
org.APIKeys = append(org.APIKeys, key)
c.organizations[m.OrganizationId] = org
c.logger.With("org_id", m.OrganizationId, "key_name", m.Name, "event", "APIKeyAdded").Debug("cache updated")
case *domain.SubGraphUpdated:
c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.Time)
c.logger.With("org_id", m.OrganizationId, "ref", m.Ref, "service", m.Service, "event", "SubGraphUpdated").Debug("cache updated")
case *domain.Organization:
c.organizations[m.ID.String()] = *m
c.addUser(m.CreatedBy, *m)
for _, k := range m.APIKeys {
c.apiKeys[k.Key] = k
// Use composite key: organizationId-name
c.apiKeys[apiKeyId(k.OrganizationId, k.Name)] = k
}
c.logger.With("org_id", m.ID.String(), "event", "Organization aggregate loaded").Debug("cache updated")
case *domain.SubGraph:
c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.ChangedAt)
c.logger.With("org_id", m.OrganizationId, "ref", m.Ref, "service", m.Service, "event", "SubGraph aggregate loaded").Debug("cache updated")
default:
c.logger.With("msg", msg).Warn("unexpected message received")
}
@@ -117,11 +151,20 @@ func (c *Cache) updateSubGraph(orgId string, ref string, subGraphId string, serv
func (c *Cache) addUser(sub string, organization domain.Organization) {
user, exists := c.users[sub]
orgId := organization.ID.String()
if !exists {
c.users[sub] = []string{organization.ID.String()}
} else {
c.users[sub] = append(user, organization.ID.String())
c.users[sub] = []string{orgId}
return
}
// Check if organization already exists for this user
for _, id := range user {
if id == orgId {
return // Already exists, no need to add
}
}
c.users[sub] = append(user, orgId)
}
func New(logger *slog.Logger) *Cache {
@@ -143,3 +186,7 @@ func refKey(orgId string, ref string) string {
func subGraphKey(orgId string, ref string, service string) string {
return fmt.Sprintf("%s<->%s<->%s", orgId, ref, service)
}
func apiKeyId(orgId string, name string) string {
return fmt.Sprintf("%s<->%s", orgId, name)
}
+447
View File
@@ -0,0 +1,447 @@
package cache
import (
"log/slog"
"os"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
)
func TestCache_OrganizationByAPIKey(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
apiKey := "test-api-key-123" // gitleaks:allow
hashedKey, err := hash.APIKey(apiKey)
require.NoError(t, err)
// Add organization to cache
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
Name: "Test Org",
}
c.organizations[orgID] = org
// Add API key to cache
c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{
Name: "test-key",
OrganizationId: orgID,
Key: hashedKey,
Refs: []string{"main"},
Read: true,
Publish: true,
}
// Test finding organization by plaintext API key
foundOrg := c.OrganizationByAPIKey(apiKey)
require.NotNil(t, foundOrg)
assert.Equal(t, org.Name, foundOrg.Name)
assert.Equal(t, orgID, foundOrg.ID.String())
// Test with wrong API key
notFoundOrg := c.OrganizationByAPIKey("wrong-key")
assert.Nil(t, notFoundOrg)
}
func TestCache_OrganizationByAPIKey_Legacy(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
apiKey := "legacy-api-key-456" // gitleaks:allow
legacyHash := hash.String(apiKey)
// Add organization to cache
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
Name: "Legacy Org",
}
c.organizations[orgID] = org
// Add API key with legacy SHA256 hash
c.apiKeys[apiKeyId(orgID, "legacy-key")] = domain.APIKey{
Name: "legacy-key",
OrganizationId: orgID,
Key: legacyHash,
Refs: []string{"main"},
Read: true,
Publish: false,
}
// Test finding organization with legacy hash
foundOrg := c.OrganizationByAPIKey(apiKey)
require.NotNil(t, foundOrg)
assert.Equal(t, org.Name, foundOrg.Name)
}
func TestCache_OrganizationsByUser(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
userSub := "user-123"
org1ID := uuid.New().String()
org2ID := uuid.New().String()
org1 := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(org1ID),
Name: "Org 1",
}
org2 := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(org2ID),
Name: "Org 2",
}
c.organizations[org1ID] = org1
c.organizations[org2ID] = org2
c.users[userSub] = []string{org1ID, org2ID}
orgs := c.OrganizationsByUser(userSub)
assert.Len(t, orgs, 2)
assert.Contains(t, []string{orgs[0].Name, orgs[1].Name}, "Org 1")
assert.Contains(t, []string{orgs[0].Name, orgs[1].Name}, "Org 2")
}
func TestCache_ApiKeyByKey(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
apiKey := "test-api-key-789" // gitleaks:allow
hashedKey, err := hash.APIKey(apiKey)
require.NoError(t, err)
expectedKey := domain.APIKey{
Name: "test-key",
OrganizationId: orgID,
Key: hashedKey,
Refs: []string{"main", "dev"},
Read: true,
Publish: true,
}
c.apiKeys[apiKeyId(orgID, "test-key")] = expectedKey
foundKey := c.ApiKeyByKey(apiKey)
require.NotNil(t, foundKey)
assert.Equal(t, expectedKey.Name, foundKey.Name)
assert.Equal(t, expectedKey.OrganizationId, foundKey.OrganizationId)
assert.Equal(t, expectedKey.Refs, foundKey.Refs)
// Test with wrong key
notFoundKey := c.ApiKeyByKey("wrong-key")
assert.Nil(t, notFoundKey)
}
func TestCache_Services(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
ref := "main"
service1 := "service-1"
service2 := "service-2"
lastUpdate := "2024-01-01T12:00:00Z"
c.services[orgID] = map[string]map[string]struct{}{
ref: {
service1: {},
service2: {},
},
}
c.lastUpdate[refKey(orgID, ref)] = lastUpdate
// Test getting services with empty lastUpdate
services, returnedLastUpdate := c.Services(orgID, ref, "")
assert.Len(t, services, 2)
assert.Contains(t, services, service1)
assert.Contains(t, services, service2)
assert.Equal(t, lastUpdate, returnedLastUpdate)
// Test with older lastUpdate (should return services)
services, returnedLastUpdate = c.Services(orgID, ref, "2023-12-31T12:00:00Z")
assert.Len(t, services, 2)
assert.Equal(t, lastUpdate, returnedLastUpdate)
// Test with newer lastUpdate (should return empty)
services, returnedLastUpdate = c.Services(orgID, ref, "2024-01-02T12:00:00Z")
assert.Len(t, services, 0)
assert.Equal(t, lastUpdate, returnedLastUpdate)
}
func TestCache_SubGraphId(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
ref := "main"
service := "test-service"
subGraphID := uuid.New().String()
c.subGraphs[subGraphKey(orgID, ref, service)] = subGraphID
foundID := c.SubGraphId(orgID, ref, service)
assert.Equal(t, subGraphID, foundID)
// Test with non-existent key
notFoundID := c.SubGraphId("wrong-org", ref, service)
assert.Empty(t, notFoundID)
}
func TestCache_Update_OrganizationAdded(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
event := &domain.OrganizationAdded{
Name: "New Org",
Initiator: "user-123",
}
event.ID = *eventsourced.IdFromString(orgID)
_, err := c.Update(event, nil)
require.NoError(t, err)
// Verify organization was added
org, exists := c.organizations[orgID]
assert.True(t, exists)
assert.Equal(t, "New Org", org.Name)
// Verify user was added
assert.Contains(t, c.users["user-123"], orgID)
}
func TestCache_Update_APIKeyAdded(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
keyName := "test-key"
hashedKey := "hashed-key-value"
// Add organization first
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
Name: "Test Org",
APIKeys: []domain.APIKey{},
}
c.organizations[orgID] = org
event := &domain.APIKeyAdded{
OrganizationId: orgID,
Name: keyName,
Key: hashedKey,
Refs: []string{"main"},
Read: true,
Publish: false,
Initiator: "user-123",
}
event.ID = *eventsourced.IdFromString(uuid.New().String())
_, err := c.Update(event, nil)
require.NoError(t, err)
// Verify API key was added to cache
key, exists := c.apiKeys[apiKeyId(orgID, keyName)]
assert.True(t, exists)
assert.Equal(t, keyName, key.Name)
assert.Equal(t, hashedKey, key.Key)
assert.Equal(t, []string{"main"}, key.Refs)
// Verify API key was added to organization
updatedOrg := c.organizations[orgID]
assert.Len(t, updatedOrg.APIKeys, 1)
assert.Equal(t, keyName, updatedOrg.APIKeys[0].Name)
}
func TestCache_Update_SubGraphUpdated(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
orgID := uuid.New().String()
ref := "main"
service := "test-service"
subGraphID := uuid.New().String()
event := &domain.SubGraphUpdated{
OrganizationId: orgID,
Ref: ref,
Service: service,
Initiator: "user-123",
}
event.ID = *eventsourced.IdFromString(subGraphID)
event.SetWhen(time.Now())
_, err := c.Update(event, nil)
require.NoError(t, err)
// Verify subgraph was added to services
assert.Contains(t, c.services[orgID][ref], subGraphID)
// Verify subgraph ID was stored
assert.Equal(t, subGraphID, c.subGraphs[subGraphKey(orgID, ref, service)])
// Verify lastUpdate was set
assert.NotEmpty(t, c.lastUpdate[refKey(orgID, ref)])
}
func TestCache_AddUser_NoDuplicates(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
userSub := "user-123"
orgID := uuid.New().String()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
Name: "Test Org",
}
// Add user first time
c.addUser(userSub, org)
assert.Len(t, c.users[userSub], 1)
assert.Equal(t, orgID, c.users[userSub][0])
// Add same user/org again - should not create duplicate
c.addUser(userSub, org)
assert.Len(t, c.users[userSub], 1, "Should not add duplicate organization")
assert.Equal(t, orgID, c.users[userSub][0])
}
func TestCache_ConcurrentReads(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
// Setup test data
orgID := uuid.New().String()
apiKey := "test-concurrent-key" // gitleaks:allow
hashedKey, err := hash.APIKey(apiKey)
require.NoError(t, err)
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
Name: "Concurrent Test Org",
}
c.organizations[orgID] = org
c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{
Name: "test-key",
OrganizationId: orgID,
Key: hashedKey,
}
// Run concurrent reads (reduced for race detector)
var wg sync.WaitGroup
numGoroutines := 20
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
org := c.OrganizationByAPIKey(apiKey)
assert.NotNil(t, org)
assert.Equal(t, "Concurrent Test Org", org.Name)
}()
}
wg.Wait()
}
func TestCache_ConcurrentWrites(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
var wg sync.WaitGroup
numGoroutines := 10 // Reduced for race detector
// Concurrent organization additions
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
orgID := uuid.New().String()
event := &domain.OrganizationAdded{
Name: "Org " + string(rune(index)),
Initiator: "user-" + string(rune(index)),
}
event.ID = *eventsourced.IdFromString(orgID)
_, err := c.Update(event, nil)
assert.NoError(t, err)
}(i)
}
wg.Wait()
// Verify all organizations were added
assert.Equal(t, numGoroutines, len(c.organizations))
}
func TestCache_ConcurrentReadsAndWrites(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c := New(logger)
// Setup initial data
orgID := uuid.New().String()
apiKey := "test-rw-key" // gitleaks:allow
hashedKey, err := hash.APIKey(apiKey)
require.NoError(t, err)
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
Name: "RW Test Org",
}
c.organizations[orgID] = org
c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{
Name: "test-key",
OrganizationId: orgID,
Key: hashedKey,
}
c.users["user-initial"] = []string{orgID}
var wg sync.WaitGroup
numReaders := 10 // Reduced for race detector
numWriters := 5 // Reduced for race detector
iterations := 3 // Reduced for race detector
// Concurrent readers
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iterations; j++ {
org := c.OrganizationByAPIKey(apiKey)
assert.NotNil(t, org)
orgs := c.OrganizationsByUser("user-initial")
assert.NotEmpty(t, orgs)
}
}()
}
// Concurrent writers
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
newOrgID := uuid.New().String()
event := &domain.OrganizationAdded{
Name: "New Org " + string(rune(index)),
Initiator: "user-new-" + string(rune(index)),
}
event.ID = *eventsourced.IdFromString(newOrgID)
_, err := c.Update(event, nil)
assert.NoError(t, err)
}(i)
}
wg.Wait()
// Verify cache is in consistent state
assert.GreaterOrEqual(t, len(c.organizations), numWriters)
}
+24 -5
View File
@@ -30,6 +30,7 @@ import (
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/graph"
"gitlab.com/unboundsoftware/schemas/graph/generated"
"gitlab.com/unboundsoftware/schemas/health"
"gitlab.com/unboundsoftware/schemas/logging"
"gitlab.com/unboundsoftware/schemas/middleware"
"gitlab.com/unboundsoftware/schemas/monitoring"
@@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
srv.AddTransport(transport.Websocket{
KeepAlivePingInterval: 10 * time.Second,
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
logger.Info("WebSocket connection with API key", "has_key", true)
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := serviceCache.OrganizationByAPIKey(apiKey); organization != nil {
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
} else {
logger.Warn("WebSocket: No organization found for API key")
}
} else {
logger.Info("WebSocket connection without API key")
}
return ctx, &initPayload, nil
},
})
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})
@@ -223,8 +242,12 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
Cache: lru.New[string](100),
})
healthChecker := health.New(db.DB, logger)
mux.Handle("/", monitoring.Handler(playground.Handler("GraphQL playground", "/query")))
mux.Handle("/health", http.HandlerFunc(healthFunc))
mux.Handle("/health", http.HandlerFunc(healthChecker.LivenessHandler))
mux.Handle("/health/live", http.HandlerFunc(healthChecker.LivenessHandler))
mux.Handle("/health/ready", http.HandlerFunc(healthChecker.ReadinessHandler))
mux.Handle("/query", cors.AllowAll().Handler(
monitoring.Handler(
mw.Middleware().CheckJWT(
@@ -283,10 +306,6 @@ func loadSubGraphs(ctx context.Context, eventStore eventsourced.EventStore, serv
return nil
}
func healthFunc(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("OK"))
}
func ConnectAMQP(url string) (Connection, error) {
return goamqp.NewFromURL(serviceName, url)
}
+362
View File
@@ -0,0 +1,362 @@
package main
import (
"context"
"testing"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
"gitlab.com/unboundsoftware/schemas/middleware"
)
// MockCache is a mock implementation for testing
type MockCache struct {
organizations map[string]*domain.Organization // keyed by orgId-name composite
apiKeys map[string]string // maps orgId-name to hashed key
}
func (m *MockCache) OrganizationByAPIKey(plainKey string) *domain.Organization {
// Find organization by comparing plaintext key with stored hash
for compositeKey, hashedKey := range m.apiKeys {
if hash.CompareAPIKey(hashedKey, plainKey) {
return m.organizations[compositeKey]
}
}
return nil
}
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
// Setup
orgID := uuid.New()
org := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
hashedKey, err := hash.APIKey(apiKey)
require.NoError(t, err)
compositeKey := orgID.String() + "-test-key"
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
compositeKey: org,
},
apiKeys: map[string]string{
compositeKey: hashedKey,
},
}
// Create InitFunc (simulating the WebSocket InitFunc logic)
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is in context
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok, "Organization should be of correct type")
assert.Equal(t, org.Name, capturedOrg.Name)
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization not found in context")
}
}
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
apiKey := "invalid-api-key"
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is NOT in context (since API key is invalid)
value := resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set for invalid API key")
}
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when not provided")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is not provided")
}
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": "", // Empty string
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (because empty string fails the condition)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when empty")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is empty")
}
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": 12345, // Wrong type (int instead of string)
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (type assertion fails)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when wrong type")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
}
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
// Setup - create multiple organizations
org1ID := uuid.New()
org1 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org1ID.String()),
},
Name: "Organization 1",
}
org2ID := uuid.New()
org2 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org2ID.String()),
},
Name: "Organization 2",
}
apiKey1 := "api-key-org-1"
apiKey2 := "api-key-org-2"
hashedKey1, err := hash.APIKey(apiKey1)
require.NoError(t, err)
hashedKey2, err := hash.APIKey(apiKey2)
require.NoError(t, err)
compositeKey1 := org1ID.String() + "-key1"
compositeKey2 := org2ID.String() + "-key2"
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
compositeKey1: org1,
compositeKey2: org2,
},
apiKeys: map[string]string{
compositeKey1: hashedKey1,
compositeKey2: hashedKey2,
},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test with first API key
ctx1 := context.Background()
initPayload1 := transport.InitPayload{
"X-Api-Key": apiKey1,
}
resultCtx1, _, err := initFunc(ctx1, initPayload1)
require.NoError(t, err)
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org1.Name, capturedOrg.Name)
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 1 not found in context")
}
// Test with second API key
ctx2 := context.Background()
initPayload2 := transport.InitPayload{
"X-Api-Key": apiKey2,
}
resultCtx2, _, err := initFunc(ctx2, initPayload2)
require.NoError(t, err)
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org2.Name, capturedOrg.Name)
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 2 not found in context")
}
}
+12 -1
View File
@@ -56,9 +56,20 @@ func (a AddAPIKey) Validate(_ context.Context, aggregate eventsourced.Aggregate)
}
func (a AddAPIKey) Event(context.Context) eventsourced.Event {
// Hash the API key using bcrypt for secure storage
// Note: We can't return an error here, but bcrypt errors are extremely rare
// (only if system runs out of memory or bcrypt cost is invalid)
// We use a fixed cost of 12 which is always valid
hashedKey, err := hash.APIKey(a.Key)
if err != nil {
// This should never happen with bcrypt cost 12, but if it does,
// we'll store an empty hash which will fail validation later
hashedKey = ""
}
return &APIKeyAdded{
Name: a.Name,
Key: hash.String(a.Key),
Key: hashedKey,
Refs: a.Refs,
Read: a.Read,
Publish: a.Publish,
+24 -11
View File
@@ -2,10 +2,13 @@ package domain
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/schemas/hash"
)
func TestAddAPIKey_Event(t *testing.T) {
@@ -24,7 +27,6 @@ func TestAddAPIKey_Event(t *testing.T) {
name string
fields fields
args args
want eventsourced.Event
}{
{
name: "event",
@@ -37,14 +39,6 @@ func TestAddAPIKey_Event(t *testing.T) {
Initiator: "jim@example.org",
},
args: args{},
want: &APIKeyAdded{
Name: "test",
Key: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY/BwUmvv0yJlvuSQnrkHkZJuTTKSVmRt4UrhV",
Refs: []string{"Example@dev"},
Read: true,
Publish: true,
Initiator: "jim@example.org",
},
},
}
for _, tt := range tests {
@@ -57,7 +51,26 @@ func TestAddAPIKey_Event(t *testing.T) {
Publish: tt.fields.Publish,
Initiator: tt.fields.Initiator,
}
assert.Equalf(t, tt.want, a.Event(tt.args.in0), "Event(%v)", tt.args.in0)
event := a.Event(tt.args.in0)
require.NotNil(t, event)
// Cast to APIKeyAdded to verify fields
apiKeyEvent, ok := event.(*APIKeyAdded)
require.True(t, ok, "Event should be *APIKeyAdded")
// Verify non-key fields match exactly
assert.Equal(t, tt.fields.Name, apiKeyEvent.Name)
assert.Equal(t, tt.fields.Refs, apiKeyEvent.Refs)
assert.Equal(t, tt.fields.Read, apiKeyEvent.Read)
assert.Equal(t, tt.fields.Publish, apiKeyEvent.Publish)
assert.Equal(t, tt.fields.Initiator, apiKeyEvent.Initiator)
// Verify the key is hashed correctly (bcrypt format)
assert.True(t, strings.HasPrefix(apiKeyEvent.Key, "$2"), "Key should be bcrypt hashed")
assert.NotEqual(t, tt.fields.Key, apiKeyEvent.Key, "Key should be hashed, not plaintext")
// Verify the hash matches the original key
assert.True(t, hash.CompareAPIKey(apiKeyEvent.Key, tt.fields.Key), "Hashed key should match original")
})
}
}
+10 -7
View File
@@ -4,11 +4,13 @@ go 1.25
require (
github.com/99designs/gqlgen v0.17.83
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/Khan/genqlient v0.8.1
github.com/alecthomas/kong v1.13.0
github.com/apex/log v1.9.0
github.com/auth0/go-jwt-middleware/v2 v2.3.0
github.com/auth0/go-jwt-middleware/v2 v2.3.1
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/pkg/errors v0.9.1
github.com/pressly/goose/v3 v3.26.0
@@ -30,6 +32,8 @@ require (
go.opentelemetry.io/otel/sdk/log v0.14.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
golang.org/x/crypto v0.45.0
gopkg.in/yaml.v3 v3.0.1
)
require (
@@ -41,7 +45,6 @@ require (
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.5.1 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
@@ -51,6 +54,7 @@ require (
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/sosodev/duration v1.3.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/gjson v1.17.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
@@ -63,14 +67,13 @@ require (
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/sys v0.37.0 // indirect
golang.org/x/text v0.30.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+15 -12
View File
@@ -27,8 +27,8 @@ github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy
github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/auth0/go-jwt-middleware/v2 v2.3.0 h1:4QREj6cS3d8dS05bEm443jhnqQF97FX9sMBeWqnNRzE=
github.com/auth0/go-jwt-middleware/v2 v2.3.0/go.mod h1:dL4ObBs1/dj4/W4cYxd8rqAdDGXYyd5rqbpMIxcbVrU=
github.com/auth0/go-jwt-middleware/v2 v2.3.1 h1:lbDyWE9aLydb3zrank+Gufb9qGJN9u//7EbJK07pRrw=
github.com/auth0/go-jwt-middleware/v2 v2.3.1/go.mod h1:mqVr0gdB5zuaFyQFWMJH/c/2hehNjbYUD4i8Dpyf+Hc=
github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
@@ -83,6 +83,7 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -141,6 +142,8 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA
github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE=
github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
@@ -212,8 +215,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
@@ -221,21 +224,21 @@ golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
+1 -1
View File
@@ -38,7 +38,7 @@ func ToGqlAPIKeys(keys []domain.APIKey) []*model.APIKey {
result[i] = &model.APIKey{
ID: apiKeyId(k.OrganizationId, k.Name),
Name: k.Name,
Key: &k.Key,
Key: nil, // Never return the hashed key - only return plaintext on creation
Organization: nil,
Refs: k.Refs,
Read: k.Read,
+98 -27
View File
@@ -1,54 +1,125 @@
package graph
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"gopkg.in/yaml.v3"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
// Build the Cosmo router config structure
// This is a simplified version - you may need to adjust based on actual Cosmo requirements
config := map[string]interface{}{
"version": "1",
"subgraphs": convertSubGraphsToCosmo(subGraphs),
// Add other Cosmo-specific configuration as needed
}
// Marshal to JSON
configJSON, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", fmt.Errorf("marshal cosmo config: %w", err)
}
return string(configJSON), nil
// CommandExecutor is an interface for executing external commands
// This allows for mocking in tests
type CommandExecutor interface {
Execute(name string, args ...string) ([]byte, error)
}
func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} {
cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs))
// DefaultCommandExecutor implements CommandExecutor using os/exec
type DefaultCommandExecutor struct{}
// Execute runs a command and returns its combined output
func (e *DefaultCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
cmd := exec.Command(name, args...)
return cmd.CombinedOutput()
}
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
// using the official wgc CLI tool via npx
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
return GenerateCosmoRouterConfigWithExecutor(subGraphs, &DefaultCommandExecutor{})
}
// GenerateCosmoRouterConfigWithExecutor generates a Cosmo Router execution config from subgraphs
// using the provided command executor (useful for testing)
func GenerateCosmoRouterConfigWithExecutor(subGraphs []*model.SubGraph, executor CommandExecutor) (string, error) {
if len(subGraphs) == 0 {
return "", fmt.Errorf("no subgraphs provided")
}
// Create a temporary directory for composition
tmpDir, err := os.MkdirTemp("", "cosmo-compose-*")
if err != nil {
return "", fmt.Errorf("create temp dir: %w", err)
}
defer os.RemoveAll(tmpDir)
// Write each subgraph SDL to a file
type SubgraphConfig struct {
Name string `yaml:"name"`
RoutingURL string `yaml:"routing_url,omitempty"`
Schema map[string]string `yaml:"schema"`
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
}
type InputConfig struct {
Version int `yaml:"version"`
Subgraphs []SubgraphConfig `yaml:"subgraphs"`
}
inputConfig := InputConfig{
Version: 1,
Subgraphs: make([]SubgraphConfig, 0, len(subGraphs)),
}
for _, sg := range subGraphs {
cosmoSg := map[string]interface{}{
"name": sg.Service,
"sdl": sg.Sdl,
// Write SDL to a temp file
schemaFile := filepath.Join(tmpDir, fmt.Sprintf("%s.graphql", sg.Service))
if err := os.WriteFile(schemaFile, []byte(sg.Sdl), 0o644); err != nil {
return "", fmt.Errorf("write schema file for %s: %w", sg.Service, err)
}
subgraphCfg := SubgraphConfig{
Name: sg.Service,
Schema: map[string]string{
"file": schemaFile,
},
}
if sg.URL != nil {
cosmoSg["routing_url"] = *sg.URL
subgraphCfg.RoutingURL = *sg.URL
}
if sg.WsURL != nil {
cosmoSg["subscription"] = map[string]interface{}{
subgraphCfg.Subscription = map[string]interface{}{
"url": *sg.WsURL,
"protocol": "ws",
"websocket_subprotocol": "graphql-ws",
}
}
cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg)
inputConfig.Subgraphs = append(inputConfig.Subgraphs, subgraphCfg)
}
return cosmoSubgraphs
// Write input config YAML
inputFile := filepath.Join(tmpDir, "input.yaml")
inputYAML, err := yaml.Marshal(inputConfig)
if err != nil {
return "", fmt.Errorf("marshal input config: %w", err)
}
if err := os.WriteFile(inputFile, inputYAML, 0o644); err != nil {
return "", fmt.Errorf("write input config: %w", err)
}
// Execute wgc router compose
// wgc is installed globally in the Docker image
outputFile := filepath.Join(tmpDir, "config.json")
output, err := executor.Execute("wgc", "router", "compose",
"--input", inputFile,
"--out", outputFile,
"--suppress-warnings",
)
if err != nil {
return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output))
}
// Read the generated config
configJSON, err := os.ReadFile(outputFile)
if err != nil {
return "", fmt.Errorf("read output config: %w", err)
}
return string(configJSON), nil
}
+293 -86
View File
@@ -2,14 +2,185 @@ package graph
import (
"encoding/json"
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
// MockCommandExecutor implements CommandExecutor for testing
type MockCommandExecutor struct {
// CallCount tracks how many times Execute was called
CallCount int
// LastArgs stores the arguments from the last call
LastArgs []string
// Error can be set to simulate command failures
Error error
}
// Execute mocks the wgc command by generating a realistic config.json file
func (m *MockCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
m.CallCount++
m.LastArgs = append([]string{name}, args...)
if m.Error != nil {
return nil, m.Error
}
// Parse the input file to understand what subgraphs we're composing
var inputFile, outputFile string
for i, arg := range args {
if arg == "--input" && i+1 < len(args) {
inputFile = args[i+1]
}
if arg == "--out" && i+1 < len(args) {
outputFile = args[i+1]
}
}
if inputFile == "" || outputFile == "" {
return nil, fmt.Errorf("missing required arguments")
}
// Read the input YAML to get subgraph information
inputData, err := os.ReadFile(inputFile)
if err != nil {
return nil, fmt.Errorf("failed to read input file: %w", err)
}
var input struct {
Version int `yaml:"version"`
Subgraphs []struct {
Name string `yaml:"name"`
RoutingURL string `yaml:"routing_url,omitempty"`
Schema map[string]string `yaml:"schema"`
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
} `yaml:"subgraphs"`
}
if err := yaml.Unmarshal(inputData, &input); err != nil {
return nil, fmt.Errorf("failed to parse input YAML: %w", err)
}
// Generate a realistic Cosmo Router config based on the input
config := map[string]interface{}{
"version": "mock-version-uuid",
"subgraphs": func() []map[string]interface{} {
subgraphs := make([]map[string]interface{}, len(input.Subgraphs))
for i, sg := range input.Subgraphs {
subgraph := map[string]interface{}{
"id": fmt.Sprintf("mock-id-%d", i),
"name": sg.Name,
}
if sg.RoutingURL != "" {
subgraph["routingUrl"] = sg.RoutingURL
}
subgraphs[i] = subgraph
}
return subgraphs
}(),
"engineConfig": map[string]interface{}{
"graphqlSchema": generateMockSchema(input.Subgraphs),
"datasourceConfigurations": func() []map[string]interface{} {
dsConfigs := make([]map[string]interface{}, len(input.Subgraphs))
for i, sg := range input.Subgraphs {
// Read SDL from file
sdl := ""
if schemaFile, ok := sg.Schema["file"]; ok {
if sdlData, err := os.ReadFile(schemaFile); err == nil {
sdl = string(sdlData)
}
}
dsConfig := map[string]interface{}{
"id": fmt.Sprintf("datasource-%d", i),
"kind": "GRAPHQL",
"customGraphql": map[string]interface{}{
"federation": map[string]interface{}{
"enabled": true,
"serviceSdl": sdl,
},
"subscription": func() map[string]interface{} {
if len(sg.Subscription) > 0 {
return map[string]interface{}{
"enabled": true,
"url": map[string]interface{}{
"staticVariableContent": sg.Subscription["url"],
},
"protocol": sg.Subscription["protocol"],
"websocketSubprotocol": sg.Subscription["websocket_subprotocol"],
}
}
return map[string]interface{}{
"enabled": false,
}
}(),
},
}
dsConfigs[i] = dsConfig
}
return dsConfigs
}(),
},
}
// Write the config to the output file
configJSON, err := json.Marshal(config)
if err != nil {
return nil, fmt.Errorf("failed to marshal config: %w", err)
}
if err := os.WriteFile(outputFile, configJSON, 0o644); err != nil {
return nil, fmt.Errorf("failed to write output file: %w", err)
}
return []byte("Success"), nil
}
// generateMockSchema creates a simple merged schema from subgraphs
func generateMockSchema(subgraphs []struct {
Name string `yaml:"name"`
RoutingURL string `yaml:"routing_url,omitempty"`
Schema map[string]string `yaml:"schema"`
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
},
) string {
schema := strings.Builder{}
schema.WriteString("schema {\n query: Query\n")
// Check if any subgraph has subscriptions
hasSubscriptions := false
for _, sg := range subgraphs {
if len(sg.Subscription) > 0 {
hasSubscriptions = true
break
}
}
if hasSubscriptions {
schema.WriteString(" subscription: Subscription\n")
}
schema.WriteString("}\n\n")
// Add types by reading SDL files
for _, sg := range subgraphs {
if schemaFile, ok := sg.Schema["file"]; ok {
if sdlData, err := os.ReadFile(schemaFile); err == nil {
schema.WriteString(string(sdlData))
schema.WriteString("\n")
}
}
}
return schema.String()
}
func TestGenerateCosmoRouterConfig(t *testing.T) {
tests := []struct {
name string
@@ -33,7 +204,10 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err, "Config should be valid JSON")
assert.Equal(t, "1", result["version"], "Version should be 1")
// Version is a UUID string from wgc
version, ok := result["version"].(string)
require.True(t, ok, "Version should be a string")
assert.NotEmpty(t, version, "Version should not be empty")
subgraphs, ok := result["subgraphs"].([]interface{})
require.True(t, ok, "subgraphs should be an array")
@@ -41,14 +215,26 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
sg := subgraphs[0].(map[string]interface{})
assert.Equal(t, "test-service", sg["name"])
assert.Equal(t, "http://localhost:4001/query", sg["routing_url"])
assert.Equal(t, "type Query { test: String }", sg["sdl"])
assert.Equal(t, "http://localhost:4001/query", sg["routingUrl"])
subscription, ok := sg["subscription"].(map[string]interface{})
// Check that datasource configurations include subscription settings
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
ds := dsConfigs[0].(map[string]interface{})
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
require.True(t, ok, "Should have customGraphql config")
subscription, ok := customGraphql["subscription"].(map[string]interface{})
require.True(t, ok, "Should have subscription config")
assert.Equal(t, "ws://localhost:4001/query", subscription["url"])
assert.Equal(t, "ws", subscription["protocol"])
assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"])
assert.True(t, subscription["enabled"].(bool), "Subscription should be enabled")
subUrl, ok := subscription["url"].(map[string]interface{})
require.True(t, ok, "Should have subscription URL")
assert.Equal(t, "ws://localhost:4001/query", subUrl["staticVariableContent"])
},
},
{
@@ -80,18 +266,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 3, "Should have 3 subgraphs")
// Check first service has no subscription
// Check service names
sg1 := subgraphs[0].(map[string]interface{})
assert.Equal(t, "service-1", sg1["name"])
_, hasSubscription := sg1["subscription"]
assert.False(t, hasSubscription, "service-1 should not have subscription config")
// Check third service has subscription
sg3 := subgraphs[2].(map[string]interface{})
assert.Equal(t, "service-3", sg3["name"])
subscription, hasSubscription := sg3["subscription"]
assert.True(t, hasSubscription, "service-3 should have subscription config")
assert.NotNil(t, subscription)
// Check that datasource configurations include subscription for service-3
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) == 3, "Should have 3 datasource configurations")
// Find service-3's datasource config (should have subscription enabled)
ds3 := dsConfigs[2].(map[string]interface{})
customGraphql, ok := ds3["customGraphql"].(map[string]interface{})
require.True(t, ok, "Service-3 should have customGraphql config")
subscription, ok := customGraphql["subscription"].(map[string]interface{})
require.True(t, ok, "Service-3 should have subscription config")
assert.True(t, subscription["enabled"].(bool), "Service-3 subscription should be enabled")
},
},
{
@@ -113,39 +309,43 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
subgraphs := result["subgraphs"].([]interface{})
sg := subgraphs[0].(map[string]interface{})
// Should not have routing_url or subscription fields if URLs are nil
_, hasRoutingURL := sg["routing_url"]
assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil")
// Should not have routing URL when URL is nil
_, hasRoutingURL := sg["routingUrl"]
assert.False(t, hasRoutingURL, "Should not have routingUrl when URL is nil")
_, hasSubscription := sg["subscription"]
assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil")
// Check datasource configurations don't have subscription enabled
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
ds := dsConfigs[0].(map[string]interface{})
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
require.True(t, ok, "Should have customGraphql config")
subscription, ok := customGraphql["subscription"].(map[string]interface{})
if ok {
// wgc always enables subscription but URL should be empty when WsURL is nil
subUrl, hasUrl := subscription["url"].(map[string]interface{})
if hasUrl {
_, hasStaticContent := subUrl["staticVariableContent"]
assert.False(t, hasStaticContent, "Subscription URL should be empty when WsURL is nil")
}
}
},
},
{
name: "empty subgraphs",
subGraphs: []*model.SubGraph{},
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 0, "Should have empty subgraphs array")
},
wantErr: true,
validate: nil,
},
{
name: "nil subgraphs",
subGraphs: nil,
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array")
},
wantErr: true,
validate: nil,
},
{
name: "complex SDL with multiple types",
@@ -173,29 +373,58 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
sg := subgraphs[0].(map[string]interface{})
sdl := sg["sdl"].(string)
// Check the composed graphqlSchema contains the types
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
assert.Contains(t, sdl, "type Query")
assert.Contains(t, sdl, "type User")
assert.Contains(t, sdl, "email: String!")
graphqlSchema, ok := engineConfig["graphqlSchema"].(string)
require.True(t, ok, "Should have graphqlSchema")
assert.Contains(t, graphqlSchema, "Query", "Schema should contain Query type")
assert.Contains(t, graphqlSchema, "User", "Schema should contain User type")
// Check datasource has the original SDL
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
ds := dsConfigs[0].(map[string]interface{})
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
require.True(t, ok, "Should have customGraphql config")
federation, ok := customGraphql["federation"].(map[string]interface{})
require.True(t, ok, "Should have federation config")
serviceSdl, ok := federation["serviceSdl"].(string)
require.True(t, ok, "Should have serviceSdl")
assert.Contains(t, serviceSdl, "email: String!", "SDL should contain email field")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config, err := GenerateCosmoRouterConfig(tt.subGraphs)
// Use mock executor for all tests
mockExecutor := &MockCommandExecutor{}
config, err := GenerateCosmoRouterConfigWithExecutor(tt.subGraphs, mockExecutor)
if tt.wantErr {
assert.Error(t, err)
// Verify executor was not called for error cases
if len(tt.subGraphs) == 0 {
assert.Equal(t, 0, mockExecutor.CallCount, "Should not call executor for empty subgraphs")
}
return
}
require.NoError(t, err)
assert.NotEmpty(t, config, "Config should not be empty")
// Verify executor was called correctly
assert.Equal(t, 1, mockExecutor.CallCount, "Should call executor once")
assert.Equal(t, "wgc", mockExecutor.LastArgs[0], "Should call wgc command")
assert.Contains(t, mockExecutor.LastArgs, "router", "Should include 'router' arg")
assert.Contains(t, mockExecutor.LastArgs, "compose", "Should include 'compose' arg")
if tt.validate != nil {
tt.validate(t, config)
}
@@ -203,53 +432,31 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
}
}
func TestConvertSubGraphsToCosmo(t *testing.T) {
tests := []struct {
name string
subGraphs []*model.SubGraph
wantLen int
validate func(t *testing.T, result []map[string]interface{})
}{
// TestGenerateCosmoRouterConfig_MockError tests error handling with mock executor
func TestGenerateCosmoRouterConfig_MockError(t *testing.T) {
subGraphs := []*model.SubGraph{
{
name: "preserves subgraph order",
subGraphs: []*model.SubGraph{
{Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"},
{Service: "beta", URL: stringPtr("http://b"), Sdl: "b"},
{Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"},
},
wantLen: 3,
validate: func(t *testing.T, result []map[string]interface{}) {
assert.Equal(t, "alpha", result[0]["name"])
assert.Equal(t, "beta", result[1]["name"])
assert.Equal(t, "gamma", result[2]["name"])
},
},
{
name: "includes SDL exactly as provided",
subGraphs: []*model.SubGraph{
{
Service: "test",
URL: stringPtr("http://test"),
Sdl: "type Query { special: String! }",
},
},
wantLen: 1,
validate: func(t *testing.T, result []map[string]interface{}) {
assert.Equal(t, "type Query { special: String! }", result[0]["sdl"])
},
Service: "test-service",
URL: stringPtr("http://localhost:4001/query"),
Sdl: "type Query { test: String }",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertSubGraphsToCosmo(tt.subGraphs)
assert.Len(t, result, tt.wantLen)
if tt.validate != nil {
tt.validate(t, result)
}
})
// Create a mock executor that returns an error
mockExecutor := &MockCommandExecutor{
Error: fmt.Errorf("simulated wgc failure"),
}
config, err := GenerateCosmoRouterConfigWithExecutor(subGraphs, mockExecutor)
// Verify error is propagated
assert.Error(t, err)
assert.Contains(t, err.Error(), "wgc router compose failed")
assert.Contains(t, err.Error(), "simulated wgc failure")
assert.Empty(t, config)
// Verify executor was called
assert.Equal(t, 1, mockExecutor.CallCount, "Should have attempted to call executor")
}
// Helper function for tests
+116
View File
@@ -74,6 +74,7 @@ type ComplexityRoot struct {
}
Query struct {
LatestSchema func(childComplexity int, ref string) int
Organizations func(childComplexity int) int
Supergraph func(childComplexity int, ref string, isAfter *string) int
}
@@ -124,6 +125,7 @@ type MutationResolver interface {
type QueryResolver interface {
Organizations(ctx context.Context) ([]*model.Organization, error)
Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error)
LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error)
}
type SubscriptionResolver interface {
SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error)
@@ -250,6 +252,17 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.Organization.Users(childComplexity), true
case "Query.latestSchema":
if e.complexity.Query.LatestSchema == nil {
break
}
args, err := ec.field_Query_latestSchema_args(ctx, rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Query.LatestSchema(childComplexity, args["ref"].(string)), true
case "Query.organizations":
if e.complexity.Query.Organizations == nil {
break
@@ -520,6 +533,7 @@ var sources = []*ast.Source{
{Name: "../schema.graphqls", Input: `type Query {
organizations: [Organization!]! @auth(user: true)
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
}
type Mutation {
@@ -671,6 +685,17 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
return args, nil
}
func (ec *executionContext) field_Query_latestSchema_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string)
if err != nil {
return nil, err
}
args["ref"] = arg0
return args, nil
}
func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
@@ -1434,6 +1459,75 @@ func (ec *executionContext) fieldContext_Query_supergraph(ctx context.Context, f
return fc, nil
}
func (ec *executionContext) _Query_latestSchema(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
ec.OperationContext,
field,
ec.fieldContext_Query_latestSchema,
func(ctx context.Context) (any, error) {
fc := graphql.GetFieldContext(ctx)
return ec.resolvers.Query().LatestSchema(ctx, fc.Args["ref"].(string))
},
func(ctx context.Context, next graphql.Resolver) graphql.Resolver {
directive0 := next
directive1 := func(ctx context.Context) (any, error) {
organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true)
if err != nil {
var zeroVal *model.SchemaUpdate
return zeroVal, err
}
if ec.directives.Auth == nil {
var zeroVal *model.SchemaUpdate
return zeroVal, errors.New("directive auth is not implemented")
}
return ec.directives.Auth(ctx, nil, directive0, nil, organization)
}
next = directive1
return next
},
ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate,
true,
true,
)
}
func (ec *executionContext) fieldContext_Query_latestSchema(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Query",
Field: field,
IsMethod: true,
IsResolver: true,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
switch field.Name {
case "ref":
return ec.fieldContext_SchemaUpdate_ref(ctx, field)
case "id":
return ec.fieldContext_SchemaUpdate_id(ctx, field)
case "subGraphs":
return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field)
case "cosmoRouterConfig":
return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name)
},
}
defer func() {
if r := recover(); r != nil {
err = ec.Recover(ctx, r)
ec.Error(ctx, err)
}
}()
ctx = graphql.WithFieldContext(ctx, fc)
if fc.Args, err = ec.field_Query_latestSchema_args(ctx, field.ArgumentMap(ec.Variables)); err != nil {
ec.Error(ctx, err)
return fc, err
}
return fc, nil
}
func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
@@ -3997,6 +4091,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
case "latestSchema":
field := field
innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
}
}()
res = ec._Query_latestSchema(ctx, field)
if res == graphql.Null {
atomic.AddUint32(&fs.Invalids, 1)
}
return res
}
rrm := func(ctx context.Context) graphql.Marshaler {
return ec.OperationContext.RootResolverMiddleware(ctx,
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
case "__type":
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
+1
View File
@@ -1,6 +1,7 @@
type Query {
organizations: [Organization!]! @auth(user: true)
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
}
type Mutation {
+113 -5
View File
@@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
// Publish schema update to subscribers
go func() {
services, lastUpdate := r.Cache.Services(orgId, input.Ref, "")
r.Logger.Info("Publishing schema update after subgraph change",
"ref", input.Ref,
"orgId", orgId,
"lastUpdate", lastUpdate,
"servicesCount", len(services),
)
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(context.Background(), id)
@@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
}
// Publish to all subscribers of this ref
r.PubSub.Publish(input.Ref, &model.SchemaUpdate{
update := &model.SchemaUpdate{
Ref: input.Ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
})
}
r.Logger.Info("Publishing schema update to subscribers",
"ref", update.Ref,
"id", update.ID,
"subGraphsCount", len(update.SubGraphs),
"cosmoConfigLength", len(cosmoConfig),
)
r.PubSub.Publish(input.Ref, update)
}()
return r.toGqlSubGraph(subGraph), nil
@@ -222,11 +238,84 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str
}, nil
}
// LatestSchema is the resolver for the latestSchema field.
func (r *queryResolver) LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) {
orgId := middleware.OrganizationFromContext(ctx)
r.Logger.Info("LatestSchema query",
"ref", ref,
"orgId", orgId,
)
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
if err != nil {
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
return nil, err
}
// Get current services and schema
services, lastUpdate := r.Cache.Services(orgId, ref, "")
r.Logger.Info("Fetching latest schema",
"ref", ref,
"orgId", orgId,
"lastUpdate", lastUpdate,
"servicesCount", len(services),
)
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(ctx, id)
if err != nil {
r.Logger.Error("fetch subgraph", "error", err, "id", id)
return nil, err
}
subGraphs[i] = &model.SubGraph{
ID: sg.ID.String(),
Service: sg.Service,
URL: sg.Url,
WsURL: sg.WSUrl,
Sdl: sg.Sdl,
ChangedBy: sg.ChangedBy,
ChangedAt: sg.ChangedAt,
}
}
// Generate Cosmo router config
cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs)
if err != nil {
r.Logger.Error("generate cosmo config", "error", err)
cosmoConfig = "" // Return empty if generation fails
}
update := &model.SchemaUpdate{
Ref: ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
}
r.Logger.Info("Latest schema fetched",
"ref", update.Ref,
"id", update.ID,
"subGraphsCount", len(update.SubGraphs),
"cosmoConfigLength", len(cosmoConfig),
)
return update, nil
}
// SchemaUpdates is the resolver for the schemaUpdates field.
func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) {
orgId := middleware.OrganizationFromContext(ctx)
r.Logger.Info("SchemaUpdates subscription started",
"ref", ref,
"orgId", orgId,
)
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
if err != nil {
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
return nil, err
}
@@ -235,12 +324,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
// Send initial state immediately
go func() {
// Use background context for async operation
bgCtx := context.Background()
services, lastUpdate := r.Cache.Services(orgId, ref, "")
r.Logger.Info("Preparing initial schema update",
"ref", ref,
"orgId", orgId,
"lastUpdate", lastUpdate,
"servicesCount", len(services),
)
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(ctx, id)
sg, err := r.fetchSubGraph(bgCtx, id)
if err != nil {
r.Logger.Error("fetch subgraph for initial update", "error", err)
r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id)
continue
}
subGraphs[i] = &model.SubGraph{
@@ -262,12 +361,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
}
// Send initial update
ch <- &model.SchemaUpdate{
update := &model.SchemaUpdate{
Ref: ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
}
r.Logger.Info("Sending initial schema update",
"ref", update.Ref,
"id", update.ID,
"subGraphsCount", len(update.SubGraphs),
"cosmoConfigLength", len(cosmoConfig),
)
ch <- update
}()
// Clean up subscription when context is done
+63
View File
@@ -3,9 +3,72 @@ package hash
import (
"crypto/sha256"
"encoding/base64"
"golang.org/x/crypto/bcrypt"
)
// String creates a SHA256 hash of a string (legacy, for non-sensitive data)
func String(s string) string {
encoded := sha256.New().Sum([]byte(s))
return base64.StdEncoding.EncodeToString(encoded)
}
// APIKey hashes an API key using bcrypt for secure storage
// Cost of 12 provides a good balance between security and performance
func APIKey(key string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(key), 12)
if err != nil {
return "", err
}
return string(hash), nil
}
// CompareAPIKey compares a plaintext API key with a hash
// Supports both bcrypt (new) and SHA256 (legacy) hashes for backwards compatibility
// Returns true if they match, false otherwise
//
// Migration Strategy:
// Old API keys stored with SHA256 will continue to work. To upgrade them to bcrypt:
// 1. Keys are automatically upgraded when users re-authenticate (if implemented)
// 2. Or, run a one-time migration using MigrateAPIKeyHash when convenient
func CompareAPIKey(hashedKey, plainKey string) bool {
// Bcrypt hashes start with $2a$, $2b$, or $2y$
// If the hash starts with $2, it's a bcrypt hash
if len(hashedKey) > 2 && hashedKey[0] == '$' && hashedKey[1] == '2' {
// New bcrypt hash
err := bcrypt.CompareHashAndPassword([]byte(hashedKey), []byte(plainKey))
return err == nil
}
// Legacy SHA256 hash - compare using the old method
legacyHash := String(plainKey)
return hashedKey == legacyHash
}
// IsLegacyHash returns true if the hash is a legacy SHA256 hash (not bcrypt)
func IsLegacyHash(hashedKey string) bool {
return len(hashedKey) <= 2 || hashedKey[0] != '$' || hashedKey[1] != '2'
}
// MigrateAPIKeyHash can be used to upgrade a legacy SHA256 hash to bcrypt
// This is useful for one-time migrations of existing keys
// Returns the new bcrypt hash if the key is legacy, otherwise returns the original
func MigrateAPIKeyHash(currentHash, plainKey string) (string, bool, error) {
// If already bcrypt, no migration needed
if !IsLegacyHash(currentHash) {
return currentHash, false, nil
}
// Verify the legacy hash is correct before migrating
if !CompareAPIKey(currentHash, plainKey) {
return "", false, nil // Invalid key, don't migrate
}
// Generate new bcrypt hash
newHash, err := APIKey(plainKey)
if err != nil {
return "", false, err
}
return newHash, true, nil
}
+169
View File
@@ -0,0 +1,169 @@
package hash
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAPIKey(t *testing.T) {
key := "test_api_key_12345" // gitleaks:allow
hash1, err := APIKey(key)
require.NoError(t, err)
assert.NotEmpty(t, hash1)
assert.NotEqual(t, key, hash1, "Hash should not equal plaintext")
// Bcrypt hashes should start with $2
assert.True(t, strings.HasPrefix(hash1, "$2"), "Should be a bcrypt hash")
// Same key should produce different hashes (due to salt)
hash2, err := APIKey(key)
require.NoError(t, err)
assert.NotEqual(t, hash1, hash2, "Bcrypt should produce different hashes with different salts")
}
func TestCompareAPIKey_Bcrypt(t *testing.T) {
key := "test_api_key_12345" // gitleaks:allow
hash, err := APIKey(key)
require.NoError(t, err)
// Correct key should match
assert.True(t, CompareAPIKey(hash, key))
// Wrong key should not match
assert.False(t, CompareAPIKey(hash, "wrong_key"))
}
func TestCompareAPIKey_Legacy(t *testing.T) {
key := "test_api_key_12345" // gitleaks:allow
// Create a legacy SHA256 hash
legacyHash := String(key)
// Should still work with legacy hashes
assert.True(t, CompareAPIKey(legacyHash, key))
// Wrong key should not match
assert.False(t, CompareAPIKey(legacyHash, "wrong_key"))
}
func TestCompareAPIKey_BackwardCompatibility(t *testing.T) {
tests := []struct {
name string
hashFunc func(string) string
expectOK bool
}{
{
name: "bcrypt hash",
hashFunc: func(k string) string {
h, _ := APIKey(k)
return h
},
expectOK: true,
},
{
name: "legacy SHA256 hash",
hashFunc: func(k string) string {
return String(k)
},
expectOK: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
key := "test_key_123"
hash := tt.hashFunc(key)
result := CompareAPIKey(hash, key)
assert.Equal(t, tt.expectOK, result)
})
}
}
func TestString(t *testing.T) {
// Test that String function still works (for non-sensitive data)
input := "test_string"
hash1 := String(input)
hash2 := String(input)
// SHA256 should be deterministic
assert.Equal(t, hash1, hash2)
assert.NotEmpty(t, hash1)
assert.NotEqual(t, input, hash1)
}
func TestIsLegacyHash(t *testing.T) {
tests := []struct {
name string
hash string
isLegacy bool
}{
{
name: "bcrypt hash",
hash: "$2a$12$abcdefghijklmnopqrstuv",
isLegacy: false,
},
{
name: "SHA256 hash",
hash: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY",
isLegacy: true,
},
{
name: "empty string",
hash: "",
isLegacy: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.isLegacy, IsLegacyHash(tt.hash))
})
}
}
func TestMigrateAPIKeyHash(t *testing.T) {
plainKey := "test_api_key_123"
t.Run("migrate legacy hash", func(t *testing.T) {
// Create a legacy SHA256 hash
legacyHash := String(plainKey)
// Migrate it
newHash, migrated, err := MigrateAPIKeyHash(legacyHash, plainKey)
require.NoError(t, err)
assert.True(t, migrated, "Should indicate migration occurred")
assert.NotEqual(t, legacyHash, newHash, "New hash should differ from legacy")
assert.True(t, strings.HasPrefix(newHash, "$2"), "New hash should be bcrypt")
// Verify new hash works
assert.True(t, CompareAPIKey(newHash, plainKey))
})
t.Run("no migration needed for bcrypt", func(t *testing.T) {
// Create a bcrypt hash
bcryptHash, err := APIKey(plainKey)
require.NoError(t, err)
// Try to migrate it
newHash, migrated, err := MigrateAPIKeyHash(bcryptHash, plainKey)
require.NoError(t, err)
assert.False(t, migrated, "Should not migrate bcrypt hash")
assert.Equal(t, bcryptHash, newHash, "Hash should remain unchanged")
})
t.Run("invalid key does not migrate", func(t *testing.T) {
legacyHash := String("correct_key")
// Try to migrate with wrong plaintext
newHash, migrated, err := MigrateAPIKeyHash(legacyHash, "wrong_key")
require.NoError(t, err)
assert.False(t, migrated, "Should not migrate invalid key")
assert.Empty(t, newHash, "Should return empty for invalid key")
})
}
+73
View File
@@ -0,0 +1,73 @@
package health
import (
"context"
"database/sql"
"encoding/json"
"log/slog"
"net/http"
"time"
)
type Checker struct {
db *sql.DB
logger *slog.Logger
}
func New(db *sql.DB, logger *slog.Logger) *Checker {
return &Checker{
db: db,
logger: logger,
}
}
type HealthStatus struct {
Status string `json:"status"`
Checks map[string]string `json:"checks,omitempty"`
}
// LivenessHandler checks if the application is running
// This is a simple check that always returns OK if the handler is reached
func (h *Checker) LivenessHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(HealthStatus{
Status: "UP",
})
}
// ReadinessHandler checks if the application is ready to accept traffic
// This checks database connectivity and other critical dependencies
func (h *Checker) ReadinessHandler(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
checks := make(map[string]string)
allHealthy := true
// Check database connectivity
if err := h.db.PingContext(ctx); err != nil {
h.logger.With("error", err).Warn("database health check failed")
checks["database"] = "DOWN"
allHealthy = false
} else {
checks["database"] = "UP"
}
status := HealthStatus{
Status: "UP",
Checks: checks,
}
if !allHealthy {
status.Status = "DOWN"
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
_ = json.NewEncoder(w).Encode(status)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(status)
}
+75
View File
@@ -0,0 +1,75 @@
package health
import (
"database/sql"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLivenessHandler(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
checker := New(db, logger)
req := httptest.NewRequest(http.MethodGet, "/health/live", nil)
rec := httptest.NewRecorder()
checker.LivenessHandler(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), `"status":"UP"`)
}
func TestReadinessHandler_Healthy(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()
// Expect a ping and return success
mock.ExpectPing().WillReturnError(nil)
checker := New(db, logger)
req := httptest.NewRequest(http.MethodGet, "/health/ready", nil)
rec := httptest.NewRecorder()
checker.ReadinessHandler(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), `"status":"UP"`)
assert.Contains(t, rec.Body.String(), `"database":"UP"`)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestReadinessHandler_DatabaseDown(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()
// Expect a ping and return error
mock.ExpectPing().WillReturnError(sql.ErrConnDone)
checker := New(db, logger)
req := httptest.NewRequest(http.MethodGet, "/health/ready", nil)
rec := httptest.NewRecorder()
checker.ReadinessHandler(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
assert.Contains(t, rec.Body.String(), `"status":"DOWN"`)
assert.Contains(t, rec.Body.String(), `"database":"DOWN"`)
assert.NoError(t, mock.ExpectationsWereMet())
}
+10 -1
View File
@@ -44,13 +44,22 @@ spec:
requests:
cpu: "20m"
memory: "20Mi"
livenessProbe:
httpGet:
path: /health/live
port: 8080
initialDelaySeconds: 10
periodSeconds: 10
timeoutSeconds: 5
failureThreshold: 3
readinessProbe:
httpGet:
path: /health
path: /health/ready
port: 8080
initialDelaySeconds: 5
periodSeconds: 5
timeoutSeconds: 5
failureThreshold: 3
imagePullPolicy: IfNotPresent
image: registry.gitlab.com/unboundsoftware/schemas:${COMMIT}
ports:
+3 -2
View File
@@ -9,7 +9,6 @@ import (
"github.com/golang-jwt/jwt/v5"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
)
const (
@@ -49,7 +48,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
_, _ = w.Write([]byte("Invalid API Key format"))
return
}
if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
// Cache handles hash comparison internally
organization := m.cache.OrganizationByAPIKey(apiKey)
if organization != nil {
ctx = context.WithValue(ctx, OrganizationKey, *organization)
}
+464
View File
@@ -0,0 +1,464 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
mw "github.com/auth0/go-jwt-middleware/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
)
// MockCache is a mock implementation of the Cache interface
type MockCache struct {
mock.Mock
}
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
args := m.Called(apiKey)
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*domain.Organization)
}
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
// Mock expects plaintext key (cache handles hashing internally)
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
apiKey := "invalid-api-key"
// Mock expects plaintext key (cache handles hashing internally)
mockCache.On("OrganizationByAPIKey", apiKey).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware passes the plaintext API key (cache handles hashing)
mockCache.On("OrganizationByAPIKey", "").Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request without API key
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware passes the plaintext API key (cache handles hashing)
mockCache.On("OrganizationByAPIKey", "").Return(nil)
userID := "user-123"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
// Create a test handler that checks the context
var capturedUser string
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with JWT token in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
}
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid API key type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
}
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid JWT token type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
}
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
userID := "user-123"
apiKey := "test-api-key-123"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
// Mock expects plaintext key (cache handles hashing internally)
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
// Create a test handler that checks both user and organization in context
var capturedUser string
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with both JWT and API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
ctx = context.WithValue(ctx, ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
mockCache.AssertExpectations(t)
}
func TestUserFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid user",
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
expected: "user-123",
},
{
name: "without user",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), UserKey, 123),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UserFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestOrganizationFromContext(t *testing.T) {
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid organization",
ctx: context.WithValue(context.Background(), OrganizationKey, org),
expected: orgID.String(),
},
{
name: "without organization",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OrganizationFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
// Test with user present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.NoError(t, err)
// Test without user
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no user available in request")
}
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with organization present
ctx := context.WithValue(context.Background(), OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.NoError(t, err)
// Test without organization
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no organization available in request")
}
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with both present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
ctx = context.WithValue(ctx, OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.NoError(t, err)
// Test with only user
ctx = context.WithValue(context.Background(), UserKey, "user-123")
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
// Test with only organization
ctx = context.WithValue(context.Background(), OrganizationKey, org)
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
}
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// Test with no requirements
ctx := context.Background()
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, nil)
assert.NoError(t, err)
assert.Equal(t, "success", result)
}