feat(tests): add unit tests for WebSocket initialization logic

Adds unit tests for the WebSocket initialization function to validate
behavior with valid, invalid, and absent API keys. Introduces a mock
cache implementation to simulate organization retrieval based on
hashed API keys. Ensures proper context value setting upon
initialization, enhancing test coverage and reliability for API key
handling in WebSocket connections.
This commit is contained in:
2025-11-20 14:24:39 +01:00
parent bb0c08be06
commit 4d18cf4175
2 changed files with 803 additions and 0 deletions
+467
View File
@@ -0,0 +1,467 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
mw "github.com/auth0/go-jwt-middleware/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
)
// MockCache is a mock implementation of the Cache interface
type MockCache struct {
mock.Mock
}
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
args := m.Called(apiKey)
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*domain.Organization)
}
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
apiKey := "invalid-api-key"
hashedKey := hash.String(apiKey)
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware always hashes the API key (even if empty) and calls the cache
emptyKeyHash := hash.String("")
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request without API key
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware always hashes the API key (even if empty) and calls the cache
emptyKeyHash := hash.String("")
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
userID := "user-123"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
// Create a test handler that checks the context
var capturedUser string
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with JWT token in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
}
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid API key type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
}
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid JWT token type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
}
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
userID := "user-123"
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
// Create a test handler that checks both user and organization in context
var capturedUser string
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with both JWT and API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
ctx = context.WithValue(ctx, ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
mockCache.AssertExpectations(t)
}
func TestUserFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid user",
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
expected: "user-123",
},
{
name: "without user",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), UserKey, 123),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UserFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestOrganizationFromContext(t *testing.T) {
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid organization",
ctx: context.WithValue(context.Background(), OrganizationKey, org),
expected: orgID.String(),
},
{
name: "without organization",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OrganizationFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
// Test with user present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.NoError(t, err)
// Test without user
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no user available in request")
}
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with organization present
ctx := context.WithValue(context.Background(), OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.NoError(t, err)
// Test without organization
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no organization available in request")
}
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with both present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
ctx = context.WithValue(ctx, OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.NoError(t, err)
// Test with only user
ctx = context.WithValue(context.Background(), UserKey, "user-123")
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
// Test with only organization
ctx = context.WithValue(context.Background(), OrganizationKey, org)
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
}
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// Test with no requirements
ctx := context.Background()
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, nil)
assert.NoError(t, err)
assert.Equal(t, "success", result)
}