From 130e92dc5f84a0a398a8c3a4ef40b14fb44f2248 Mon Sep 17 00:00:00 2001 From: Joakim Olsson Date: Fri, 21 Nov 2025 10:21:08 +0100 Subject: [PATCH] feat(cache): add concurrency safety and logging improvements Implement read-write mutex locks for cache functions to ensure concurrency safety. Add debug logging for cache updates to enhance traceability of operations. Optimize user addition logic to prevent duplicates. Introduce a new test file for comprehensive cache functionality testing, ensuring reliable behavior. --- cache/cache.go | 45 ++++- cache/cache_test.go | 446 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 487 insertions(+), 4 deletions(-) create mode 100644 cache/cache_test.go diff --git a/cache/cache.go b/cache/cache.go index ed058a7..8e93f7f 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -3,15 +3,18 @@ package cache import ( "fmt" "log/slog" + "sync" "time" "github.com/sparetimecoders/goamqp" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" "gitlab.com/unboundsoftware/schemas/domain" "gitlab.com/unboundsoftware/schemas/hash" ) type Cache struct { + mu sync.RWMutex organizations map[string]domain.Organization users map[string][]string apiKeys map[string]domain.APIKey // keyed by organizationId-name @@ -22,6 +25,9 @@ type Cache struct { } func (c *Cache) OrganizationByAPIKey(apiKey string) *domain.Organization { + c.mu.RLock() + defer c.mu.RUnlock() + // Find the API key by comparing hashes for _, key := range c.apiKeys { if hash.CompareAPIKey(key.Key, apiKey) { @@ -36,6 +42,9 @@ func (c *Cache) OrganizationByAPIKey(apiKey string) *domain.Organization { } func (c *Cache) OrganizationsByUser(sub string) []domain.Organization { + c.mu.RLock() + defer c.mu.RUnlock() + orgIds := c.users[sub] orgs := make([]domain.Organization, len(orgIds)) for i, id := range orgIds { @@ -45,6 +54,9 @@ func (c *Cache) OrganizationsByUser(sub string) []domain.Organization { } func (c *Cache) ApiKeyByKey(key string) *domain.APIKey { + c.mu.RLock() + defer c.mu.RUnlock() + // Find the API key by comparing hashes for _, apiKey := range c.apiKeys { if hash.CompareAPIKey(apiKey.Key, key) { @@ -55,6 +67,9 @@ func (c *Cache) ApiKeyByKey(key string) *domain.APIKey { } func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) { + c.mu.RLock() + defer c.mu.RUnlock() + key := refKey(orgId, ref) var services []string if lastUpdate == "" || c.lastUpdate[key] > lastUpdate { @@ -66,16 +81,25 @@ func (c *Cache) Services(orgId, ref, lastUpdate string) ([]string, string) { } func (c *Cache) SubGraphId(orgId, ref, service string) string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.subGraphs[subGraphKey(orgId, ref, service)] } func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) { + c.mu.Lock() + defer c.mu.Unlock() + switch m := msg.(type) { case *domain.OrganizationAdded: - o := domain.Organization{} + o := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(m.ID.String()), + } m.UpdateOrganization(&o) c.organizations[m.ID.String()] = o c.addUser(m.Initiator, o) + c.logger.With("org_id", m.ID.String(), "event", "OrganizationAdded").Debug("cache updated") case *domain.APIKeyAdded: key := domain.APIKey{ Name: m.Name, @@ -92,8 +116,10 @@ func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) { org := c.organizations[m.OrganizationId] org.APIKeys = append(org.APIKeys, key) c.organizations[m.OrganizationId] = org + c.logger.With("org_id", m.OrganizationId, "key_name", m.Name, "event", "APIKeyAdded").Debug("cache updated") case *domain.SubGraphUpdated: c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.Time) + c.logger.With("org_id", m.OrganizationId, "ref", m.Ref, "service", m.Service, "event", "SubGraphUpdated").Debug("cache updated") case *domain.Organization: c.organizations[m.ID.String()] = *m c.addUser(m.CreatedBy, *m) @@ -101,8 +127,10 @@ func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) { // Use composite key: organizationId-name c.apiKeys[apiKeyId(k.OrganizationId, k.Name)] = k } + c.logger.With("org_id", m.ID.String(), "event", "Organization aggregate loaded").Debug("cache updated") case *domain.SubGraph: c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.ChangedAt) + c.logger.With("org_id", m.OrganizationId, "ref", m.Ref, "service", m.Service, "event", "SubGraph aggregate loaded").Debug("cache updated") default: c.logger.With("msg", msg).Warn("unexpected message received") } @@ -123,11 +151,20 @@ func (c *Cache) updateSubGraph(orgId string, ref string, subGraphId string, serv func (c *Cache) addUser(sub string, organization domain.Organization) { user, exists := c.users[sub] + orgId := organization.ID.String() if !exists { - c.users[sub] = []string{organization.ID.String()} - } else { - c.users[sub] = append(user, organization.ID.String()) + c.users[sub] = []string{orgId} + return } + + // Check if organization already exists for this user + for _, id := range user { + if id == orgId { + return // Already exists, no need to add + } + } + + c.users[sub] = append(user, orgId) } func New(logger *slog.Logger) *Cache { diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..f743461 --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,446 @@ +package cache + +import ( + "log/slog" + "os" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" + + "gitlab.com/unboundsoftware/schemas/domain" + "gitlab.com/unboundsoftware/schemas/hash" +) + +func TestCache_OrganizationByAPIKey(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + apiKey := "test-api-key-123" // gitleaks:allow + hashedKey, err := hash.APIKey(apiKey) + require.NoError(t, err) + + // Add organization to cache + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(orgID), + Name: "Test Org", + } + c.organizations[orgID] = org + + // Add API key to cache + c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{ + Name: "test-key", + OrganizationId: orgID, + Key: hashedKey, + Refs: []string{"main"}, + Read: true, + Publish: true, + } + + // Test finding organization by plaintext API key + foundOrg := c.OrganizationByAPIKey(apiKey) + require.NotNil(t, foundOrg) + assert.Equal(t, org.Name, foundOrg.Name) + assert.Equal(t, orgID, foundOrg.ID.String()) + + // Test with wrong API key + notFoundOrg := c.OrganizationByAPIKey("wrong-key") + assert.Nil(t, notFoundOrg) +} + +func TestCache_OrganizationByAPIKey_Legacy(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + apiKey := "legacy-api-key-456" // gitleaks:allow + legacyHash := hash.String(apiKey) + + // Add organization to cache + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(orgID), + Name: "Legacy Org", + } + c.organizations[orgID] = org + + // Add API key with legacy SHA256 hash + c.apiKeys[apiKeyId(orgID, "legacy-key")] = domain.APIKey{ + Name: "legacy-key", + OrganizationId: orgID, + Key: legacyHash, + Refs: []string{"main"}, + Read: true, + Publish: false, + } + + // Test finding organization with legacy hash + foundOrg := c.OrganizationByAPIKey(apiKey) + require.NotNil(t, foundOrg) + assert.Equal(t, org.Name, foundOrg.Name) +} + +func TestCache_OrganizationsByUser(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + userSub := "user-123" + org1ID := uuid.New().String() + org2ID := uuid.New().String() + + org1 := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(org1ID), + Name: "Org 1", + } + org2 := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(org2ID), + Name: "Org 2", + } + + c.organizations[org1ID] = org1 + c.organizations[org2ID] = org2 + c.users[userSub] = []string{org1ID, org2ID} + + orgs := c.OrganizationsByUser(userSub) + assert.Len(t, orgs, 2) + assert.Contains(t, []string{orgs[0].Name, orgs[1].Name}, "Org 1") + assert.Contains(t, []string{orgs[0].Name, orgs[1].Name}, "Org 2") +} + +func TestCache_ApiKeyByKey(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + apiKey := "test-api-key-789" // gitleaks:allow + hashedKey, err := hash.APIKey(apiKey) + require.NoError(t, err) + + expectedKey := domain.APIKey{ + Name: "test-key", + OrganizationId: orgID, + Key: hashedKey, + Refs: []string{"main", "dev"}, + Read: true, + Publish: true, + } + + c.apiKeys[apiKeyId(orgID, "test-key")] = expectedKey + + foundKey := c.ApiKeyByKey(apiKey) + require.NotNil(t, foundKey) + assert.Equal(t, expectedKey.Name, foundKey.Name) + assert.Equal(t, expectedKey.OrganizationId, foundKey.OrganizationId) + assert.Equal(t, expectedKey.Refs, foundKey.Refs) + + // Test with wrong key + notFoundKey := c.ApiKeyByKey("wrong-key") + assert.Nil(t, notFoundKey) +} + +func TestCache_Services(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + ref := "main" + service1 := "service-1" + service2 := "service-2" + lastUpdate := "2024-01-01T12:00:00Z" + + c.services[orgID] = map[string]map[string]struct{}{ + ref: { + service1: {}, + service2: {}, + }, + } + c.lastUpdate[refKey(orgID, ref)] = lastUpdate + + // Test getting services with empty lastUpdate + services, returnedLastUpdate := c.Services(orgID, ref, "") + assert.Len(t, services, 2) + assert.Contains(t, services, service1) + assert.Contains(t, services, service2) + assert.Equal(t, lastUpdate, returnedLastUpdate) + + // Test with older lastUpdate (should return services) + services, returnedLastUpdate = c.Services(orgID, ref, "2023-12-31T12:00:00Z") + assert.Len(t, services, 2) + assert.Equal(t, lastUpdate, returnedLastUpdate) + + // Test with newer lastUpdate (should return empty) + services, returnedLastUpdate = c.Services(orgID, ref, "2024-01-02T12:00:00Z") + assert.Len(t, services, 0) + assert.Equal(t, lastUpdate, returnedLastUpdate) +} + +func TestCache_SubGraphId(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + ref := "main" + service := "test-service" + subGraphID := uuid.New().String() + + c.subGraphs[subGraphKey(orgID, ref, service)] = subGraphID + + foundID := c.SubGraphId(orgID, ref, service) + assert.Equal(t, subGraphID, foundID) + + // Test with non-existent key + notFoundID := c.SubGraphId("wrong-org", ref, service) + assert.Empty(t, notFoundID) +} + +func TestCache_Update_OrganizationAdded(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + event := &domain.OrganizationAdded{ + Name: "New Org", + Initiator: "user-123", + } + event.ID = *eventsourced.IdFromString(orgID) + + _, err := c.Update(event, nil) + require.NoError(t, err) + + // Verify organization was added + org, exists := c.organizations[orgID] + assert.True(t, exists) + assert.Equal(t, "New Org", org.Name) + + // Verify user was added + assert.Contains(t, c.users["user-123"], orgID) +} + +func TestCache_Update_APIKeyAdded(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + keyName := "test-key" + hashedKey := "hashed-key-value" + + // Add organization first + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(orgID), + Name: "Test Org", + APIKeys: []domain.APIKey{}, + } + c.organizations[orgID] = org + + event := &domain.APIKeyAdded{ + OrganizationId: orgID, + Name: keyName, + Key: hashedKey, + Refs: []string{"main"}, + Read: true, + Publish: false, + Initiator: "user-123", + } + event.ID = *eventsourced.IdFromString(uuid.New().String()) + + _, err := c.Update(event, nil) + require.NoError(t, err) + + // Verify API key was added to cache + key, exists := c.apiKeys[apiKeyId(orgID, keyName)] + assert.True(t, exists) + assert.Equal(t, keyName, key.Name) + assert.Equal(t, hashedKey, key.Key) + assert.Equal(t, []string{"main"}, key.Refs) + + // Verify API key was added to organization + updatedOrg := c.organizations[orgID] + assert.Len(t, updatedOrg.APIKeys, 1) + assert.Equal(t, keyName, updatedOrg.APIKeys[0].Name) +} + +func TestCache_Update_SubGraphUpdated(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + orgID := uuid.New().String() + ref := "main" + service := "test-service" + subGraphID := uuid.New().String() + + event := &domain.SubGraphUpdated{ + OrganizationId: orgID, + Ref: ref, + Service: service, + Initiator: "user-123", + } + event.ID = *eventsourced.IdFromString(subGraphID) + event.SetWhen(time.Now()) + + _, err := c.Update(event, nil) + require.NoError(t, err) + + // Verify subgraph was added to services + assert.Contains(t, c.services[orgID][ref], subGraphID) + + // Verify subgraph ID was stored + assert.Equal(t, subGraphID, c.subGraphs[subGraphKey(orgID, ref, service)]) + + // Verify lastUpdate was set + assert.NotEmpty(t, c.lastUpdate[refKey(orgID, ref)]) +} + +func TestCache_AddUser_NoDuplicates(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + userSub := "user-123" + orgID := uuid.New().String() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(orgID), + Name: "Test Org", + } + + // Add user first time + c.addUser(userSub, org) + assert.Len(t, c.users[userSub], 1) + assert.Equal(t, orgID, c.users[userSub][0]) + + // Add same user/org again - should not create duplicate + c.addUser(userSub, org) + assert.Len(t, c.users[userSub], 1, "Should not add duplicate organization") + assert.Equal(t, orgID, c.users[userSub][0]) +} + +func TestCache_ConcurrentReads(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + // Setup test data + orgID := uuid.New().String() + apiKey := "test-concurrent-key" // gitleaks:allow + hashedKey, err := hash.APIKey(apiKey) + require.NoError(t, err) + + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(orgID), + Name: "Concurrent Test Org", + } + c.organizations[orgID] = org + c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{ + Name: "test-key", + OrganizationId: orgID, + Key: hashedKey, + } + + // Run concurrent reads + var wg sync.WaitGroup + numGoroutines := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + org := c.OrganizationByAPIKey(apiKey) + assert.NotNil(t, org) + assert.Equal(t, "Concurrent Test Org", org.Name) + }() + } + + wg.Wait() +} + +func TestCache_ConcurrentWrites(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + var wg sync.WaitGroup + numGoroutines := 50 + + // Concurrent organization additions + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + orgID := uuid.New().String() + event := &domain.OrganizationAdded{ + Name: "Org " + string(rune(index)), + Initiator: "user-" + string(rune(index)), + } + event.ID = *eventsourced.IdFromString(orgID) + _, err := c.Update(event, nil) + assert.NoError(t, err) + }(i) + } + + wg.Wait() + + // Verify all organizations were added + assert.Equal(t, numGoroutines, len(c.organizations)) +} + +func TestCache_ConcurrentReadsAndWrites(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c := New(logger) + + // Setup initial data + orgID := uuid.New().String() + apiKey := "test-rw-key" // gitleaks:allow + hashedKey, err := hash.APIKey(apiKey) + require.NoError(t, err) + + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregateFromString(orgID), + Name: "RW Test Org", + } + c.organizations[orgID] = org + c.apiKeys[apiKeyId(orgID, "test-key")] = domain.APIKey{ + Name: "test-key", + OrganizationId: orgID, + Key: hashedKey, + } + c.users["user-initial"] = []string{orgID} + + var wg sync.WaitGroup + numReaders := 50 + numWriters := 20 + + // Concurrent readers + for i := 0; i < numReaders; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + org := c.OrganizationByAPIKey(apiKey) + assert.NotNil(t, org) + orgs := c.OrganizationsByUser("user-initial") + assert.NotEmpty(t, orgs) + } + }() + } + + // Concurrent writers + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + newOrgID := uuid.New().String() + event := &domain.OrganizationAdded{ + Name: "New Org " + string(rune(index)), + Initiator: "user-new-" + string(rune(index)), + } + event.ID = *eventsourced.IdFromString(newOrgID) + _, err := c.Update(event, nil) + assert.NoError(t, err) + }(i) + } + + wg.Wait() + + // Verify cache is in consistent state + assert.GreaterOrEqual(t, len(c.organizations), numWriters) +}