Files
schemas/cmd/service/service_test.go
T
argoyle 4468903535 feat(cache): implement hashed API key storage and retrieval
Adds a new hashed key storage mechanism for API keys in the cache. 
Replaces direct mapping to API keys with composite keys based on 
organizationId and name. Implements searching of API keys using 
hash comparisons for improved security. Updates related tests to 
ensure correct functionality and validate the hashing. Also, 
adds support for a new dependency `golang.org/x/crypto`.
2025-11-20 22:11:24 +01:00

363 lines
11 KiB
Go

package main
import (
"context"
"testing"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
"gitlab.com/unboundsoftware/schemas/middleware"
)
// MockCache is a mock implementation for testing
type MockCache struct {
organizations map[string]*domain.Organization // keyed by orgId-name composite
apiKeys map[string]string // maps orgId-name to hashed key
}
func (m *MockCache) OrganizationByAPIKey(plainKey string) *domain.Organization {
// Find organization by comparing plaintext key with stored hash
for compositeKey, hashedKey := range m.apiKeys {
if hash.CompareAPIKey(hashedKey, plainKey) {
return m.organizations[compositeKey]
}
}
return nil
}
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
// Setup
orgID := uuid.New()
org := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
hashedKey, err := hash.APIKey(apiKey)
require.NoError(t, err)
compositeKey := orgID.String() + "-test-key"
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
compositeKey: org,
},
apiKeys: map[string]string{
compositeKey: hashedKey,
},
}
// Create InitFunc (simulating the WebSocket InitFunc logic)
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is in context
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok, "Organization should be of correct type")
assert.Equal(t, org.Name, capturedOrg.Name)
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization not found in context")
}
}
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
apiKey := "invalid-api-key"
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is NOT in context (since API key is invalid)
value := resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set for invalid API key")
}
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when not provided")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is not provided")
}
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": "", // Empty string
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (because empty string fails the condition)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when empty")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is empty")
}
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
apiKeys: map[string]string{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": 12345, // Wrong type (int instead of string)
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (type assertion fails)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when wrong type")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
}
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
// Setup - create multiple organizations
org1ID := uuid.New()
org1 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org1ID.String()),
},
Name: "Organization 1",
}
org2ID := uuid.New()
org2 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org2ID.String()),
},
Name: "Organization 2",
}
apiKey1 := "api-key-org-1"
apiKey2 := "api-key-org-2"
hashedKey1, err := hash.APIKey(apiKey1)
require.NoError(t, err)
hashedKey2, err := hash.APIKey(apiKey2)
require.NoError(t, err)
compositeKey1 := org1ID.String() + "-key1"
compositeKey2 := org2ID.String() + "-key2"
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
compositeKey1: org1,
compositeKey2: org2,
},
apiKeys: map[string]string{
compositeKey1: hashedKey1,
compositeKey2: hashedKey2,
},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (cache handles hash comparison)
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test with first API key
ctx1 := context.Background()
initPayload1 := transport.InitPayload{
"X-Api-Key": apiKey1,
}
resultCtx1, _, err := initFunc(ctx1, initPayload1)
require.NoError(t, err)
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org1.Name, capturedOrg.Name)
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 1 not found in context")
}
// Test with second API key
ctx2 := context.Background()
initPayload2 := transport.InitPayload{
"X-Api-Key": apiKey2,
}
resultCtx2, _, err := initFunc(ctx2, initPayload2)
require.NoError(t, err)
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org2.Name, capturedOrg.Name)
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 2 not found in context")
}
}