package main import ( "context" "testing" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/unboundsoftware/eventsourced/eventsourced" "gitea.unbound.se/unboundsoftware/schemas/domain" "gitea.unbound.se/unboundsoftware/schemas/hash" "gitea.unbound.se/unboundsoftware/schemas/middleware" ) // MockCache is a mock implementation for testing type MockCache struct { organizations map[string]*domain.Organization // keyed by orgId-name composite apiKeys map[string]string // maps orgId-name to hashed key } func (m *MockCache) OrganizationByAPIKey(plainKey string) *domain.Organization { // Find organization by comparing plaintext key with stored hash for compositeKey, hashedKey := range m.apiKeys { if hash.CompareAPIKey(hashedKey, plainKey) { return m.organizations[compositeKey] } } return nil } func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { // Setup orgID := uuid.New() org := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(orgID.String()), }, Name: "Test Organization", } apiKey := "test-api-key-123" hashedKey, err := hash.APIKey(apiKey) require.NoError(t, err) compositeKey := orgID.String() + "-test-key" mockCache := &MockCache{ organizations: map[string]*domain.Organization{ compositeKey: org, }, apiKeys: map[string]string{ compositeKey: hashedKey, }, } // Create InitFunc (simulating the WebSocket InitFunc logic) initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key (cache handles hash comparison) if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": apiKey, } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is in context if value := resultCtx.Value(middleware.ApiKey); value != nil { assert.Equal(t, apiKey, value.(string)) } else { t.Fatal("API key not found in context") } // Check organization is in context if value := resultCtx.Value(middleware.OrganizationKey); value != nil { capturedOrg, ok := value.(domain.Organization) require.True(t, ok, "Organization should be of correct type") assert.Equal(t, org.Name, capturedOrg.Name) assert.Equal(t, org.ID.String(), capturedOrg.ID.String()) } else { t.Fatal("Organization not found in context") } } func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, apiKeys: map[string]string{}, } apiKey := "invalid-api-key" // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key (cache handles hash comparison) if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": apiKey, } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is in context if value := resultCtx.Value(middleware.ApiKey); value != nil { assert.Equal(t, apiKey, value.(string)) } else { t.Fatal("API key not found in context") } // Check organization is NOT in context (since API key is invalid) value := resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set for invalid API key") } func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, apiKeys: map[string]string{}, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key (cache handles hash comparison) if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{} resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is NOT in context value := resultCtx.Value(middleware.ApiKey) assert.Nil(t, value, "API key should not be set when not provided") // Check organization is NOT in context value = resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set when API key is not provided") } func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, apiKeys: map[string]string{}, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key (cache handles hash comparison) if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": "", // Empty string } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is NOT in context (because empty string fails the condition) value := resultCtx.Value(middleware.ApiKey) assert.Nil(t, value, "API key should not be set when empty") // Check organization is NOT in context value = resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set when API key is empty") } func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, apiKeys: map[string]string{}, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key (cache handles hash comparison) if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test ctx := context.Background() initPayload := transport.InitPayload{ "X-Api-Key": 12345, // Wrong type (int instead of string) } resultCtx, resultPayload, err := initFunc(ctx, initPayload) // Assert require.NoError(t, err) require.NotNil(t, resultPayload) // Check API key is NOT in context (type assertion fails) value := resultCtx.Value(middleware.ApiKey) assert.Nil(t, value, "API key should not be set when wrong type") // Check organization is NOT in context value = resultCtx.Value(middleware.OrganizationKey) assert.Nil(t, value, "Organization should not be set when API key has wrong type") } func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) { // Setup - create multiple organizations org1ID := uuid.New() org1 := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(org1ID.String()), }, Name: "Organization 1", } org2ID := uuid.New() org2 := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(org2ID.String()), }, Name: "Organization 2", } apiKey1 := "api-key-org-1" apiKey2 := "api-key-org-2" hashedKey1, err := hash.APIKey(apiKey1) require.NoError(t, err) hashedKey2, err := hash.APIKey(apiKey2) require.NoError(t, err) compositeKey1 := org1ID.String() + "-key1" compositeKey2 := org2ID.String() + "-key2" mockCache := &MockCache{ organizations: map[string]*domain.Organization{ compositeKey1: org1, compositeKey2: org2, }, apiKeys: map[string]string{ compositeKey1: hashedKey1, compositeKey2: hashedKey2, }, } // Create InitFunc initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Extract API key from WebSocket connection_init payload if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) // Look up organization by API key (cache handles hash comparison) if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } return ctx, &initPayload, nil } // Test with first API key ctx1 := context.Background() initPayload1 := transport.InitPayload{ "X-Api-Key": apiKey1, } resultCtx1, _, err := initFunc(ctx1, initPayload1) require.NoError(t, err) if value := resultCtx1.Value(middleware.OrganizationKey); value != nil { capturedOrg, ok := value.(domain.Organization) require.True(t, ok) assert.Equal(t, org1.Name, capturedOrg.Name) assert.Equal(t, org1.ID.String(), capturedOrg.ID.String()) } else { t.Fatal("Organization 1 not found in context") } // Test with second API key ctx2 := context.Background() initPayload2 := transport.InitPayload{ "X-Api-Key": apiKey2, } resultCtx2, _, err := initFunc(ctx2, initPayload2) require.NoError(t, err) if value := resultCtx2.Value(middleware.OrganizationKey); value != nil { capturedOrg, ok := value.(domain.Organization) require.True(t, ok) assert.Equal(t, org2.Name, capturedOrg.Name) assert.Equal(t, org2.ID.String(), capturedOrg.ID.String()) } else { t.Fatal("Organization 2 not found in context") } }