4d18cf4175
Adds unit tests for the WebSocket initialization function to validate behavior with valid, invalid, and absent API keys. Introduces a mock cache implementation to simulate organization retrieval based on hashed API keys. Ensures proper context value setting upon initialization, enhancing test coverage and reliability for API key handling in WebSocket connections.
337 lines
10 KiB
Go
337 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/99designs/gqlgen/graphql/handler/transport"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
|
|
|
"gitlab.com/unboundsoftware/schemas/domain"
|
|
"gitlab.com/unboundsoftware/schemas/hash"
|
|
"gitlab.com/unboundsoftware/schemas/middleware"
|
|
)
|
|
|
|
// MockCache is a mock implementation for testing
|
|
type MockCache struct {
|
|
organizations map[string]*domain.Organization
|
|
}
|
|
|
|
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
|
return m.organizations[apiKey]
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
|
// Setup
|
|
orgID := uuid.New()
|
|
org := &domain.Organization{
|
|
BaseAggregate: eventsourced.BaseAggregate{
|
|
ID: eventsourced.IdFromString(orgID.String()),
|
|
},
|
|
Name: "Test Organization",
|
|
}
|
|
|
|
apiKey := "test-api-key-123"
|
|
hashedKey := hash.String(apiKey)
|
|
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{
|
|
hashedKey: org,
|
|
},
|
|
}
|
|
|
|
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key
|
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": apiKey,
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is in context
|
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
|
assert.Equal(t, apiKey, value.(string))
|
|
} else {
|
|
t.Fatal("API key not found in context")
|
|
}
|
|
|
|
// Check organization is in context
|
|
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
|
capturedOrg, ok := value.(domain.Organization)
|
|
require.True(t, ok, "Organization should be of correct type")
|
|
assert.Equal(t, org.Name, capturedOrg.Name)
|
|
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
|
} else {
|
|
t.Fatal("Organization not found in context")
|
|
}
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
}
|
|
|
|
apiKey := "invalid-api-key"
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key
|
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": apiKey,
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is in context
|
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
|
assert.Equal(t, apiKey, value.(string))
|
|
} else {
|
|
t.Fatal("API key not found in context")
|
|
}
|
|
|
|
// Check organization is NOT in context (since API key is invalid)
|
|
value := resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key
|
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is NOT in context
|
|
value := resultCtx.Value(middleware.ApiKey)
|
|
assert.Nil(t, value, "API key should not be set when not provided")
|
|
|
|
// Check organization is NOT in context
|
|
value = resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key
|
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": "", // Empty string
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is NOT in context (because empty string fails the condition)
|
|
value := resultCtx.Value(middleware.ApiKey)
|
|
assert.Nil(t, value, "API key should not be set when empty")
|
|
|
|
// Check organization is NOT in context
|
|
value = resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key
|
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is NOT in context (type assertion fails)
|
|
value := resultCtx.Value(middleware.ApiKey)
|
|
assert.Nil(t, value, "API key should not be set when wrong type")
|
|
|
|
// Check organization is NOT in context
|
|
value = resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
|
// Setup - create multiple organizations
|
|
org1ID := uuid.New()
|
|
org1 := &domain.Organization{
|
|
BaseAggregate: eventsourced.BaseAggregate{
|
|
ID: eventsourced.IdFromString(org1ID.String()),
|
|
},
|
|
Name: "Organization 1",
|
|
}
|
|
|
|
org2ID := uuid.New()
|
|
org2 := &domain.Organization{
|
|
BaseAggregate: eventsourced.BaseAggregate{
|
|
ID: eventsourced.IdFromString(org2ID.String()),
|
|
},
|
|
Name: "Organization 2",
|
|
}
|
|
|
|
apiKey1 := "api-key-org-1"
|
|
apiKey2 := "api-key-org-2"
|
|
hashedKey1 := hash.String(apiKey1)
|
|
hashedKey2 := hash.String(apiKey2)
|
|
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{
|
|
hashedKey1: org1,
|
|
hashedKey2: org2,
|
|
},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key
|
|
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test with first API key
|
|
ctx1 := context.Background()
|
|
initPayload1 := transport.InitPayload{
|
|
"X-Api-Key": apiKey1,
|
|
}
|
|
|
|
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
|
require.NoError(t, err)
|
|
|
|
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
|
capturedOrg, ok := value.(domain.Organization)
|
|
require.True(t, ok)
|
|
assert.Equal(t, org1.Name, capturedOrg.Name)
|
|
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
|
} else {
|
|
t.Fatal("Organization 1 not found in context")
|
|
}
|
|
|
|
// Test with second API key
|
|
ctx2 := context.Background()
|
|
initPayload2 := transport.InitPayload{
|
|
"X-Api-Key": apiKey2,
|
|
}
|
|
|
|
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
|
require.NoError(t, err)
|
|
|
|
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
|
capturedOrg, ok := value.(domain.Organization)
|
|
require.True(t, ok)
|
|
assert.Equal(t, org2.Name, capturedOrg.Name)
|
|
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
|
} else {
|
|
t.Fatal("Organization 2 not found in context")
|
|
}
|
|
}
|