4468903535
Adds a new hashed key storage mechanism for API keys in the cache. Replaces direct mapping to API keys with composite keys based on organizationId and name. Implements searching of API keys using hash comparisons for improved security. Updates related tests to ensure correct functionality and validate the hashing. Also, adds support for a new dependency `golang.org/x/crypto`.
363 lines
11 KiB
Go
363 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/99designs/gqlgen/graphql/handler/transport"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
|
|
|
"gitlab.com/unboundsoftware/schemas/domain"
|
|
"gitlab.com/unboundsoftware/schemas/hash"
|
|
"gitlab.com/unboundsoftware/schemas/middleware"
|
|
)
|
|
|
|
// MockCache is a mock implementation for testing
|
|
type MockCache struct {
|
|
organizations map[string]*domain.Organization // keyed by orgId-name composite
|
|
apiKeys map[string]string // maps orgId-name to hashed key
|
|
}
|
|
|
|
func (m *MockCache) OrganizationByAPIKey(plainKey string) *domain.Organization {
|
|
// Find organization by comparing plaintext key with stored hash
|
|
for compositeKey, hashedKey := range m.apiKeys {
|
|
if hash.CompareAPIKey(hashedKey, plainKey) {
|
|
return m.organizations[compositeKey]
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
|
// Setup
|
|
orgID := uuid.New()
|
|
org := &domain.Organization{
|
|
BaseAggregate: eventsourced.BaseAggregate{
|
|
ID: eventsourced.IdFromString(orgID.String()),
|
|
},
|
|
Name: "Test Organization",
|
|
}
|
|
|
|
apiKey := "test-api-key-123"
|
|
hashedKey, err := hash.APIKey(apiKey)
|
|
require.NoError(t, err)
|
|
|
|
compositeKey := orgID.String() + "-test-key"
|
|
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{
|
|
compositeKey: org,
|
|
},
|
|
apiKeys: map[string]string{
|
|
compositeKey: hashedKey,
|
|
},
|
|
}
|
|
|
|
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key (cache handles hash comparison)
|
|
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": apiKey,
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is in context
|
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
|
assert.Equal(t, apiKey, value.(string))
|
|
} else {
|
|
t.Fatal("API key not found in context")
|
|
}
|
|
|
|
// Check organization is in context
|
|
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
|
capturedOrg, ok := value.(domain.Organization)
|
|
require.True(t, ok, "Organization should be of correct type")
|
|
assert.Equal(t, org.Name, capturedOrg.Name)
|
|
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
|
} else {
|
|
t.Fatal("Organization not found in context")
|
|
}
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
apiKeys: map[string]string{},
|
|
}
|
|
|
|
apiKey := "invalid-api-key"
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key (cache handles hash comparison)
|
|
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": apiKey,
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is in context
|
|
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
|
assert.Equal(t, apiKey, value.(string))
|
|
} else {
|
|
t.Fatal("API key not found in context")
|
|
}
|
|
|
|
// Check organization is NOT in context (since API key is invalid)
|
|
value := resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
apiKeys: map[string]string{},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key (cache handles hash comparison)
|
|
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is NOT in context
|
|
value := resultCtx.Value(middleware.ApiKey)
|
|
assert.Nil(t, value, "API key should not be set when not provided")
|
|
|
|
// Check organization is NOT in context
|
|
value = resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
apiKeys: map[string]string{},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key (cache handles hash comparison)
|
|
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": "", // Empty string
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is NOT in context (because empty string fails the condition)
|
|
value := resultCtx.Value(middleware.ApiKey)
|
|
assert.Nil(t, value, "API key should not be set when empty")
|
|
|
|
// Check organization is NOT in context
|
|
value = resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
|
// Setup
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{},
|
|
apiKeys: map[string]string{},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key (cache handles hash comparison)
|
|
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test
|
|
ctx := context.Background()
|
|
initPayload := transport.InitPayload{
|
|
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
|
}
|
|
|
|
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
|
|
|
// Assert
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resultPayload)
|
|
|
|
// Check API key is NOT in context (type assertion fails)
|
|
value := resultCtx.Value(middleware.ApiKey)
|
|
assert.Nil(t, value, "API key should not be set when wrong type")
|
|
|
|
// Check organization is NOT in context
|
|
value = resultCtx.Value(middleware.OrganizationKey)
|
|
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
|
}
|
|
|
|
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
|
// Setup - create multiple organizations
|
|
org1ID := uuid.New()
|
|
org1 := &domain.Organization{
|
|
BaseAggregate: eventsourced.BaseAggregate{
|
|
ID: eventsourced.IdFromString(org1ID.String()),
|
|
},
|
|
Name: "Organization 1",
|
|
}
|
|
|
|
org2ID := uuid.New()
|
|
org2 := &domain.Organization{
|
|
BaseAggregate: eventsourced.BaseAggregate{
|
|
ID: eventsourced.IdFromString(org2ID.String()),
|
|
},
|
|
Name: "Organization 2",
|
|
}
|
|
|
|
apiKey1 := "api-key-org-1"
|
|
apiKey2 := "api-key-org-2"
|
|
hashedKey1, err := hash.APIKey(apiKey1)
|
|
require.NoError(t, err)
|
|
hashedKey2, err := hash.APIKey(apiKey2)
|
|
require.NoError(t, err)
|
|
|
|
compositeKey1 := org1ID.String() + "-key1"
|
|
compositeKey2 := org2ID.String() + "-key2"
|
|
|
|
mockCache := &MockCache{
|
|
organizations: map[string]*domain.Organization{
|
|
compositeKey1: org1,
|
|
compositeKey2: org2,
|
|
},
|
|
apiKeys: map[string]string{
|
|
compositeKey1: hashedKey1,
|
|
compositeKey2: hashedKey2,
|
|
},
|
|
}
|
|
|
|
// Create InitFunc
|
|
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
|
// Extract API key from WebSocket connection_init payload
|
|
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
|
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
|
|
|
// Look up organization by API key (cache handles hash comparison)
|
|
if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil {
|
|
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
|
}
|
|
}
|
|
return ctx, &initPayload, nil
|
|
}
|
|
|
|
// Test with first API key
|
|
ctx1 := context.Background()
|
|
initPayload1 := transport.InitPayload{
|
|
"X-Api-Key": apiKey1,
|
|
}
|
|
|
|
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
|
require.NoError(t, err)
|
|
|
|
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
|
capturedOrg, ok := value.(domain.Organization)
|
|
require.True(t, ok)
|
|
assert.Equal(t, org1.Name, capturedOrg.Name)
|
|
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
|
} else {
|
|
t.Fatal("Organization 1 not found in context")
|
|
}
|
|
|
|
// Test with second API key
|
|
ctx2 := context.Background()
|
|
initPayload2 := transport.InitPayload{
|
|
"X-Api-Key": apiKey2,
|
|
}
|
|
|
|
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
|
require.NoError(t, err)
|
|
|
|
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
|
capturedOrg, ok := value.(domain.Organization)
|
|
require.True(t, ok)
|
|
assert.Equal(t, org2.Name, capturedOrg.Name)
|
|
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
|
} else {
|
|
t.Fatal("Organization 2 not found in context")
|
|
}
|
|
}
|