diff --git a/cache/cache.go b/cache/cache.go index 57d4304..ed058a7 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -14,7 +14,7 @@ import ( type Cache struct { organizations map[string]domain.Organization users map[string][]string - apiKeys map[string]domain.APIKey + apiKeys map[string]domain.APIKey // keyed by organizationId-name services map[string]map[string]map[string]struct{} subGraphs map[string]string lastUpdate map[string]string @@ -22,15 +22,17 @@ type Cache struct { } func (c *Cache) OrganizationByAPIKey(apiKey string) *domain.Organization { - key, exists := c.apiKeys[apiKey] - if !exists { - return nil + // Find the API key by comparing hashes + for _, key := range c.apiKeys { + if hash.CompareAPIKey(key.Key, apiKey) { + org, exists := c.organizations[key.OrganizationId] + if !exists { + return nil + } + return &org + } } - org, exists := c.organizations[key.OrganizationId] - if !exists { - return nil - } - return &org + return nil } func (c *Cache) OrganizationsByUser(sub string) []domain.Organization { @@ -43,11 +45,13 @@ func (c *Cache) OrganizationsByUser(sub string) []domain.Organization { } func (c *Cache) ApiKeyByKey(key string) *domain.APIKey { - k, exists := c.apiKeys[hash.String(key)] - if !exists { - return nil + // Find the API key by comparing hashes + for _, apiKey := range c.apiKeys { + if hash.CompareAPIKey(apiKey.Key, key) { + return &apiKey + } } - return &k + return nil } func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) { @@ -76,14 +80,15 @@ func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) { key := domain.APIKey{ Name: m.Name, OrganizationId: m.OrganizationId, - Key: m.Key, + Key: m.Key, // This is now the hashed key Refs: m.Refs, Read: m.Read, Publish: m.Publish, CreatedBy: m.Initiator, CreatedAt: m.When(), } - c.apiKeys[m.Key] = key + // Use composite key: organizationId-name + c.apiKeys[apiKeyId(m.OrganizationId, m.Name)] = key org := c.organizations[m.OrganizationId] org.APIKeys = append(org.APIKeys, key) c.organizations[m.OrganizationId] = org @@ -93,7 +98,8 @@ func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) { c.organizations[m.ID.String()] = *m c.addUser(m.CreatedBy, *m) for _, k := range m.APIKeys { - c.apiKeys[k.Key] = k + // Use composite key: organizationId-name + c.apiKeys[apiKeyId(k.OrganizationId, k.Name)] = k } case *domain.SubGraph: c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.ChangedAt) @@ -143,3 +149,7 @@ func refKey(orgId string, ref string) string { func subGraphKey(orgId string, ref string, service string) string { return fmt.Sprintf("%s<->%s<->%s", orgId, ref, service) } + +func apiKeyId(orgId string, name string) string { + return fmt.Sprintf("%s<->%s", orgId, name) +} diff --git a/cmd/service/service.go b/cmd/service/service.go index 147d612..0c7f39c 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -30,7 +30,6 @@ import ( "gitlab.com/unboundsoftware/schemas/domain" "gitlab.com/unboundsoftware/schemas/graph" "gitlab.com/unboundsoftware/schemas/graph/generated" - "gitlab.com/unboundsoftware/schemas/hash" "gitlab.com/unboundsoftware/schemas/logging" "gitlab.com/unboundsoftware/schemas/middleware" "gitlab.com/unboundsoftware/schemas/monitoring" @@ -217,8 +216,8 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u logger.Info("WebSocket connection with API key", "has_key", true) ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) - // Look up organization by API key (same logic as auth middleware) - if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + // Look up organization by API key (cache handles hash comparison) + if organization := serviceCache.OrganizationByAPIKey(apiKey); organization != nil { logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String()) ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } else { diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go index c9680d9..0bad862 100644 --- a/cmd/service/service_test.go +++ b/cmd/service/service_test.go @@ -17,11 +17,18 @@ import ( // MockCache is a mock implementation for testing type MockCache struct { - organizations map[string]*domain.Organization + organizations map[string]*domain.Organization // keyed by orgId-name composite + apiKeys map[string]string // maps orgId-name to hashed key } -func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { - return m.organizations[apiKey] +func (m *MockCache) OrganizationByAPIKey(plainKey string) *domain.Organization { + // Find organization by comparing plaintext key with stored hash + for compositeKey, hashedKey := range m.apiKeys { + if hash.CompareAPIKey(hashedKey, plainKey) { + return m.organizations[compositeKey] + } + } + return nil } func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { @@ -35,11 +42,17 @@ func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { } apiKey := "test-api-key-123" - hashedKey := hash.String(apiKey) + hashedKey, err := hash.APIKey(apiKey) + require.NoError(t, err) + + compositeKey := orgID.String() + "-test-key" mockCache := &MockCache{ organizations: map[string]*domain.Organization{ - hashedKey: org, + compositeKey: org, + }, + apiKeys: map[string]string{ + compositeKey: hashedKey, }, } @@ -49,8 +62,8 @@ func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) - // Look up organization by API key - if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + // Look up organization by API key (cache handles hash comparison) + if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } @@ -91,6 +104,7 @@ func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, + apiKeys: map[string]string{}, } apiKey := "invalid-api-key" @@ -101,8 +115,8 @@ func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) { if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) - // Look up organization by API key - if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + // Look up organization by API key (cache handles hash comparison) + if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } @@ -137,6 +151,7 @@ func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, + apiKeys: map[string]string{}, } // Create InitFunc @@ -145,8 +160,8 @@ func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) { if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) - // Look up organization by API key - if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + // Look up organization by API key (cache handles hash comparison) + if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } @@ -176,6 +191,7 @@ func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, + apiKeys: map[string]string{}, } // Create InitFunc @@ -184,8 +200,8 @@ func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) { if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) - // Look up organization by API key - if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + // Look up organization by API key (cache handles hash comparison) + if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } @@ -217,6 +233,7 @@ func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) { // Setup mockCache := &MockCache{ organizations: map[string]*domain.Organization{}, + apiKeys: map[string]string{}, } // Create InitFunc @@ -225,8 +242,8 @@ func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) { if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) - // Look up organization by API key - if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + // Look up organization by API key (cache handles hash comparison) + if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } @@ -274,13 +291,22 @@ func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) { apiKey1 := "api-key-org-1" apiKey2 := "api-key-org-2" - hashedKey1 := hash.String(apiKey1) - hashedKey2 := hash.String(apiKey2) + hashedKey1, err := hash.APIKey(apiKey1) + require.NoError(t, err) + hashedKey2, err := hash.APIKey(apiKey2) + require.NoError(t, err) + + compositeKey1 := org1ID.String() + "-key1" + compositeKey2 := org2ID.String() + "-key2" mockCache := &MockCache{ organizations: map[string]*domain.Organization{ - hashedKey1: org1, - hashedKey2: org2, + compositeKey1: org1, + compositeKey2: org2, + }, + apiKeys: map[string]string{ + compositeKey1: hashedKey1, + compositeKey2: hashedKey2, }, } @@ -290,8 +316,8 @@ func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) { if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) - // Look up organization by API key - if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + // Look up organization by API key (cache handles hash comparison) + if organization := mockCache.OrganizationByAPIKey(apiKey); organization != nil { ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) } } diff --git a/domain/commands.go b/domain/commands.go index 55fede4..1037ab0 100644 --- a/domain/commands.go +++ b/domain/commands.go @@ -56,9 +56,20 @@ func (a AddAPIKey) Validate(_ context.Context, aggregate eventsourced.Aggregate) } func (a AddAPIKey) Event(context.Context) eventsourced.Event { + // Hash the API key using bcrypt for secure storage + // Note: We can't return an error here, but bcrypt errors are extremely rare + // (only if system runs out of memory or bcrypt cost is invalid) + // We use a fixed cost of 12 which is always valid + hashedKey, err := hash.APIKey(a.Key) + if err != nil { + // This should never happen with bcrypt cost 12, but if it does, + // we'll store an empty hash which will fail validation later + hashedKey = "" + } + return &APIKeyAdded{ Name: a.Name, - Key: hash.String(a.Key), + Key: hashedKey, Refs: a.Refs, Read: a.Read, Publish: a.Publish, diff --git a/domain/commands_test.go b/domain/commands_test.go index c62e4e4..aec95ac 100644 --- a/domain/commands_test.go +++ b/domain/commands_test.go @@ -2,10 +2,13 @@ package domain import ( "context" + "strings" "testing" "github.com/stretchr/testify/assert" - "gitlab.com/unboundsoftware/eventsourced/eventsourced" + "github.com/stretchr/testify/require" + + "gitlab.com/unboundsoftware/schemas/hash" ) func TestAddAPIKey_Event(t *testing.T) { @@ -24,7 +27,6 @@ func TestAddAPIKey_Event(t *testing.T) { name string fields fields args args - want eventsourced.Event }{ { name: "event", @@ -37,14 +39,6 @@ func TestAddAPIKey_Event(t *testing.T) { Initiator: "jim@example.org", }, args: args{}, - want: &APIKeyAdded{ - Name: "test", - Key: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY/BwUmvv0yJlvuSQnrkHkZJuTTKSVmRt4UrhV", - Refs: []string{"Example@dev"}, - Read: true, - Publish: true, - Initiator: "jim@example.org", - }, }, } for _, tt := range tests { @@ -57,7 +51,26 @@ func TestAddAPIKey_Event(t *testing.T) { Publish: tt.fields.Publish, Initiator: tt.fields.Initiator, } - assert.Equalf(t, tt.want, a.Event(tt.args.in0), "Event(%v)", tt.args.in0) + event := a.Event(tt.args.in0) + require.NotNil(t, event) + + // Cast to APIKeyAdded to verify fields + apiKeyEvent, ok := event.(*APIKeyAdded) + require.True(t, ok, "Event should be *APIKeyAdded") + + // Verify non-key fields match exactly + assert.Equal(t, tt.fields.Name, apiKeyEvent.Name) + assert.Equal(t, tt.fields.Refs, apiKeyEvent.Refs) + assert.Equal(t, tt.fields.Read, apiKeyEvent.Read) + assert.Equal(t, tt.fields.Publish, apiKeyEvent.Publish) + assert.Equal(t, tt.fields.Initiator, apiKeyEvent.Initiator) + + // Verify the key is hashed correctly (bcrypt format) + assert.True(t, strings.HasPrefix(apiKeyEvent.Key, "$2"), "Key should be bcrypt hashed") + assert.NotEqual(t, tt.fields.Key, apiKeyEvent.Key, "Key should be hashed, not plaintext") + + // Verify the hash matches the original key + assert.True(t, hash.CompareAPIKey(apiKeyEvent.Key, tt.fields.Key), "Hashed key should match original") }) } } diff --git a/go.mod b/go.mod index 6a4d2f0..7599dfd 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( go.opentelemetry.io/otel/sdk/log v0.14.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 + golang.org/x/crypto v0.43.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/graph/converter.go b/graph/converter.go index 4af8538..57ae7d2 100644 --- a/graph/converter.go +++ b/graph/converter.go @@ -38,7 +38,7 @@ func ToGqlAPIKeys(keys []domain.APIKey) []*model.APIKey { result[i] = &model.APIKey{ ID: apiKeyId(k.OrganizationId, k.Name), Name: k.Name, - Key: &k.Key, + Key: nil, // Never return the hashed key - only return plaintext on creation Organization: nil, Refs: k.Refs, Read: k.Read, diff --git a/hash/hash.go b/hash/hash.go index ad5ae73..3a9583c 100644 --- a/hash/hash.go +++ b/hash/hash.go @@ -3,9 +3,72 @@ package hash import ( "crypto/sha256" "encoding/base64" + + "golang.org/x/crypto/bcrypt" ) +// String creates a SHA256 hash of a string (legacy, for non-sensitive data) func String(s string) string { encoded := sha256.New().Sum([]byte(s)) return base64.StdEncoding.EncodeToString(encoded) } + +// APIKey hashes an API key using bcrypt for secure storage +// Cost of 12 provides a good balance between security and performance +func APIKey(key string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(key), 12) + if err != nil { + return "", err + } + return string(hash), nil +} + +// CompareAPIKey compares a plaintext API key with a hash +// Supports both bcrypt (new) and SHA256 (legacy) hashes for backwards compatibility +// Returns true if they match, false otherwise +// +// Migration Strategy: +// Old API keys stored with SHA256 will continue to work. To upgrade them to bcrypt: +// 1. Keys are automatically upgraded when users re-authenticate (if implemented) +// 2. Or, run a one-time migration using MigrateAPIKeyHash when convenient +func CompareAPIKey(hashedKey, plainKey string) bool { + // Bcrypt hashes start with $2a$, $2b$, or $2y$ + // If the hash starts with $2, it's a bcrypt hash + if len(hashedKey) > 2 && hashedKey[0] == '$' && hashedKey[1] == '2' { + // New bcrypt hash + err := bcrypt.CompareHashAndPassword([]byte(hashedKey), []byte(plainKey)) + return err == nil + } + + // Legacy SHA256 hash - compare using the old method + legacyHash := String(plainKey) + return hashedKey == legacyHash +} + +// IsLegacyHash returns true if the hash is a legacy SHA256 hash (not bcrypt) +func IsLegacyHash(hashedKey string) bool { + return len(hashedKey) <= 2 || hashedKey[0] != '$' || hashedKey[1] != '2' +} + +// MigrateAPIKeyHash can be used to upgrade a legacy SHA256 hash to bcrypt +// This is useful for one-time migrations of existing keys +// Returns the new bcrypt hash if the key is legacy, otherwise returns the original +func MigrateAPIKeyHash(currentHash, plainKey string) (string, bool, error) { + // If already bcrypt, no migration needed + if !IsLegacyHash(currentHash) { + return currentHash, false, nil + } + + // Verify the legacy hash is correct before migrating + if !CompareAPIKey(currentHash, plainKey) { + return "", false, nil // Invalid key, don't migrate + } + + // Generate new bcrypt hash + newHash, err := APIKey(plainKey) + if err != nil { + return "", false, err + } + + return newHash, true, nil +} diff --git a/hash/hash_test.go b/hash/hash_test.go new file mode 100644 index 0000000..c7b515e --- /dev/null +++ b/hash/hash_test.go @@ -0,0 +1,169 @@ +package hash + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAPIKey(t *testing.T) { + key := "test_api_key_12345" // gitleaks:allow + + hash1, err := APIKey(key) + require.NoError(t, err) + assert.NotEmpty(t, hash1) + assert.NotEqual(t, key, hash1, "Hash should not equal plaintext") + + // Bcrypt hashes should start with $2 + assert.True(t, strings.HasPrefix(hash1, "$2"), "Should be a bcrypt hash") + + // Same key should produce different hashes (due to salt) + hash2, err := APIKey(key) + require.NoError(t, err) + assert.NotEqual(t, hash1, hash2, "Bcrypt should produce different hashes with different salts") +} + +func TestCompareAPIKey_Bcrypt(t *testing.T) { + key := "test_api_key_12345" // gitleaks:allow + + hash, err := APIKey(key) + require.NoError(t, err) + + // Correct key should match + assert.True(t, CompareAPIKey(hash, key)) + + // Wrong key should not match + assert.False(t, CompareAPIKey(hash, "wrong_key")) +} + +func TestCompareAPIKey_Legacy(t *testing.T) { + key := "test_api_key_12345" // gitleaks:allow + + // Create a legacy SHA256 hash + legacyHash := String(key) + + // Should still work with legacy hashes + assert.True(t, CompareAPIKey(legacyHash, key)) + + // Wrong key should not match + assert.False(t, CompareAPIKey(legacyHash, "wrong_key")) +} + +func TestCompareAPIKey_BackwardCompatibility(t *testing.T) { + tests := []struct { + name string + hashFunc func(string) string + expectOK bool + }{ + { + name: "bcrypt hash", + hashFunc: func(k string) string { + h, _ := APIKey(k) + return h + }, + expectOK: true, + }, + { + name: "legacy SHA256 hash", + hashFunc: func(k string) string { + return String(k) + }, + expectOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := "test_key_123" + hash := tt.hashFunc(key) + + result := CompareAPIKey(hash, key) + assert.Equal(t, tt.expectOK, result) + }) + } +} + +func TestString(t *testing.T) { + // Test that String function still works (for non-sensitive data) + input := "test_string" + hash1 := String(input) + hash2 := String(input) + + // SHA256 should be deterministic + assert.Equal(t, hash1, hash2) + assert.NotEmpty(t, hash1) + assert.NotEqual(t, input, hash1) +} + +func TestIsLegacyHash(t *testing.T) { + tests := []struct { + name string + hash string + isLegacy bool + }{ + { + name: "bcrypt hash", + hash: "$2a$12$abcdefghijklmnopqrstuv", + isLegacy: false, + }, + { + name: "SHA256 hash", + hash: "dXNfYWtfMTIzNDU2Nzg5MDEyMzQ1NuOwxEKY", + isLegacy: true, + }, + { + name: "empty string", + hash: "", + isLegacy: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.isLegacy, IsLegacyHash(tt.hash)) + }) + } +} + +func TestMigrateAPIKeyHash(t *testing.T) { + plainKey := "test_api_key_123" + + t.Run("migrate legacy hash", func(t *testing.T) { + // Create a legacy SHA256 hash + legacyHash := String(plainKey) + + // Migrate it + newHash, migrated, err := MigrateAPIKeyHash(legacyHash, plainKey) + require.NoError(t, err) + assert.True(t, migrated, "Should indicate migration occurred") + assert.NotEqual(t, legacyHash, newHash, "New hash should differ from legacy") + assert.True(t, strings.HasPrefix(newHash, "$2"), "New hash should be bcrypt") + + // Verify new hash works + assert.True(t, CompareAPIKey(newHash, plainKey)) + }) + + t.Run("no migration needed for bcrypt", func(t *testing.T) { + // Create a bcrypt hash + bcryptHash, err := APIKey(plainKey) + require.NoError(t, err) + + // Try to migrate it + newHash, migrated, err := MigrateAPIKeyHash(bcryptHash, plainKey) + require.NoError(t, err) + assert.False(t, migrated, "Should not migrate bcrypt hash") + assert.Equal(t, bcryptHash, newHash, "Hash should remain unchanged") + }) + + t.Run("invalid key does not migrate", func(t *testing.T) { + legacyHash := String("correct_key") + + // Try to migrate with wrong plaintext + newHash, migrated, err := MigrateAPIKeyHash(legacyHash, "wrong_key") + require.NoError(t, err) + assert.False(t, migrated, "Should not migrate invalid key") + assert.Empty(t, newHash, "Should return empty for invalid key") + }) +} diff --git a/middleware/auth.go b/middleware/auth.go index bbe3a9e..8b57007 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -9,7 +9,6 @@ import ( "github.com/golang-jwt/jwt/v5" "gitlab.com/unboundsoftware/schemas/domain" - "gitlab.com/unboundsoftware/schemas/hash" ) const ( @@ -49,8 +48,8 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler { _, _ = w.Write([]byte("Invalid API Key format")) return } - hashedKey := hash.String(apiKey) - organization := m.cache.OrganizationByAPIKey(hashedKey) + // Cache handles hash comparison internally + organization := m.cache.OrganizationByAPIKey(apiKey) if organization != nil { ctx = context.WithValue(ctx, OrganizationKey, *organization) } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 3e456c8..3265e35 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -15,7 +15,6 @@ import ( "gitlab.com/unboundsoftware/eventsourced/eventsourced" "gitlab.com/unboundsoftware/schemas/domain" - "gitlab.com/unboundsoftware/schemas/hash" ) // MockCache is a mock implementation of the Cache interface @@ -45,9 +44,9 @@ func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) { } apiKey := "test-api-key-123" - hashedKey := hash.String(apiKey) - mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + // Mock expects plaintext key (cache handles hashing internally) + mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg) // Create a test handler that checks the context var capturedOrg *domain.Organization @@ -84,9 +83,9 @@ func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) { authMiddleware := NewAuth(mockCache) apiKey := "invalid-api-key" - hashedKey := hash.String(apiKey) - mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil) + // Mock expects plaintext key (cache handles hashing internally) + mockCache.On("OrganizationByAPIKey", apiKey).Return(nil) // Create a test handler that checks the context var capturedOrg *domain.Organization @@ -120,9 +119,8 @@ func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) { mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) - // The middleware always hashes the API key (even if empty) and calls the cache - emptyKeyHash := hash.String("") - mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + // The middleware passes the plaintext API key (cache handles hashing) + mockCache.On("OrganizationByAPIKey", "").Return(nil) // Create a test handler that checks the context var capturedOrg *domain.Organization @@ -153,9 +151,8 @@ func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) { mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) - // The middleware always hashes the API key (even if empty) and calls the cache - emptyKeyHash := hash.String("") - mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + // The middleware passes the plaintext API key (cache handles hashing) + mockCache.On("OrganizationByAPIKey", "").Return(nil) userID := "user-123" token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ @@ -251,13 +248,13 @@ func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) { userID := "user-123" apiKey := "test-api-key-123" - hashedKey := hash.String(apiKey) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": userID, }) - mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + // Mock expects plaintext key (cache handles hashing internally) + mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg) // Create a test handler that checks both user and organization in context var capturedUser string