feat(tests): add unit tests for WebSocket initialization logic
Adds unit tests for the WebSocket initialization function to validate behavior with valid, invalid, and absent API keys. Introduces a mock cache implementation to simulate organization retrieval based on hashed API keys. Ensures proper context value setting upon initialization, enhancing test coverage and reliability for API key handling in WebSocket connections.
This commit is contained in:
@@ -0,0 +1,336 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/handler/transport"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation for testing
|
||||
type MockCache struct {
|
||||
organizations map[string]*domain.Organization
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
return m.organizations[apiKey]
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
orgID := uuid.New()
|
||||
org := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey: org,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is in context
|
||||
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok, "Organization should be of correct type")
|
||||
assert.Equal(t, org.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization not found in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is NOT in context (since API key is invalid)
|
||||
value := resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when not provided")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": "", // Empty string
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (because empty string fails the condition)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when empty")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (type assertion fails)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when wrong type")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||
// Setup - create multiple organizations
|
||||
org1ID := uuid.New()
|
||||
org1 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org1ID.String()),
|
||||
},
|
||||
Name: "Organization 1",
|
||||
}
|
||||
|
||||
org2ID := uuid.New()
|
||||
org2 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org2ID.String()),
|
||||
},
|
||||
Name: "Organization 2",
|
||||
}
|
||||
|
||||
apiKey1 := "api-key-org-1"
|
||||
apiKey2 := "api-key-org-2"
|
||||
hashedKey1 := hash.String(apiKey1)
|
||||
hashedKey2 := hash.String(apiKey2)
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey1: org1,
|
||||
hashedKey2: org2,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test with first API key
|
||||
ctx1 := context.Background()
|
||||
initPayload1 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey1,
|
||||
}
|
||||
|
||||
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org1.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 1 not found in context")
|
||||
}
|
||||
|
||||
// Test with second API key
|
||||
ctx2 := context.Background()
|
||||
initPayload2 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey2,
|
||||
}
|
||||
|
||||
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org2.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 2 not found in context")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,467 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation of the Cache interface
|
||||
type MockCache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
args := m.Called(apiKey)
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*domain.Organization)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request without API key
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
userID := "user-123"
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedUser string
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with JWT token in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid API key type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid JWT token type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
userID := "user-123"
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks both user and organization in context
|
||||
var capturedUser string
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with both JWT and API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid user",
|
||||
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
|
||||
expected: "user-123",
|
||||
},
|
||||
{
|
||||
name: "without user",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), UserKey, 123),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := UserFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationFromContext(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid organization",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, org),
|
||||
expected: orgID.String(),
|
||||
},
|
||||
{
|
||||
name: "without organization",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := OrganizationFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
|
||||
// Test with user present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without user
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no user available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with organization present
|
||||
ctx := context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without organization
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no organization available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with both present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
ctx = context.WithValue(ctx, OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with only user
|
||||
ctx = context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test with only organization
|
||||
ctx = context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// Test with no requirements
|
||||
ctx := context.Background()
|
||||
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "success", result)
|
||||
}
|
||||
Reference in New Issue
Block a user