Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 06aeedc3b0 | |||
| fce85782f0 | |||
| 9cd8218eb4 | |||
| 98ef62b144 | |||
| e0cdd2aa58 | |||
| e22e8b339c | |||
|
6404f7a497
|
|||
| 5dc5043d46 | |||
| bcca005256 | |||
|
a9dea19531
|
|||
|
130e92dc5f
|
|||
| c4112a005f | |||
| 549f6617df | |||
| a1b0f49aab | |||
|
4468903535
|
|||
| df054ca451 | |||
| 1e2236dc9e | |||
|
6ccd7f4f25
|
|||
| b1a46f9d4e | |||
|
47dbf827f2
|
|||
|
df44ddbb8e
|
|||
|
9368d77bc8
|
|||
|
4d18cf4175
|
|||
|
bb0c08be06
|
|||
| a9a47c1690 | |||
| de073ce2da |
@@ -1,9 +1,11 @@
|
||||
.idea
|
||||
.claude
|
||||
.testCoverage.txt
|
||||
.testCoverage.txt.tmp
|
||||
coverage.html
|
||||
/exported
|
||||
/release
|
||||
/schemactl
|
||||
/service
|
||||
CHANGES.md
|
||||
VERSION
|
||||
|
||||
@@ -41,7 +41,7 @@ repos:
|
||||
hooks:
|
||||
- id: golangci-lint-full
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.29.0
|
||||
rev: v8.29.1
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
exclude: '^ctl/generated.go|graph/generated/.*$|^graph/model/models_gen.go|^tools/.*$$'
|
||||
|
||||
@@ -2,6 +2,34 @@
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
## [0.8.0] - 2025-11-21
|
||||
|
||||
### 🚀 Features
|
||||
|
||||
- *(tests)* Add unit tests for WebSocket initialization logic
|
||||
- Add latestSchema query for retrieving schema updates
|
||||
- Add CLAUDE.md for project documentation and guidelines
|
||||
- *(cache)* Implement hashed API key storage and retrieval
|
||||
- *(health)* Add health checking endpoints and logic
|
||||
- *(cache)* Add concurrency safety and logging improvements
|
||||
|
||||
### 🐛 Bug Fixes
|
||||
|
||||
- Enhance API key handling and logging in middleware
|
||||
- Add command executor interface for better testing
|
||||
- *(deps)* Update module golang.org/x/crypto to v0.45.0
|
||||
- *(deps)* Update module github.com/auth0/go-jwt-middleware/v2 to v2.3.1
|
||||
|
||||
### 🧪 Testing
|
||||
|
||||
- Enhance assertions for version and subscription config
|
||||
- *(cache)* Reduce goroutines for race detector stability
|
||||
|
||||
### ⚙️ Miscellaneous Tasks
|
||||
|
||||
- *(deps)* Update pre-commit hook gitleaks/gitleaks to v8.29.1
|
||||
- *(deps)* Update node.js to v24
|
||||
|
||||
## [0.7.0] - 2025-11-19
|
||||
|
||||
### 🚀 Features
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is a GraphQL schema registry service that manages federated GraphQL schemas for microservices. It allows services to publish their subgraph schemas and provides merged supergraphs with Cosmo Router configuration for federated GraphQL gateways.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Event Sourcing
|
||||
The system uses event sourcing via `gitlab.com/unboundsoftware/eventsourced`. Key domain aggregates are:
|
||||
- **Organization** (domain/aggregates.go): Manages organizations, users, and API keys
|
||||
- **SubGraph** (domain/aggregates.go): Tracks subgraph schemas with versioning
|
||||
|
||||
All state changes flow through events (domain/events.go) and commands (domain/commands.go). The EventStore persists events to PostgreSQL, and events are published to RabbitMQ for downstream consumers.
|
||||
|
||||
### GraphQL Layer
|
||||
- **Schema**: graph/schema.graphqls defines the API
|
||||
- **Resolvers**: graph/schema.resolvers.go implements mutations/queries
|
||||
- **Generated Code**: graph/generated/ and graph/model/ (auto-generated by gqlgen)
|
||||
|
||||
The resolver (graph/resolver.go) coordinates between the EventStore, Publisher (RabbitMQ), Cache, and PubSub for subscriptions.
|
||||
|
||||
### Schema Merging
|
||||
The sdlmerge/ package handles GraphQL schema federation:
|
||||
- Merges multiple subgraph SDL schemas into a unified supergraph
|
||||
- Uses wundergraph/graphql-go-tools for AST manipulation
|
||||
- Removes duplicates, extends types, and applies federation directives
|
||||
|
||||
### Authentication & Authorization
|
||||
- **Auth0 JWT** (middleware/auth0.go): Validates user tokens from Auth0
|
||||
- **API Keys** (middleware/apikey.go): Validates service API keys
|
||||
- **Auth Middleware** (middleware/auth.go): Routes auth based on context
|
||||
|
||||
The @auth directive controls field-level access (user vs organization API key).
|
||||
|
||||
### Cosmo Router Integration
|
||||
The service generates Cosmo Router configuration (graph/cosmo.go) using the wgc CLI tool installed in the Docker container. This config enables federated query execution across subgraphs.
|
||||
|
||||
### PubSub for Real-time Updates
|
||||
graph/pubsub.go implements subscription support for schemaUpdates, allowing clients to receive real-time notifications when schemas change.
|
||||
|
||||
## Commands
|
||||
|
||||
### Code Generation
|
||||
```bash
|
||||
# Generate GraphQL server code (gqlgen), format, and organize imports
|
||||
go generate ./...
|
||||
```
|
||||
|
||||
Always run this after modifying graph/schema.graphqls. The go:generate directives are in:
|
||||
- graph/resolver.go: runs gqlgen, gofumpt, and goimports
|
||||
- ctl/ctl.go: generates genqlient client code
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./... -v
|
||||
|
||||
# Run tests with race detection and coverage (as used in CI)
|
||||
CGO_ENABLED=1 go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
# Run specific package tests
|
||||
go test ./middleware -v
|
||||
go test ./graph -v -run TestGenerateCosmoRouterConfig
|
||||
|
||||
# Run single test
|
||||
go test ./cmd/service -v -run TestWebSocket
|
||||
```
|
||||
|
||||
### Building
|
||||
```bash
|
||||
# Build the service binary
|
||||
go build -o service ./cmd/service/service.go
|
||||
|
||||
# Build the CLI tool
|
||||
go build -o schemactl ./cmd/schemactl/schemactl.go
|
||||
|
||||
# Docker build (multi-stage)
|
||||
docker build -t schemas .
|
||||
```
|
||||
|
||||
The Dockerfile runs tests with coverage before building the production binary.
|
||||
|
||||
### Running the Service
|
||||
```bash
|
||||
# Start the service (requires PostgreSQL and RabbitMQ)
|
||||
go run ./cmd/service/service.go \
|
||||
--postgres-url="postgres://user:pass@localhost:5432/schemas?sslmode=disable" \
|
||||
--amqp-url="amqp://user:pass@localhost:5672/" \
|
||||
--issuer="your-auth0-domain.auth0.com"
|
||||
|
||||
# The service listens on port 8080 by default
|
||||
# GraphQL Playground available at http://localhost:8080/
|
||||
```
|
||||
|
||||
### Using the schemactl CLI
|
||||
```bash
|
||||
# Publish a subgraph schema
|
||||
schemactl publish \
|
||||
--api-key="your-api-key" \
|
||||
--schema-ref="production" \
|
||||
--service="users" \
|
||||
--url="http://users-service:8080/query" \
|
||||
--sdl=schema.graphql
|
||||
|
||||
# List subgraphs for a ref
|
||||
schemactl list \
|
||||
--api-key="your-api-key" \
|
||||
--schema-ref="production"
|
||||
```
|
||||
|
||||
## Development Workflow
|
||||
|
||||
1. **Schema Changes**: Edit graph/schema.graphqls → run `go generate ./...`
|
||||
2. **Resolver Implementation**: Implement in graph/schema.resolvers.go
|
||||
3. **Testing**: Write tests, run `go test ./...`
|
||||
4. **Pre-commit**: Hooks run go-mod-tidy, goimports, gofumpt, golangci-lint, and tests
|
||||
|
||||
## Key Dependencies
|
||||
|
||||
- **gqlgen**: GraphQL server generation
|
||||
- **genqlient**: GraphQL client generation (for ctl package)
|
||||
- **eventsourced**: Event sourcing framework
|
||||
- **wundergraph/graphql-go-tools**: Schema federation and composition
|
||||
- **wgc CLI**: Cosmo Router config generation (Node.js tool)
|
||||
- **Auth0**: JWT authentication
|
||||
- **OpenTelemetry**: Observability (traces, metrics, logs)
|
||||
|
||||
## Important Files
|
||||
|
||||
- gqlgen.yml: gqlgen configuration
|
||||
- graph/tools.go: Declares build-time tool dependencies
|
||||
- .pre-commit-config.yaml: Pre-commit hooks configuration
|
||||
- cliff.toml: Changelog generation config
|
||||
+9
-1
@@ -24,9 +24,17 @@ RUN GOOS=linux GOARCH=amd64 go build \
|
||||
FROM scratch as export
|
||||
COPY --from=build /build/coverage.txt /
|
||||
|
||||
FROM scratch
|
||||
FROM node:24-alpine@sha256:2867d550cf9d8bb50059a0fff528741f11a84d985c732e60e19e8e75c7239c43
|
||||
ENV TZ Europe/Stockholm
|
||||
|
||||
# Install wgc CLI globally for Cosmo Router composition
|
||||
RUN npm install -g wgc@latest
|
||||
|
||||
# Copy timezone data and certificates
|
||||
COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
|
||||
# Copy the service binary
|
||||
COPY --from=build /release/service /
|
||||
|
||||
CMD ["/service"]
|
||||
|
||||
Vendored
+67
-20
@@ -3,18 +3,21 @@ package cache
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sparetimecoders/goamqp"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
organizations map[string]domain.Organization
|
||||
users map[string][]string
|
||||
apiKeys map[string]domain.APIKey
|
||||
apiKeys map[string]domain.APIKey // keyed by organizationId-name
|
||||
services map[string]map[string]map[string]struct{}
|
||||
subGraphs map[string]string
|
||||
lastUpdate map[string]string
|
||||
@@ -22,18 +25,26 @@ type Cache struct {
|
||||
}
|
||||
|
||||
func (c *Cache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
key, exists := c.apiKeys[apiKey]
|
||||
if !exists {
|
||||
return nil
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Find the API key by comparing hashes
|
||||
for _, key := range c.apiKeys {
|
||||
if hash.CompareAPIKey(key.Key, apiKey) {
|
||||
org, exists := c.organizations[key.OrganizationId]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return &org
|
||||
}
|
||||
}
|
||||
org, exists := c.organizations[key.OrganizationId]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return &org
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) OrganizationsByUser(sub string) []domain.Organization {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
orgIds := c.users[sub]
|
||||
orgs := make([]domain.Organization, len(orgIds))
|
||||
for i, id := range orgIds {
|
||||
@@ -43,14 +54,22 @@ func (c *Cache) OrganizationsByUser(sub string) []domain.Organization {
|
||||
}
|
||||
|
||||
func (c *Cache) ApiKeyByKey(key string) *domain.APIKey {
|
||||
k, exists := c.apiKeys[hash.String(key)]
|
||||
if !exists {
|
||||
return nil
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// Find the API key by comparing hashes
|
||||
for _, apiKey := range c.apiKeys {
|
||||
if hash.CompareAPIKey(apiKey.Key, key) {
|
||||
return &apiKey
|
||||
}
|
||||
}
|
||||
return &k
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
key := refKey(orgId, ref)
|
||||
var services []string
|
||||
if lastUpdate == "" || c.lastUpdate[key] > lastUpdate {
|
||||
@@ -62,41 +81,56 @@ func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) {
|
||||
}
|
||||
|
||||
func (c *Cache) SubGraphId(orgId, ref, service string) string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
return c.subGraphs[subGraphKey(orgId, ref, service)]
|
||||
}
|
||||
|
||||
func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
switch m := msg.(type) {
|
||||
case *domain.OrganizationAdded:
|
||||
o := domain.Organization{}
|
||||
o := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(m.ID.String()),
|
||||
}
|
||||
m.UpdateOrganization(&o)
|
||||
c.organizations[m.ID.String()] = o
|
||||
c.addUser(m.Initiator, o)
|
||||
c.logger.With("org_id", m.ID.String(), "event", "OrganizationAdded").Debug("cache updated")
|
||||
case *domain.APIKeyAdded:
|
||||
key := domain.APIKey{
|
||||
Name: m.Name,
|
||||
OrganizationId: m.OrganizationId,
|
||||
Key: m.Key,
|
||||
Key: m.Key, // This is now the hashed key
|
||||
Refs: m.Refs,
|
||||
Read: m.Read,
|
||||
Publish: m.Publish,
|
||||
CreatedBy: m.Initiator,
|
||||
CreatedAt: m.When(),
|
||||
}
|
||||
c.apiKeys[m.Key] = key
|
||||
// Use composite key: organizationId-name
|
||||
c.apiKeys[apiKeyId(m.OrganizationId, m.Name)] = key
|
||||
org := c.organizations[m.OrganizationId]
|
||||
org.APIKeys = append(org.APIKeys, key)
|
||||
c.organizations[m.OrganizationId] = org
|
||||
c.logger.With("org_id", m.OrganizationId, "key_name", m.Name, "event", "APIKeyAdded").Debug("cache updated")
|
||||
case *domain.SubGraphUpdated:
|
||||
c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.Time)
|
||||
c.logger.With("org_id", m.OrganizationId, "ref", m.Ref, "service", m.Service, "event", "SubGraphUpdated").Debug("cache updated")
|
||||
case *domain.Organization:
|
||||
c.organizations[m.ID.String()] = *m
|
||||
c.addUser(m.CreatedBy, *m)
|
||||
for _, k := range m.APIKeys {
|
||||
c.apiKeys[k.Key] = k
|
||||
// Use composite key: organizationId-name
|
||||
c.apiKeys[apiKeyId(k.OrganizationId, k.Name)] = k
|
||||
}
|
||||
c.logger.With("org_id", m.ID.String(), "event", "Organization aggregate loaded").Debug("cache updated")
|
||||
case *domain.SubGraph:
|
||||
c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.ChangedAt)
|
||||
c.logger.With("org_id", m.OrganizationId, "ref", m.Ref, "service", m.Service, "event", "SubGraph aggregate loaded").Debug("cache updated")
|
||||
default:
|
||||
c.logger.With("msg", msg).Warn("unexpected message received")
|
||||
}
|
||||
@@ -117,11 +151,20 @@ func (c *Cache) updateSubGraph(orgId string, ref string, subGraphId string, serv
|
||||
|
||||
func (c *Cache) addUser(sub string, organization domain.Organization) {
|
||||
user, exists := c.users[sub]
|
||||
orgId := organization.ID.String()
|
||||
if !exists {
|
||||
c.users[sub] = []string{organization.ID.String()}
|
||||
} else {
|
||||
c.users[sub] = append(user, organization.ID.String())
|
||||
c.users[sub] = []string{orgId}
|
||||
return
|
||||
}
|
||||
|
||||
// Check if organization already exists for this user
|
||||
for _, id := range user {
|
||||
if id == orgId {
|
||||
return // Already exists, no need to add
|
||||
}
|
||||
}
|
||||
|
||||
c.users[sub] = append(user, orgId)
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger) *Cache {
|
||||
@@ -143,3 +186,7 @@ func refKey(orgId string, ref string) string {
|
||||
func subGraphKey(orgId string, ref string, service string) string {
|
||||
return fmt.Sprintf("%s<->%s<->%s", orgId, ref, service)
|
||||
}
|
||||
|
||||
func apiKeyId(orgId string, name string) string {
|
||||
return fmt.Sprintf("%s<->%s", orgId, name)
|
||||
}
|
||||
|
||||
Vendored
+447
@@ -0,0 +1,447 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
func TestCache_OrganizationByAPIKey(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
apiKey := "test-api-key-123" // gitleaks:allow
|
||||
hashedKey, err := hash.APIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add organization to cache
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
|
||||
Name: "Test Org",
|
||||
}
|
||||
c.organizations[orgID] = org
|
||||
|
||||
// Add API key to cache
|
||||
c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{
|
||||
Name: "test-key",
|
||||
OrganizationId: orgID,
|
||||
Key: hashedKey,
|
||||
Refs: []string{"main"},
|
||||
Read: true,
|
||||
Publish: true,
|
||||
}
|
||||
|
||||
// Test finding organization by plaintext API key
|
||||
foundOrg := c.OrganizationByAPIKey(apiKey)
|
||||
require.NotNil(t, foundOrg)
|
||||
assert.Equal(t, org.Name, foundOrg.Name)
|
||||
assert.Equal(t, orgID, foundOrg.ID.String())
|
||||
|
||||
// Test with wrong API key
|
||||
notFoundOrg := c.OrganizationByAPIKey("wrong-key")
|
||||
assert.Nil(t, notFoundOrg)
|
||||
}
|
||||
|
||||
func TestCache_OrganizationByAPIKey_Legacy(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
apiKey := "legacy-api-key-456" // gitleaks:allow
|
||||
legacyHash := hash.String(apiKey)
|
||||
|
||||
// Add organization to cache
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
|
||||
Name: "Legacy Org",
|
||||
}
|
||||
c.organizations[orgID] = org
|
||||
|
||||
// Add API key with legacy SHA256 hash
|
||||
c.apiKeys[apiKeyId(orgID, "legacy-key")] = domain.APIKey{
|
||||
Name: "legacy-key",
|
||||
OrganizationId: orgID,
|
||||
Key: legacyHash,
|
||||
Refs: []string{"main"},
|
||||
Read: true,
|
||||
Publish: false,
|
||||
}
|
||||
|
||||
// Test finding organization with legacy hash
|
||||
foundOrg := c.OrganizationByAPIKey(apiKey)
|
||||
require.NotNil(t, foundOrg)
|
||||
assert.Equal(t, org.Name, foundOrg.Name)
|
||||
}
|
||||
|
||||
func TestCache_OrganizationsByUser(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
userSub := "user-123"
|
||||
org1ID := uuid.New().String()
|
||||
org2ID := uuid.New().String()
|
||||
|
||||
org1 := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(org1ID),
|
||||
Name: "Org 1",
|
||||
}
|
||||
org2 := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(org2ID),
|
||||
Name: "Org 2",
|
||||
}
|
||||
|
||||
c.organizations[org1ID] = org1
|
||||
c.organizations[org2ID] = org2
|
||||
c.users[userSub] = []string{org1ID, org2ID}
|
||||
|
||||
orgs := c.OrganizationsByUser(userSub)
|
||||
assert.Len(t, orgs, 2)
|
||||
assert.Contains(t, []string{orgs[0].Name, orgs[1].Name}, "Org 1")
|
||||
assert.Contains(t, []string{orgs[0].Name, orgs[1].Name}, "Org 2")
|
||||
}
|
||||
|
||||
func TestCache_ApiKeyByKey(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
apiKey := "test-api-key-789" // gitleaks:allow
|
||||
hashedKey, err := hash.APIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedKey := domain.APIKey{
|
||||
Name: "test-key",
|
||||
OrganizationId: orgID,
|
||||
Key: hashedKey,
|
||||
Refs: []string{"main", "dev"},
|
||||
Read: true,
|
||||
Publish: true,
|
||||
}
|
||||
|
||||
c.apiKeys[apiKeyId(orgID, "test-key")] = expectedKey
|
||||
|
||||
foundKey := c.ApiKeyByKey(apiKey)
|
||||
require.NotNil(t, foundKey)
|
||||
assert.Equal(t, expectedKey.Name, foundKey.Name)
|
||||
assert.Equal(t, expectedKey.OrganizationId, foundKey.OrganizationId)
|
||||
assert.Equal(t, expectedKey.Refs, foundKey.Refs)
|
||||
|
||||
// Test with wrong key
|
||||
notFoundKey := c.ApiKeyByKey("wrong-key")
|
||||
assert.Nil(t, notFoundKey)
|
||||
}
|
||||
|
||||
func TestCache_Services(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
ref := "main"
|
||||
service1 := "service-1"
|
||||
service2 := "service-2"
|
||||
lastUpdate := "2024-01-01T12:00:00Z"
|
||||
|
||||
c.services[orgID] = map[string]map[string]struct{}{
|
||||
ref: {
|
||||
service1: {},
|
||||
service2: {},
|
||||
},
|
||||
}
|
||||
c.lastUpdate[refKey(orgID, ref)] = lastUpdate
|
||||
|
||||
// Test getting services with empty lastUpdate
|
||||
services, returnedLastUpdate := c.Services(orgID, ref, "")
|
||||
assert.Len(t, services, 2)
|
||||
assert.Contains(t, services, service1)
|
||||
assert.Contains(t, services, service2)
|
||||
assert.Equal(t, lastUpdate, returnedLastUpdate)
|
||||
|
||||
// Test with older lastUpdate (should return services)
|
||||
services, returnedLastUpdate = c.Services(orgID, ref, "2023-12-31T12:00:00Z")
|
||||
assert.Len(t, services, 2)
|
||||
assert.Equal(t, lastUpdate, returnedLastUpdate)
|
||||
|
||||
// Test with newer lastUpdate (should return empty)
|
||||
services, returnedLastUpdate = c.Services(orgID, ref, "2024-01-02T12:00:00Z")
|
||||
assert.Len(t, services, 0)
|
||||
assert.Equal(t, lastUpdate, returnedLastUpdate)
|
||||
}
|
||||
|
||||
func TestCache_SubGraphId(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
ref := "main"
|
||||
service := "test-service"
|
||||
subGraphID := uuid.New().String()
|
||||
|
||||
c.subGraphs[subGraphKey(orgID, ref, service)] = subGraphID
|
||||
|
||||
foundID := c.SubGraphId(orgID, ref, service)
|
||||
assert.Equal(t, subGraphID, foundID)
|
||||
|
||||
// Test with non-existent key
|
||||
notFoundID := c.SubGraphId("wrong-org", ref, service)
|
||||
assert.Empty(t, notFoundID)
|
||||
}
|
||||
|
||||
func TestCache_Update_OrganizationAdded(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
event := &domain.OrganizationAdded{
|
||||
Name: "New Org",
|
||||
Initiator: "user-123",
|
||||
}
|
||||
event.ID = *eventsourced.IdFromString(orgID)
|
||||
|
||||
_, err := c.Update(event, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify organization was added
|
||||
org, exists := c.organizations[orgID]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "New Org", org.Name)
|
||||
|
||||
// Verify user was added
|
||||
assert.Contains(t, c.users["user-123"], orgID)
|
||||
}
|
||||
|
||||
func TestCache_Update_APIKeyAdded(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
keyName := "test-key"
|
||||
hashedKey := "hashed-key-value"
|
||||
|
||||
// Add organization first
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
|
||||
Name: "Test Org",
|
||||
APIKeys: []domain.APIKey{},
|
||||
}
|
||||
c.organizations[orgID] = org
|
||||
|
||||
event := &domain.APIKeyAdded{
|
||||
OrganizationId: orgID,
|
||||
Name: keyName,
|
||||
Key: hashedKey,
|
||||
Refs: []string{"main"},
|
||||
Read: true,
|
||||
Publish: false,
|
||||
Initiator: "user-123",
|
||||
}
|
||||
event.ID = *eventsourced.IdFromString(uuid.New().String())
|
||||
|
||||
_, err := c.Update(event, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify API key was added to cache
|
||||
key, exists := c.apiKeys[apiKeyId(orgID, keyName)]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, keyName, key.Name)
|
||||
assert.Equal(t, hashedKey, key.Key)
|
||||
assert.Equal(t, []string{"main"}, key.Refs)
|
||||
|
||||
// Verify API key was added to organization
|
||||
updatedOrg := c.organizations[orgID]
|
||||
assert.Len(t, updatedOrg.APIKeys, 1)
|
||||
assert.Equal(t, keyName, updatedOrg.APIKeys[0].Name)
|
||||
}
|
||||
|
||||
func TestCache_Update_SubGraphUpdated(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
orgID := uuid.New().String()
|
||||
ref := "main"
|
||||
service := "test-service"
|
||||
subGraphID := uuid.New().String()
|
||||
|
||||
event := &domain.SubGraphUpdated{
|
||||
OrganizationId: orgID,
|
||||
Ref: ref,
|
||||
Service: service,
|
||||
Initiator: "user-123",
|
||||
}
|
||||
event.ID = *eventsourced.IdFromString(subGraphID)
|
||||
event.SetWhen(time.Now())
|
||||
|
||||
_, err := c.Update(event, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify subgraph was added to services
|
||||
assert.Contains(t, c.services[orgID][ref], subGraphID)
|
||||
|
||||
// Verify subgraph ID was stored
|
||||
assert.Equal(t, subGraphID, c.subGraphs[subGraphKey(orgID, ref, service)])
|
||||
|
||||
// Verify lastUpdate was set
|
||||
assert.NotEmpty(t, c.lastUpdate[refKey(orgID, ref)])
|
||||
}
|
||||
|
||||
func TestCache_AddUser_NoDuplicates(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
userSub := "user-123"
|
||||
orgID := uuid.New().String()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Add user first time
|
||||
c.addUser(userSub, org)
|
||||
assert.Len(t, c.users[userSub], 1)
|
||||
assert.Equal(t, orgID, c.users[userSub][0])
|
||||
|
||||
// Add same user/org again - should not create duplicate
|
||||
c.addUser(userSub, org)
|
||||
assert.Len(t, c.users[userSub], 1, "Should not add duplicate organization")
|
||||
assert.Equal(t, orgID, c.users[userSub][0])
|
||||
}
|
||||
|
||||
func TestCache_ConcurrentReads(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
// Setup test data
|
||||
orgID := uuid.New().String()
|
||||
apiKey := "test-concurrent-key" // gitleaks:allow
|
||||
hashedKey, err := hash.APIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
|
||||
Name: "Concurrent Test Org",
|
||||
}
|
||||
c.organizations[orgID] = org
|
||||
c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{
|
||||
Name: "test-key",
|
||||
OrganizationId: orgID,
|
||||
Key: hashedKey,
|
||||
}
|
||||
|
||||
// Run concurrent reads (reduced for race detector)
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 20
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
org := c.OrganizationByAPIKey(apiKey)
|
||||
assert.NotNil(t, org)
|
||||
assert.Equal(t, "Concurrent Test Org", org.Name)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCache_ConcurrentWrites(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10 // Reduced for race detector
|
||||
|
||||
// Concurrent organization additions
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
orgID := uuid.New().String()
|
||||
event := &domain.OrganizationAdded{
|
||||
Name: "Org " + string(rune(index)),
|
||||
Initiator: "user-" + string(rune(index)),
|
||||
}
|
||||
event.ID = *eventsourced.IdFromString(orgID)
|
||||
_, err := c.Update(event, nil)
|
||||
assert.NoError(t, err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all organizations were added
|
||||
assert.Equal(t, numGoroutines, len(c.organizations))
|
||||
}
|
||||
|
||||
func TestCache_ConcurrentReadsAndWrites(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
c := New(logger)
|
||||
|
||||
// Setup initial data
|
||||
orgID := uuid.New().String()
|
||||
apiKey := "test-rw-key" // gitleaks:allow
|
||||
hashedKey, err := hash.APIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregateFromString(orgID),
|
||||
Name: "RW Test Org",
|
||||
}
|
||||
c.organizations[orgID] = org
|
||||
c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{
|
||||
Name: "test-key",
|
||||
OrganizationId: orgID,
|
||||
Key: hashedKey,
|
||||
}
|
||||
c.users["user-initial"] = []string{orgID}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numReaders := 10 // Reduced for race detector
|
||||
numWriters := 5 // Reduced for race detector
|
||||
iterations := 3 // Reduced for race detector
|
||||
|
||||
// Concurrent readers
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
org := c.OrganizationByAPIKey(apiKey)
|
||||
assert.NotNil(t, org)
|
||||
orgs := c.OrganizationsByUser("user-initial")
|
||||
assert.NotEmpty(t, orgs)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent writers
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
newOrgID := uuid.New().String()
|
||||
event := &domain.OrganizationAdded{
|
||||
Name: "New Org " + string(rune(index)),
|
||||
Initiator: "user-new-" + string(rune(index)),
|
||||
}
|
||||
event.ID = *eventsourced.IdFromString(newOrgID)
|
||||
_, err := c.Update(event, nil)
|
||||
assert.NoError(t, err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify cache is in consistent state
|
||||
assert.GreaterOrEqual(t, len(c.organizations), numWriters)
|
||||
}
|
||||
+24
-5
@@ -30,6 +30,7 @@ import (
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/graph"
|
||||
"gitlab.com/unboundsoftware/schemas/graph/generated"
|
||||
"gitlab.com/unboundsoftware/schemas/health"
|
||||
"gitlab.com/unboundsoftware/schemas/logging"
|
||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||
"gitlab.com/unboundsoftware/schemas/monitoring"
|
||||
@@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
|
||||
|
||||
srv.AddTransport(transport.Websocket{
|
||||
KeepAlivePingInterval: 10 * time.Second,
|
||||
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
logger.Info("WebSocket connection with API key", "has_key", true)
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := serviceCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
} else {
|
||||
logger.Warn("WebSocket: No organization found for API key")
|
||||
}
|
||||
} else {
|
||||
logger.Info("WebSocket connection without API key")
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
},
|
||||
})
|
||||
srv.AddTransport(transport.Options{})
|
||||
srv.AddTransport(transport.GET{})
|
||||
@@ -223,8 +242,12 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
|
||||
Cache: lru.New[string](100),
|
||||
})
|
||||
|
||||
healthChecker := health.New(db.DB, logger)
|
||||
|
||||
mux.Handle("/", monitoring.Handler(playground.Handler("GraphQL playground", "/query")))
|
||||
mux.Handle("/health", http.HandlerFunc(healthFunc))
|
||||
mux.Handle("/health", http.HandlerFunc(healthChecker.LivenessHandler))
|
||||
mux.Handle("/health/live", http.HandlerFunc(healthChecker.LivenessHandler))
|
||||
mux.Handle("/health/ready", http.HandlerFunc(healthChecker.ReadinessHandler))
|
||||
mux.Handle("/query", cors.AllowAll().Handler(
|
||||
monitoring.Handler(
|
||||
mw.Middleware().CheckJWT(
|
||||
@@ -283,10 +306,6 @@ func loadSubGraphs(ctx context.Context, eventStore eventsourced.EventStore, serv
|
||||
return nil
|
||||
}
|
||||
|
||||
func healthFunc(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func ConnectAMQP(url string) (Connection, error) {
|
||||
return goamqp.NewFromURL(serviceName, url)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,362 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/handler/transport"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation for testing
|
||||
type MockCache struct {
|
||||
organizations map[string]*domain.Organization // keyed by orgId-name composite
|
||||
apiKeys map[string]string // maps orgId-name to hashed key
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(plainKey string) *domain.Organization {
|
||||
// Find organization by comparing plaintext key with stored hash
|
||||
for compositeKey, hashedKey := range m.apiKeys {
|
||||
if hash.CompareAPIKey(hashedKey, plainKey) {
|
||||
return m.organizations[compositeKey]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
orgID := uuid.New()
|
||||
org := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey, err := hash.APIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
compositeKey := orgID.String() + "-test-key"
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
compositeKey: org,
|
||||
},
|
||||
apiKeys: map[string]string{
|
||||
compositeKey: hashedKey,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is in context
|
||||
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok, "Organization should be of correct type")
|
||||
assert.Equal(t, org.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization not found in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is NOT in context (since API key is invalid)
|
||||
value := resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when not provided")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": "", // Empty string
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (because empty string fails the condition)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when empty")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
apiKeys: map[string]string{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (type assertion fails)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when wrong type")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||
// Setup - create multiple organizations
|
||||
org1ID := uuid.New()
|
||||
org1 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org1ID.String()),
|
||||
},
|
||||
Name: "Organization 1",
|
||||
}
|
||||
|
||||
org2ID := uuid.New()
|
||||
org2 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org2ID.String()),
|
||||
},
|
||||
Name: "Organization 2",
|
||||
}
|
||||
|
||||
apiKey1 := "api-key-org-1"
|
||||
apiKey2 := "api-key-org-2"
|
||||
hashedKey1, err := hash.APIKey(apiKey1)
|
||||
require.NoError(t, err)
|
||||
hashedKey2, err := hash.APIKey(apiKey2)
|
||||
require.NoError(t, err)
|
||||
|
||||
compositeKey1 := org1ID.String() + "-key1"
|
||||
compositeKey2 := org2ID.String() + "-key2"
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
compositeKey1: org1,
|
||||
compositeKey2: org2,
|
||||
},
|
||||
apiKeys: map[string]string{
|
||||
compositeKey1: hashedKey1,
|
||||
compositeKey2: hashedKey2,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (cache handles hash comparison)
|
||||
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test with first API key
|
||||
ctx1 := context.Background()
|
||||
initPayload1 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey1,
|
||||
}
|
||||
|
||||
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org1.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 1 not found in context")
|
||||
}
|
||||
|
||||
// Test with second API key
|
||||
ctx2 := context.Background()
|
||||
initPayload2 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey2,
|
||||
}
|
||||
|
||||
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org2.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 2 not found in context")
|
||||
}
|
||||
}
|
||||
+12
-1
@@ -56,9 +56,20 @@ func (a AddAPIKey) Validate(_ context.Context, aggregate eventsourced.Aggregate)
|
||||
}
|
||||
|
||||
func (a AddAPIKey) Event(context.Context) eventsourced.Event {
|
||||
// Hash the API key using bcrypt for secure storage
|
||||
// Note: We can't return an error here, but bcrypt errors are extremely rare
|
||||
// (only if system runs out of memory or bcrypt cost is invalid)
|
||||
// We use a fixed cost of 12 which is always valid
|
||||
hashedKey, err := hash.APIKey(a.Key)
|
||||
if err != nil {
|
||||
// This should never happen with bcrypt cost 12, but if it does,
|
||||
// we'll store an empty hash which will fail validation later
|
||||
hashedKey = ""
|
||||
}
|
||||
|
||||
return &APIKeyAdded{
|
||||
Name: a.Name,
|
||||
Key: hash.String(a.Key),
|
||||
Key: hashedKey,
|
||||
Refs: a.Refs,
|
||||
Read: a.Read,
|
||||
Publish: a.Publish,
|
||||
|
||||
+24
-11
@@ -2,10 +2,13 @@ package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
func TestAddAPIKey_Event(t *testing.T) {
|
||||
@@ -24,7 +27,6 @@ func TestAddAPIKey_Event(t *testing.T) {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want eventsourced.Event
|
||||
}{
|
||||
{
|
||||
name: "event",
|
||||
@@ -37,14 +39,6 @@ func TestAddAPIKey_Event(t *testing.T) {
|
||||
Initiator: "jim@example.org",
|
||||
},
|
||||
args: args{},
|
||||
want: &APIKeyAdded{
|
||||
Name: "test",
|
||||
Key: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY/BwUmvv0yJlvuSQnrkHkZJuTTKSVmRt4UrhV",
|
||||
Refs: []string{"Example@dev"},
|
||||
Read: true,
|
||||
Publish: true,
|
||||
Initiator: "jim@example.org",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -57,7 +51,26 @@ func TestAddAPIKey_Event(t *testing.T) {
|
||||
Publish: tt.fields.Publish,
|
||||
Initiator: tt.fields.Initiator,
|
||||
}
|
||||
assert.Equalf(t, tt.want, a.Event(tt.args.in0), "Event(%v)", tt.args.in0)
|
||||
event := a.Event(tt.args.in0)
|
||||
require.NotNil(t, event)
|
||||
|
||||
// Cast to APIKeyAdded to verify fields
|
||||
apiKeyEvent, ok := event.(*APIKeyAdded)
|
||||
require.True(t, ok, "Event should be *APIKeyAdded")
|
||||
|
||||
// Verify non-key fields match exactly
|
||||
assert.Equal(t, tt.fields.Name, apiKeyEvent.Name)
|
||||
assert.Equal(t, tt.fields.Refs, apiKeyEvent.Refs)
|
||||
assert.Equal(t, tt.fields.Read, apiKeyEvent.Read)
|
||||
assert.Equal(t, tt.fields.Publish, apiKeyEvent.Publish)
|
||||
assert.Equal(t, tt.fields.Initiator, apiKeyEvent.Initiator)
|
||||
|
||||
// Verify the key is hashed correctly (bcrypt format)
|
||||
assert.True(t, strings.HasPrefix(apiKeyEvent.Key, "$2"), "Key should be bcrypt hashed")
|
||||
assert.NotEqual(t, tt.fields.Key, apiKeyEvent.Key, "Key should be hashed, not plaintext")
|
||||
|
||||
// Verify the hash matches the original key
|
||||
assert.True(t, hash.CompareAPIKey(apiKeyEvent.Key, tt.fields.Key), "Hashed key should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,13 @@ go 1.25
|
||||
|
||||
require (
|
||||
github.com/99designs/gqlgen v0.17.83
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/Khan/genqlient v0.8.1
|
||||
github.com/alecthomas/kong v1.13.0
|
||||
github.com/apex/log v1.9.0
|
||||
github.com/auth0/go-jwt-middleware/v2 v2.3.0
|
||||
github.com/auth0/go-jwt-middleware/v2 v2.3.1
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pressly/goose/v3 v3.26.0
|
||||
@@ -30,6 +32,8 @@ require (
|
||||
go.opentelemetry.io/otel/sdk/log v0.14.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -41,7 +45,6 @@ require (
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.1 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
@@ -51,6 +54,7 @@ require (
|
||||
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/sosodev/duration v1.3.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tidwall/gjson v1.17.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
@@ -63,14 +67,13 @@ require (
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/mod v0.29.0 // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -27,8 +27,8 @@ github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy
|
||||
github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/auth0/go-jwt-middleware/v2 v2.3.0 h1:4QREj6cS3d8dS05bEm443jhnqQF97FX9sMBeWqnNRzE=
|
||||
github.com/auth0/go-jwt-middleware/v2 v2.3.0/go.mod h1:dL4ObBs1/dj4/W4cYxd8rqAdDGXYyd5rqbpMIxcbVrU=
|
||||
github.com/auth0/go-jwt-middleware/v2 v2.3.1 h1:lbDyWE9aLydb3zrank+Gufb9qGJN9u//7EbJK07pRrw=
|
||||
github.com/auth0/go-jwt-middleware/v2 v2.3.1/go.mod h1:mqVr0gdB5zuaFyQFWMJH/c/2hehNjbYUD4i8Dpyf+Hc=
|
||||
github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
|
||||
github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
@@ -83,6 +83,7 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht
|
||||
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||
github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
@@ -141,6 +142,8 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA
|
||||
github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE=
|
||||
github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
@@ -212,8 +215,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
@@ -221,21 +224,21 @@ golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
|
||||
+1
-1
@@ -38,7 +38,7 @@ func ToGqlAPIKeys(keys []domain.APIKey) []*model.APIKey {
|
||||
result[i] = &model.APIKey{
|
||||
ID: apiKeyId(k.OrganizationId, k.Name),
|
||||
Name: k.Name,
|
||||
Key: &k.Key,
|
||||
Key: nil, // Never return the hashed key - only return plaintext on creation
|
||||
Organization: nil,
|
||||
Refs: k.Refs,
|
||||
Read: k.Read,
|
||||
|
||||
+98
-27
@@ -1,54 +1,125 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/graph/model"
|
||||
)
|
||||
|
||||
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
|
||||
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
|
||||
// Build the Cosmo router config structure
|
||||
// This is a simplified version - you may need to adjust based on actual Cosmo requirements
|
||||
config := map[string]interface{}{
|
||||
"version": "1",
|
||||
"subgraphs": convertSubGraphsToCosmo(subGraphs),
|
||||
// Add other Cosmo-specific configuration as needed
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal cosmo config: %w", err)
|
||||
}
|
||||
|
||||
return string(configJSON), nil
|
||||
// CommandExecutor is an interface for executing external commands
|
||||
// This allows for mocking in tests
|
||||
type CommandExecutor interface {
|
||||
Execute(name string, args ...string) ([]byte, error)
|
||||
}
|
||||
|
||||
func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} {
|
||||
cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs))
|
||||
// DefaultCommandExecutor implements CommandExecutor using os/exec
|
||||
type DefaultCommandExecutor struct{}
|
||||
|
||||
// Execute runs a command and returns its combined output
|
||||
func (e *DefaultCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
|
||||
cmd := exec.Command(name, args...)
|
||||
return cmd.CombinedOutput()
|
||||
}
|
||||
|
||||
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
|
||||
// using the official wgc CLI tool via npx
|
||||
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
|
||||
return GenerateCosmoRouterConfigWithExecutor(subGraphs, &DefaultCommandExecutor{})
|
||||
}
|
||||
|
||||
// GenerateCosmoRouterConfigWithExecutor generates a Cosmo Router execution config from subgraphs
|
||||
// using the provided command executor (useful for testing)
|
||||
func GenerateCosmoRouterConfigWithExecutor(subGraphs []*model.SubGraph, executor CommandExecutor) (string, error) {
|
||||
if len(subGraphs) == 0 {
|
||||
return "", fmt.Errorf("no subgraphs provided")
|
||||
}
|
||||
|
||||
// Create a temporary directory for composition
|
||||
tmpDir, err := os.MkdirTemp("", "cosmo-compose-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Write each subgraph SDL to a file
|
||||
type SubgraphConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||
Schema map[string]string `yaml:"schema"`
|
||||
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
type InputConfig struct {
|
||||
Version int `yaml:"version"`
|
||||
Subgraphs []SubgraphConfig `yaml:"subgraphs"`
|
||||
}
|
||||
|
||||
inputConfig := InputConfig{
|
||||
Version: 1,
|
||||
Subgraphs: make([]SubgraphConfig, 0, len(subGraphs)),
|
||||
}
|
||||
|
||||
for _, sg := range subGraphs {
|
||||
cosmoSg := map[string]interface{}{
|
||||
"name": sg.Service,
|
||||
"sdl": sg.Sdl,
|
||||
// Write SDL to a temp file
|
||||
schemaFile := filepath.Join(tmpDir, fmt.Sprintf("%s.graphql", sg.Service))
|
||||
if err := os.WriteFile(schemaFile, []byte(sg.Sdl), 0o644); err != nil {
|
||||
return "", fmt.Errorf("write schema file for %s: %w", sg.Service, err)
|
||||
}
|
||||
|
||||
subgraphCfg := SubgraphConfig{
|
||||
Name: sg.Service,
|
||||
Schema: map[string]string{
|
||||
"file": schemaFile,
|
||||
},
|
||||
}
|
||||
|
||||
if sg.URL != nil {
|
||||
cosmoSg["routing_url"] = *sg.URL
|
||||
subgraphCfg.RoutingURL = *sg.URL
|
||||
}
|
||||
|
||||
if sg.WsURL != nil {
|
||||
cosmoSg["subscription"] = map[string]interface{}{
|
||||
subgraphCfg.Subscription = map[string]interface{}{
|
||||
"url": *sg.WsURL,
|
||||
"protocol": "ws",
|
||||
"websocket_subprotocol": "graphql-ws",
|
||||
}
|
||||
}
|
||||
|
||||
cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg)
|
||||
inputConfig.Subgraphs = append(inputConfig.Subgraphs, subgraphCfg)
|
||||
}
|
||||
|
||||
return cosmoSubgraphs
|
||||
// Write input config YAML
|
||||
inputFile := filepath.Join(tmpDir, "input.yaml")
|
||||
inputYAML, err := yaml.Marshal(inputConfig)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal input config: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(inputFile, inputYAML, 0o644); err != nil {
|
||||
return "", fmt.Errorf("write input config: %w", err)
|
||||
}
|
||||
|
||||
// Execute wgc router compose
|
||||
// wgc is installed globally in the Docker image
|
||||
outputFile := filepath.Join(tmpDir, "config.json")
|
||||
output, err := executor.Execute("wgc", "router", "compose",
|
||||
"--input", inputFile,
|
||||
"--out", outputFile,
|
||||
"--suppress-warnings",
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output))
|
||||
}
|
||||
|
||||
// Read the generated config
|
||||
configJSON, err := os.ReadFile(outputFile)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read output config: %w", err)
|
||||
}
|
||||
|
||||
return string(configJSON), nil
|
||||
}
|
||||
|
||||
+293
-86
@@ -2,14 +2,185 @@ package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/graph/model"
|
||||
)
|
||||
|
||||
// MockCommandExecutor implements CommandExecutor for testing
|
||||
type MockCommandExecutor struct {
|
||||
// CallCount tracks how many times Execute was called
|
||||
CallCount int
|
||||
// LastArgs stores the arguments from the last call
|
||||
LastArgs []string
|
||||
// Error can be set to simulate command failures
|
||||
Error error
|
||||
}
|
||||
|
||||
// Execute mocks the wgc command by generating a realistic config.json file
|
||||
func (m *MockCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
|
||||
m.CallCount++
|
||||
m.LastArgs = append([]string{name}, args...)
|
||||
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
|
||||
// Parse the input file to understand what subgraphs we're composing
|
||||
var inputFile, outputFile string
|
||||
for i, arg := range args {
|
||||
if arg == "--input" && i+1 < len(args) {
|
||||
inputFile = args[i+1]
|
||||
}
|
||||
if arg == "--out" && i+1 < len(args) {
|
||||
outputFile = args[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
if inputFile == "" || outputFile == "" {
|
||||
return nil, fmt.Errorf("missing required arguments")
|
||||
}
|
||||
|
||||
// Read the input YAML to get subgraph information
|
||||
inputData, err := os.ReadFile(inputFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read input file: %w", err)
|
||||
}
|
||||
|
||||
var input struct {
|
||||
Version int `yaml:"version"`
|
||||
Subgraphs []struct {
|
||||
Name string `yaml:"name"`
|
||||
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||
Schema map[string]string `yaml:"schema"`
|
||||
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||
} `yaml:"subgraphs"`
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(inputData, &input); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse input YAML: %w", err)
|
||||
}
|
||||
|
||||
// Generate a realistic Cosmo Router config based on the input
|
||||
config := map[string]interface{}{
|
||||
"version": "mock-version-uuid",
|
||||
"subgraphs": func() []map[string]interface{} {
|
||||
subgraphs := make([]map[string]interface{}, len(input.Subgraphs))
|
||||
for i, sg := range input.Subgraphs {
|
||||
subgraph := map[string]interface{}{
|
||||
"id": fmt.Sprintf("mock-id-%d", i),
|
||||
"name": sg.Name,
|
||||
}
|
||||
if sg.RoutingURL != "" {
|
||||
subgraph["routingUrl"] = sg.RoutingURL
|
||||
}
|
||||
subgraphs[i] = subgraph
|
||||
}
|
||||
return subgraphs
|
||||
}(),
|
||||
"engineConfig": map[string]interface{}{
|
||||
"graphqlSchema": generateMockSchema(input.Subgraphs),
|
||||
"datasourceConfigurations": func() []map[string]interface{} {
|
||||
dsConfigs := make([]map[string]interface{}, len(input.Subgraphs))
|
||||
for i, sg := range input.Subgraphs {
|
||||
// Read SDL from file
|
||||
sdl := ""
|
||||
if schemaFile, ok := sg.Schema["file"]; ok {
|
||||
if sdlData, err := os.ReadFile(schemaFile); err == nil {
|
||||
sdl = string(sdlData)
|
||||
}
|
||||
}
|
||||
|
||||
dsConfig := map[string]interface{}{
|
||||
"id": fmt.Sprintf("datasource-%d", i),
|
||||
"kind": "GRAPHQL",
|
||||
"customGraphql": map[string]interface{}{
|
||||
"federation": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"serviceSdl": sdl,
|
||||
},
|
||||
"subscription": func() map[string]interface{} {
|
||||
if len(sg.Subscription) > 0 {
|
||||
return map[string]interface{}{
|
||||
"enabled": true,
|
||||
"url": map[string]interface{}{
|
||||
"staticVariableContent": sg.Subscription["url"],
|
||||
},
|
||||
"protocol": sg.Subscription["protocol"],
|
||||
"websocketSubprotocol": sg.Subscription["websocket_subprotocol"],
|
||||
}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"enabled": false,
|
||||
}
|
||||
}(),
|
||||
},
|
||||
}
|
||||
dsConfigs[i] = dsConfig
|
||||
}
|
||||
return dsConfigs
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
// Write the config to the output file
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(outputFile, configJSON, 0o644); err != nil {
|
||||
return nil, fmt.Errorf("failed to write output file: %w", err)
|
||||
}
|
||||
|
||||
return []byte("Success"), nil
|
||||
}
|
||||
|
||||
// generateMockSchema creates a simple merged schema from subgraphs
|
||||
func generateMockSchema(subgraphs []struct {
|
||||
Name string `yaml:"name"`
|
||||
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||
Schema map[string]string `yaml:"schema"`
|
||||
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||
},
|
||||
) string {
|
||||
schema := strings.Builder{}
|
||||
schema.WriteString("schema {\n query: Query\n")
|
||||
|
||||
// Check if any subgraph has subscriptions
|
||||
hasSubscriptions := false
|
||||
for _, sg := range subgraphs {
|
||||
if len(sg.Subscription) > 0 {
|
||||
hasSubscriptions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasSubscriptions {
|
||||
schema.WriteString(" subscription: Subscription\n")
|
||||
}
|
||||
schema.WriteString("}\n\n")
|
||||
|
||||
// Add types by reading SDL files
|
||||
for _, sg := range subgraphs {
|
||||
if schemaFile, ok := sg.Schema["file"]; ok {
|
||||
if sdlData, err := os.ReadFile(schemaFile); err == nil {
|
||||
schema.WriteString(string(sdlData))
|
||||
schema.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schema.String()
|
||||
}
|
||||
|
||||
func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -33,7 +204,10 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err, "Config should be valid JSON")
|
||||
|
||||
assert.Equal(t, "1", result["version"], "Version should be 1")
|
||||
// Version is a UUID string from wgc
|
||||
version, ok := result["version"].(string)
|
||||
require.True(t, ok, "Version should be a string")
|
||||
assert.NotEmpty(t, version, "Version should not be empty")
|
||||
|
||||
subgraphs, ok := result["subgraphs"].([]interface{})
|
||||
require.True(t, ok, "subgraphs should be an array")
|
||||
@@ -41,14 +215,26 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
|
||||
sg := subgraphs[0].(map[string]interface{})
|
||||
assert.Equal(t, "test-service", sg["name"])
|
||||
assert.Equal(t, "http://localhost:4001/query", sg["routing_url"])
|
||||
assert.Equal(t, "type Query { test: String }", sg["sdl"])
|
||||
assert.Equal(t, "http://localhost:4001/query", sg["routingUrl"])
|
||||
|
||||
subscription, ok := sg["subscription"].(map[string]interface{})
|
||||
// Check that datasource configurations include subscription settings
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||
|
||||
ds := dsConfigs[0].(map[string]interface{})
|
||||
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have customGraphql config")
|
||||
|
||||
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have subscription config")
|
||||
assert.Equal(t, "ws://localhost:4001/query", subscription["url"])
|
||||
assert.Equal(t, "ws", subscription["protocol"])
|
||||
assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"])
|
||||
assert.True(t, subscription["enabled"].(bool), "Subscription should be enabled")
|
||||
|
||||
subUrl, ok := subscription["url"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have subscription URL")
|
||||
assert.Equal(t, "ws://localhost:4001/query", subUrl["staticVariableContent"])
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -80,18 +266,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
assert.Len(t, subgraphs, 3, "Should have 3 subgraphs")
|
||||
|
||||
// Check first service has no subscription
|
||||
// Check service names
|
||||
sg1 := subgraphs[0].(map[string]interface{})
|
||||
assert.Equal(t, "service-1", sg1["name"])
|
||||
_, hasSubscription := sg1["subscription"]
|
||||
assert.False(t, hasSubscription, "service-1 should not have subscription config")
|
||||
|
||||
// Check third service has subscription
|
||||
sg3 := subgraphs[2].(map[string]interface{})
|
||||
assert.Equal(t, "service-3", sg3["name"])
|
||||
subscription, hasSubscription := sg3["subscription"]
|
||||
assert.True(t, hasSubscription, "service-3 should have subscription config")
|
||||
assert.NotNil(t, subscription)
|
||||
|
||||
// Check that datasource configurations include subscription for service-3
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) == 3, "Should have 3 datasource configurations")
|
||||
|
||||
// Find service-3's datasource config (should have subscription enabled)
|
||||
ds3 := dsConfigs[2].(map[string]interface{})
|
||||
customGraphql, ok := ds3["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Service-3 should have customGraphql config")
|
||||
|
||||
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||
require.True(t, ok, "Service-3 should have subscription config")
|
||||
assert.True(t, subscription["enabled"].(bool), "Service-3 subscription should be enabled")
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -113,39 +309,43 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
sg := subgraphs[0].(map[string]interface{})
|
||||
|
||||
// Should not have routing_url or subscription fields if URLs are nil
|
||||
_, hasRoutingURL := sg["routing_url"]
|
||||
assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil")
|
||||
// Should not have routing URL when URL is nil
|
||||
_, hasRoutingURL := sg["routingUrl"]
|
||||
assert.False(t, hasRoutingURL, "Should not have routingUrl when URL is nil")
|
||||
|
||||
_, hasSubscription := sg["subscription"]
|
||||
assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil")
|
||||
// Check datasource configurations don't have subscription enabled
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||
|
||||
ds := dsConfigs[0].(map[string]interface{})
|
||||
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have customGraphql config")
|
||||
|
||||
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||
if ok {
|
||||
// wgc always enables subscription but URL should be empty when WsURL is nil
|
||||
subUrl, hasUrl := subscription["url"].(map[string]interface{})
|
||||
if hasUrl {
|
||||
_, hasStaticContent := subUrl["staticVariableContent"]
|
||||
assert.False(t, hasStaticContent, "Subscription URL should be empty when WsURL is nil")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty subgraphs",
|
||||
subGraphs: []*model.SubGraph{},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config string) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
assert.Len(t, subgraphs, 0, "Should have empty subgraphs array")
|
||||
},
|
||||
wantErr: true,
|
||||
validate: nil,
|
||||
},
|
||||
{
|
||||
name: "nil subgraphs",
|
||||
subGraphs: nil,
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config string) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array")
|
||||
},
|
||||
wantErr: true,
|
||||
validate: nil,
|
||||
},
|
||||
{
|
||||
name: "complex SDL with multiple types",
|
||||
@@ -173,29 +373,58 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
sg := subgraphs[0].(map[string]interface{})
|
||||
sdl := sg["sdl"].(string)
|
||||
// Check the composed graphqlSchema contains the types
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
assert.Contains(t, sdl, "type Query")
|
||||
assert.Contains(t, sdl, "type User")
|
||||
assert.Contains(t, sdl, "email: String!")
|
||||
graphqlSchema, ok := engineConfig["graphqlSchema"].(string)
|
||||
require.True(t, ok, "Should have graphqlSchema")
|
||||
|
||||
assert.Contains(t, graphqlSchema, "Query", "Schema should contain Query type")
|
||||
assert.Contains(t, graphqlSchema, "User", "Schema should contain User type")
|
||||
|
||||
// Check datasource has the original SDL
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||
|
||||
ds := dsConfigs[0].(map[string]interface{})
|
||||
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have customGraphql config")
|
||||
|
||||
federation, ok := customGraphql["federation"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have federation config")
|
||||
|
||||
serviceSdl, ok := federation["serviceSdl"].(string)
|
||||
require.True(t, ok, "Should have serviceSdl")
|
||||
assert.Contains(t, serviceSdl, "email: String!", "SDL should contain email field")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config, err := GenerateCosmoRouterConfig(tt.subGraphs)
|
||||
// Use mock executor for all tests
|
||||
mockExecutor := &MockCommandExecutor{}
|
||||
config, err := GenerateCosmoRouterConfigWithExecutor(tt.subGraphs, mockExecutor)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
// Verify executor was not called for error cases
|
||||
if len(tt.subGraphs) == 0 {
|
||||
assert.Equal(t, 0, mockExecutor.CallCount, "Should not call executor for empty subgraphs")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, config, "Config should not be empty")
|
||||
|
||||
// Verify executor was called correctly
|
||||
assert.Equal(t, 1, mockExecutor.CallCount, "Should call executor once")
|
||||
assert.Equal(t, "wgc", mockExecutor.LastArgs[0], "Should call wgc command")
|
||||
assert.Contains(t, mockExecutor.LastArgs, "router", "Should include 'router' arg")
|
||||
assert.Contains(t, mockExecutor.LastArgs, "compose", "Should include 'compose' arg")
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, config)
|
||||
}
|
||||
@@ -203,53 +432,31 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSubGraphsToCosmo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subGraphs []*model.SubGraph
|
||||
wantLen int
|
||||
validate func(t *testing.T, result []map[string]interface{})
|
||||
}{
|
||||
// TestGenerateCosmoRouterConfig_MockError tests error handling with mock executor
|
||||
func TestGenerateCosmoRouterConfig_MockError(t *testing.T) {
|
||||
subGraphs := []*model.SubGraph{
|
||||
{
|
||||
name: "preserves subgraph order",
|
||||
subGraphs: []*model.SubGraph{
|
||||
{Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"},
|
||||
{Service: "beta", URL: stringPtr("http://b"), Sdl: "b"},
|
||||
{Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"},
|
||||
},
|
||||
wantLen: 3,
|
||||
validate: func(t *testing.T, result []map[string]interface{}) {
|
||||
assert.Equal(t, "alpha", result[0]["name"])
|
||||
assert.Equal(t, "beta", result[1]["name"])
|
||||
assert.Equal(t, "gamma", result[2]["name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "includes SDL exactly as provided",
|
||||
subGraphs: []*model.SubGraph{
|
||||
{
|
||||
Service: "test",
|
||||
URL: stringPtr("http://test"),
|
||||
Sdl: "type Query { special: String! }",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, result []map[string]interface{}) {
|
||||
assert.Equal(t, "type Query { special: String! }", result[0]["sdl"])
|
||||
},
|
||||
Service: "test-service",
|
||||
URL: stringPtr("http://localhost:4001/query"),
|
||||
Sdl: "type Query { test: String }",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertSubGraphsToCosmo(tt.subGraphs)
|
||||
assert.Len(t, result, tt.wantLen)
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
// Create a mock executor that returns an error
|
||||
mockExecutor := &MockCommandExecutor{
|
||||
Error: fmt.Errorf("simulated wgc failure"),
|
||||
}
|
||||
|
||||
config, err := GenerateCosmoRouterConfigWithExecutor(subGraphs, mockExecutor)
|
||||
|
||||
// Verify error is propagated
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "wgc router compose failed")
|
||||
assert.Contains(t, err.Error(), "simulated wgc failure")
|
||||
assert.Empty(t, config)
|
||||
|
||||
// Verify executor was called
|
||||
assert.Equal(t, 1, mockExecutor.CallCount, "Should have attempted to call executor")
|
||||
}
|
||||
|
||||
// Helper function for tests
|
||||
|
||||
@@ -74,6 +74,7 @@ type ComplexityRoot struct {
|
||||
}
|
||||
|
||||
Query struct {
|
||||
LatestSchema func(childComplexity int, ref string) int
|
||||
Organizations func(childComplexity int) int
|
||||
Supergraph func(childComplexity int, ref string, isAfter *string) int
|
||||
}
|
||||
@@ -124,6 +125,7 @@ type MutationResolver interface {
|
||||
type QueryResolver interface {
|
||||
Organizations(ctx context.Context) ([]*model.Organization, error)
|
||||
Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error)
|
||||
LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error)
|
||||
}
|
||||
type SubscriptionResolver interface {
|
||||
SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error)
|
||||
@@ -250,6 +252,17 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
|
||||
|
||||
return e.complexity.Organization.Users(childComplexity), true
|
||||
|
||||
case "Query.latestSchema":
|
||||
if e.complexity.Query.LatestSchema == nil {
|
||||
break
|
||||
}
|
||||
|
||||
args, err := ec.field_Query_latestSchema_args(ctx, rawArgs)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return e.complexity.Query.LatestSchema(childComplexity, args["ref"].(string)), true
|
||||
case "Query.organizations":
|
||||
if e.complexity.Query.Organizations == nil {
|
||||
break
|
||||
@@ -520,6 +533,7 @@ var sources = []*ast.Source{
|
||||
{Name: "../schema.graphqls", Input: `type Query {
|
||||
organizations: [Organization!]! @auth(user: true)
|
||||
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
||||
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
@@ -671,6 +685,17 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) field_Query_latestSchema_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
|
||||
var err error
|
||||
args := map[string]any{}
|
||||
arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args["ref"] = arg0
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
|
||||
var err error
|
||||
args := map[string]any{}
|
||||
@@ -1434,6 +1459,75 @@ func (ec *executionContext) fieldContext_Query_supergraph(ctx context.Context, f
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) _Query_latestSchema(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
|
||||
return graphql.ResolveField(
|
||||
ctx,
|
||||
ec.OperationContext,
|
||||
field,
|
||||
ec.fieldContext_Query_latestSchema,
|
||||
func(ctx context.Context) (any, error) {
|
||||
fc := graphql.GetFieldContext(ctx)
|
||||
return ec.resolvers.Query().LatestSchema(ctx, fc.Args["ref"].(string))
|
||||
},
|
||||
func(ctx context.Context, next graphql.Resolver) graphql.Resolver {
|
||||
directive0 := next
|
||||
|
||||
directive1 := func(ctx context.Context) (any, error) {
|
||||
organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true)
|
||||
if err != nil {
|
||||
var zeroVal *model.SchemaUpdate
|
||||
return zeroVal, err
|
||||
}
|
||||
if ec.directives.Auth == nil {
|
||||
var zeroVal *model.SchemaUpdate
|
||||
return zeroVal, errors.New("directive auth is not implemented")
|
||||
}
|
||||
return ec.directives.Auth(ctx, nil, directive0, nil, organization)
|
||||
}
|
||||
|
||||
next = directive1
|
||||
return next
|
||||
},
|
||||
ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate,
|
||||
true,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func (ec *executionContext) fieldContext_Query_latestSchema(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
|
||||
fc = &graphql.FieldContext{
|
||||
Object: "Query",
|
||||
Field: field,
|
||||
IsMethod: true,
|
||||
IsResolver: true,
|
||||
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
|
||||
switch field.Name {
|
||||
case "ref":
|
||||
return ec.fieldContext_SchemaUpdate_ref(ctx, field)
|
||||
case "id":
|
||||
return ec.fieldContext_SchemaUpdate_id(ctx, field)
|
||||
case "subGraphs":
|
||||
return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field)
|
||||
case "cosmoRouterConfig":
|
||||
return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field)
|
||||
}
|
||||
return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name)
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = ec.Recover(ctx, r)
|
||||
ec.Error(ctx, err)
|
||||
}
|
||||
}()
|
||||
ctx = graphql.WithFieldContext(ctx, fc)
|
||||
if fc.Args, err = ec.field_Query_latestSchema_args(ctx, field.ArgumentMap(ec.Variables)); err != nil {
|
||||
ec.Error(ctx, err)
|
||||
return fc, err
|
||||
}
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
|
||||
return graphql.ResolveField(
|
||||
ctx,
|
||||
@@ -3997,6 +4091,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr
|
||||
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
|
||||
}
|
||||
|
||||
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
|
||||
case "latestSchema":
|
||||
field := field
|
||||
|
||||
innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
ec.Error(ctx, ec.Recover(ctx, r))
|
||||
}
|
||||
}()
|
||||
res = ec._Query_latestSchema(ctx, field)
|
||||
if res == graphql.Null {
|
||||
atomic.AddUint32(&fs.Invalids, 1)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
rrm := func(ctx context.Context) graphql.Marshaler {
|
||||
return ec.OperationContext.RootResolverMiddleware(ctx,
|
||||
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
|
||||
}
|
||||
|
||||
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
|
||||
case "__type":
|
||||
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
type Query {
|
||||
organizations: [Organization!]! @auth(user: true)
|
||||
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
||||
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
|
||||
+113
-5
@@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
|
||||
// Publish schema update to subscribers
|
||||
go func() {
|
||||
services, lastUpdate := r.Cache.Services(orgId, input.Ref, "")
|
||||
r.Logger.Info("Publishing schema update after subgraph change",
|
||||
"ref", input.Ref,
|
||||
"orgId", orgId,
|
||||
"lastUpdate", lastUpdate,
|
||||
"servicesCount", len(services),
|
||||
)
|
||||
|
||||
subGraphs := make([]*model.SubGraph, len(services))
|
||||
for i, id := range services {
|
||||
sg, err := r.fetchSubGraph(context.Background(), id)
|
||||
@@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
|
||||
}
|
||||
|
||||
// Publish to all subscribers of this ref
|
||||
r.PubSub.Publish(input.Ref, &model.SchemaUpdate{
|
||||
update := &model.SchemaUpdate{
|
||||
Ref: input.Ref,
|
||||
ID: lastUpdate,
|
||||
SubGraphs: subGraphs,
|
||||
CosmoRouterConfig: &cosmoConfig,
|
||||
})
|
||||
}
|
||||
|
||||
r.Logger.Info("Publishing schema update to subscribers",
|
||||
"ref", update.Ref,
|
||||
"id", update.ID,
|
||||
"subGraphsCount", len(update.SubGraphs),
|
||||
"cosmoConfigLength", len(cosmoConfig),
|
||||
)
|
||||
|
||||
r.PubSub.Publish(input.Ref, update)
|
||||
}()
|
||||
|
||||
return r.toGqlSubGraph(subGraph), nil
|
||||
@@ -222,11 +238,84 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LatestSchema is the resolver for the latestSchema field.
|
||||
func (r *queryResolver) LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) {
|
||||
orgId := middleware.OrganizationFromContext(ctx)
|
||||
|
||||
r.Logger.Info("LatestSchema query",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
)
|
||||
|
||||
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
|
||||
if err != nil {
|
||||
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get current services and schema
|
||||
services, lastUpdate := r.Cache.Services(orgId, ref, "")
|
||||
r.Logger.Info("Fetching latest schema",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
"lastUpdate", lastUpdate,
|
||||
"servicesCount", len(services),
|
||||
)
|
||||
|
||||
subGraphs := make([]*model.SubGraph, len(services))
|
||||
for i, id := range services {
|
||||
sg, err := r.fetchSubGraph(ctx, id)
|
||||
if err != nil {
|
||||
r.Logger.Error("fetch subgraph", "error", err, "id", id)
|
||||
return nil, err
|
||||
}
|
||||
subGraphs[i] = &model.SubGraph{
|
||||
ID: sg.ID.String(),
|
||||
Service: sg.Service,
|
||||
URL: sg.Url,
|
||||
WsURL: sg.WSUrl,
|
||||
Sdl: sg.Sdl,
|
||||
ChangedBy: sg.ChangedBy,
|
||||
ChangedAt: sg.ChangedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// Generate Cosmo router config
|
||||
cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs)
|
||||
if err != nil {
|
||||
r.Logger.Error("generate cosmo config", "error", err)
|
||||
cosmoConfig = "" // Return empty if generation fails
|
||||
}
|
||||
|
||||
update := &model.SchemaUpdate{
|
||||
Ref: ref,
|
||||
ID: lastUpdate,
|
||||
SubGraphs: subGraphs,
|
||||
CosmoRouterConfig: &cosmoConfig,
|
||||
}
|
||||
|
||||
r.Logger.Info("Latest schema fetched",
|
||||
"ref", update.Ref,
|
||||
"id", update.ID,
|
||||
"subGraphsCount", len(update.SubGraphs),
|
||||
"cosmoConfigLength", len(cosmoConfig),
|
||||
)
|
||||
|
||||
return update, nil
|
||||
}
|
||||
|
||||
// SchemaUpdates is the resolver for the schemaUpdates field.
|
||||
func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) {
|
||||
orgId := middleware.OrganizationFromContext(ctx)
|
||||
|
||||
r.Logger.Info("SchemaUpdates subscription started",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
)
|
||||
|
||||
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
|
||||
if err != nil {
|
||||
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -235,12 +324,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
|
||||
|
||||
// Send initial state immediately
|
||||
go func() {
|
||||
// Use background context for async operation
|
||||
bgCtx := context.Background()
|
||||
|
||||
services, lastUpdate := r.Cache.Services(orgId, ref, "")
|
||||
r.Logger.Info("Preparing initial schema update",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
"lastUpdate", lastUpdate,
|
||||
"servicesCount", len(services),
|
||||
)
|
||||
|
||||
subGraphs := make([]*model.SubGraph, len(services))
|
||||
for i, id := range services {
|
||||
sg, err := r.fetchSubGraph(ctx, id)
|
||||
sg, err := r.fetchSubGraph(bgCtx, id)
|
||||
if err != nil {
|
||||
r.Logger.Error("fetch subgraph for initial update", "error", err)
|
||||
r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id)
|
||||
continue
|
||||
}
|
||||
subGraphs[i] = &model.SubGraph{
|
||||
@@ -262,12 +361,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
|
||||
}
|
||||
|
||||
// Send initial update
|
||||
ch <- &model.SchemaUpdate{
|
||||
update := &model.SchemaUpdate{
|
||||
Ref: ref,
|
||||
ID: lastUpdate,
|
||||
SubGraphs: subGraphs,
|
||||
CosmoRouterConfig: &cosmoConfig,
|
||||
}
|
||||
|
||||
r.Logger.Info("Sending initial schema update",
|
||||
"ref", update.Ref,
|
||||
"id", update.ID,
|
||||
"subGraphsCount", len(update.SubGraphs),
|
||||
"cosmoConfigLength", len(cosmoConfig),
|
||||
)
|
||||
|
||||
ch <- update
|
||||
}()
|
||||
|
||||
// Clean up subscription when context is done
|
||||
|
||||
@@ -3,9 +3,72 @@ package hash
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// String creates a SHA256 hash of a string (legacy, for non-sensitive data)
|
||||
func String(s string) string {
|
||||
encoded := sha256.New().Sum([]byte(s))
|
||||
return base64.StdEncoding.EncodeToString(encoded)
|
||||
}
|
||||
|
||||
// APIKey hashes an API key using bcrypt for secure storage
|
||||
// Cost of 12 provides a good balance between security and performance
|
||||
func APIKey(key string) (string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(key), 12)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// CompareAPIKey compares a plaintext API key with a hash
|
||||
// Supports both bcrypt (new) and SHA256 (legacy) hashes for backwards compatibility
|
||||
// Returns true if they match, false otherwise
|
||||
//
|
||||
// Migration Strategy:
|
||||
// Old API keys stored with SHA256 will continue to work. To upgrade them to bcrypt:
|
||||
// 1. Keys are automatically upgraded when users re-authenticate (if implemented)
|
||||
// 2. Or, run a one-time migration using MigrateAPIKeyHash when convenient
|
||||
func CompareAPIKey(hashedKey, plainKey string) bool {
|
||||
// Bcrypt hashes start with $2a$, $2b$, or $2y$
|
||||
// If the hash starts with $2, it's a bcrypt hash
|
||||
if len(hashedKey) > 2 && hashedKey[0] == '$' && hashedKey[1] == '2' {
|
||||
// New bcrypt hash
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedKey), []byte(plainKey))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Legacy SHA256 hash - compare using the old method
|
||||
legacyHash := String(plainKey)
|
||||
return hashedKey == legacyHash
|
||||
}
|
||||
|
||||
// IsLegacyHash returns true if the hash is a legacy SHA256 hash (not bcrypt)
|
||||
func IsLegacyHash(hashedKey string) bool {
|
||||
return len(hashedKey) <= 2 || hashedKey[0] != '$' || hashedKey[1] != '2'
|
||||
}
|
||||
|
||||
// MigrateAPIKeyHash can be used to upgrade a legacy SHA256 hash to bcrypt
|
||||
// This is useful for one-time migrations of existing keys
|
||||
// Returns the new bcrypt hash if the key is legacy, otherwise returns the original
|
||||
func MigrateAPIKeyHash(currentHash, plainKey string) (string, bool, error) {
|
||||
// If already bcrypt, no migration needed
|
||||
if !IsLegacyHash(currentHash) {
|
||||
return currentHash, false, nil
|
||||
}
|
||||
|
||||
// Verify the legacy hash is correct before migrating
|
||||
if !CompareAPIKey(currentHash, plainKey) {
|
||||
return "", false, nil // Invalid key, don't migrate
|
||||
}
|
||||
|
||||
// Generate new bcrypt hash
|
||||
newHash, err := APIKey(plainKey)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
return newHash, true, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
package hash
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIKey(t *testing.T) {
|
||||
key := "test_api_key_12345" // gitleaks:allow
|
||||
|
||||
hash1, err := APIKey(key)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, hash1)
|
||||
assert.NotEqual(t, key, hash1, "Hash should not equal plaintext")
|
||||
|
||||
// Bcrypt hashes should start with $2
|
||||
assert.True(t, strings.HasPrefix(hash1, "$2"), "Should be a bcrypt hash")
|
||||
|
||||
// Same key should produce different hashes (due to salt)
|
||||
hash2, err := APIKey(key)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, hash1, hash2, "Bcrypt should produce different hashes with different salts")
|
||||
}
|
||||
|
||||
func TestCompareAPIKey_Bcrypt(t *testing.T) {
|
||||
key := "test_api_key_12345" // gitleaks:allow
|
||||
|
||||
hash, err := APIKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Correct key should match
|
||||
assert.True(t, CompareAPIKey(hash, key))
|
||||
|
||||
// Wrong key should not match
|
||||
assert.False(t, CompareAPIKey(hash, "wrong_key"))
|
||||
}
|
||||
|
||||
func TestCompareAPIKey_Legacy(t *testing.T) {
|
||||
key := "test_api_key_12345" // gitleaks:allow
|
||||
|
||||
// Create a legacy SHA256 hash
|
||||
legacyHash := String(key)
|
||||
|
||||
// Should still work with legacy hashes
|
||||
assert.True(t, CompareAPIKey(legacyHash, key))
|
||||
|
||||
// Wrong key should not match
|
||||
assert.False(t, CompareAPIKey(legacyHash, "wrong_key"))
|
||||
}
|
||||
|
||||
func TestCompareAPIKey_BackwardCompatibility(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hashFunc func(string) string
|
||||
expectOK bool
|
||||
}{
|
||||
{
|
||||
name: "bcrypt hash",
|
||||
hashFunc: func(k string) string {
|
||||
h, _ := APIKey(k)
|
||||
return h
|
||||
},
|
||||
expectOK: true,
|
||||
},
|
||||
{
|
||||
name: "legacy SHA256 hash",
|
||||
hashFunc: func(k string) string {
|
||||
return String(k)
|
||||
},
|
||||
expectOK: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key := "test_key_123"
|
||||
hash := tt.hashFunc(key)
|
||||
|
||||
result := CompareAPIKey(hash, key)
|
||||
assert.Equal(t, tt.expectOK, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
// Test that String function still works (for non-sensitive data)
|
||||
input := "test_string"
|
||||
hash1 := String(input)
|
||||
hash2 := String(input)
|
||||
|
||||
// SHA256 should be deterministic
|
||||
assert.Equal(t, hash1, hash2)
|
||||
assert.NotEmpty(t, hash1)
|
||||
assert.NotEqual(t, input, hash1)
|
||||
}
|
||||
|
||||
func TestIsLegacyHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hash string
|
||||
isLegacy bool
|
||||
}{
|
||||
{
|
||||
name: "bcrypt hash",
|
||||
hash: "$2a$12$abcdefghijklmnopqrstuv",
|
||||
isLegacy: false,
|
||||
},
|
||||
{
|
||||
name: "SHA256 hash",
|
||||
hash: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY",
|
||||
isLegacy: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
hash: "",
|
||||
isLegacy: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.isLegacy, IsLegacyHash(tt.hash))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateAPIKeyHash(t *testing.T) {
|
||||
plainKey := "test_api_key_123"
|
||||
|
||||
t.Run("migrate legacy hash", func(t *testing.T) {
|
||||
// Create a legacy SHA256 hash
|
||||
legacyHash := String(plainKey)
|
||||
|
||||
// Migrate it
|
||||
newHash, migrated, err := MigrateAPIKeyHash(legacyHash, plainKey)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, migrated, "Should indicate migration occurred")
|
||||
assert.NotEqual(t, legacyHash, newHash, "New hash should differ from legacy")
|
||||
assert.True(t, strings.HasPrefix(newHash, "$2"), "New hash should be bcrypt")
|
||||
|
||||
// Verify new hash works
|
||||
assert.True(t, CompareAPIKey(newHash, plainKey))
|
||||
})
|
||||
|
||||
t.Run("no migration needed for bcrypt", func(t *testing.T) {
|
||||
// Create a bcrypt hash
|
||||
bcryptHash, err := APIKey(plainKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to migrate it
|
||||
newHash, migrated, err := MigrateAPIKeyHash(bcryptHash, plainKey)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, migrated, "Should not migrate bcrypt hash")
|
||||
assert.Equal(t, bcryptHash, newHash, "Hash should remain unchanged")
|
||||
})
|
||||
|
||||
t.Run("invalid key does not migrate", func(t *testing.T) {
|
||||
legacyHash := String("correct_key")
|
||||
|
||||
// Try to migrate with wrong plaintext
|
||||
newHash, migrated, err := MigrateAPIKeyHash(legacyHash, "wrong_key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, migrated, "Should not migrate invalid key")
|
||||
assert.Empty(t, newHash, "Should return empty for invalid key")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Checker struct {
|
||||
db *sql.DB
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func New(db *sql.DB, logger *slog.Logger) *Checker {
|
||||
return &Checker{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
type HealthStatus struct {
|
||||
Status string `json:"status"`
|
||||
Checks map[string]string `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
// LivenessHandler checks if the application is running
|
||||
// This is a simple check that always returns OK if the handler is reached
|
||||
func (h *Checker) LivenessHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(HealthStatus{
|
||||
Status: "UP",
|
||||
})
|
||||
}
|
||||
|
||||
// ReadinessHandler checks if the application is ready to accept traffic
|
||||
// This checks database connectivity and other critical dependencies
|
||||
func (h *Checker) ReadinessHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
checks := make(map[string]string)
|
||||
allHealthy := true
|
||||
|
||||
// Check database connectivity
|
||||
if err := h.db.PingContext(ctx); err != nil {
|
||||
h.logger.With("error", err).Warn("database health check failed")
|
||||
checks["database"] = "DOWN"
|
||||
allHealthy = false
|
||||
} else {
|
||||
checks["database"] = "UP"
|
||||
}
|
||||
|
||||
status := HealthStatus{
|
||||
Status: "UP",
|
||||
Checks: checks,
|
||||
}
|
||||
|
||||
if !allHealthy {
|
||||
status.Status = "DOWN"
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_ = json.NewEncoder(w).Encode(status)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLivenessHandler(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
checker := New(db, logger)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
checker.LivenessHandler(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), `"status":"UP"`)
|
||||
}
|
||||
|
||||
func TestReadinessHandler_Healthy(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Expect a ping and return success
|
||||
mock.ExpectPing().WillReturnError(nil)
|
||||
|
||||
checker := New(db, logger)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
checker.ReadinessHandler(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), `"status":"UP"`)
|
||||
assert.Contains(t, rec.Body.String(), `"database":"UP"`)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestReadinessHandler_DatabaseDown(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Expect a ping and return error
|
||||
mock.ExpectPing().WillReturnError(sql.ErrConnDone)
|
||||
|
||||
checker := New(db, logger)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
checker.ReadinessHandler(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), `"status":"DOWN"`)
|
||||
assert.Contains(t, rec.Body.String(), `"database":"DOWN"`)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
+10
-1
@@ -44,13 +44,22 @@ spec:
|
||||
requests:
|
||||
cpu: "20m"
|
||||
memory: "20Mi"
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/live
|
||||
port: 8080
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
timeoutSeconds: 5
|
||||
failureThreshold: 3
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
path: /health/ready
|
||||
port: 8080
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 5
|
||||
timeoutSeconds: 5
|
||||
failureThreshold: 3
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: registry.gitlab.com/unboundsoftware/schemas:${COMMIT}
|
||||
ports:
|
||||
|
||||
+3
-2
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -49,7 +48,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
|
||||
_, _ = w.Write([]byte("Invalid API Key format"))
|
||||
return
|
||||
}
|
||||
if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
// Cache handles hash comparison internally
|
||||
organization := m.cache.OrganizationByAPIKey(apiKey)
|
||||
if organization != nil {
|
||||
ctx = context.WithValue(ctx, OrganizationKey, *organization)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,464 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation of the Cache interface
|
||||
type MockCache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
args := m.Called(apiKey)
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*domain.Organization)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
|
||||
// Mock expects plaintext key (cache handles hashing internally)
|
||||
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
|
||||
// Mock expects plaintext key (cache handles hashing internally)
|
||||
mockCache.On("OrganizationByAPIKey", apiKey).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware passes the plaintext API key (cache handles hashing)
|
||||
mockCache.On("OrganizationByAPIKey", "").Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request without API key
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware passes the plaintext API key (cache handles hashing)
|
||||
mockCache.On("OrganizationByAPIKey", "").Return(nil)
|
||||
|
||||
userID := "user-123"
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedUser string
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with JWT token in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid API key type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid JWT token type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
userID := "user-123"
|
||||
apiKey := "test-api-key-123"
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
// Mock expects plaintext key (cache handles hashing internally)
|
||||
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks both user and organization in context
|
||||
var capturedUser string
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with both JWT and API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid user",
|
||||
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
|
||||
expected: "user-123",
|
||||
},
|
||||
{
|
||||
name: "without user",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), UserKey, 123),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := UserFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationFromContext(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid organization",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, org),
|
||||
expected: orgID.String(),
|
||||
},
|
||||
{
|
||||
name: "without organization",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := OrganizationFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
|
||||
// Test with user present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without user
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no user available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with organization present
|
||||
ctx := context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without organization
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no organization available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with both present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
ctx = context.WithValue(ctx, OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with only user
|
||||
ctx = context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test with only organization
|
||||
ctx = context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// Test with no requirements
|
||||
ctx := context.Background()
|
||||
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "success", result)
|
||||
}
|
||||
Reference in New Issue
Block a user