fix: enhance API key handling and logging in middleware #627
@@ -0,0 +1,336 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/handler/transport"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation for testing
|
||||
type MockCache struct {
|
||||
organizations map[string]*domain.Organization
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
return m.organizations[apiKey]
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
orgID := uuid.New()
|
||||
org := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey: org,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is in context
|
||||
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok, "Organization should be of correct type")
|
||||
assert.Equal(t, org.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization not found in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is NOT in context (since API key is invalid)
|
||||
value := resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when not provided")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": "", // Empty string
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (because empty string fails the condition)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when empty")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (type assertion fails)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when wrong type")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||
// Setup - create multiple organizations
|
||||
org1ID := uuid.New()
|
||||
org1 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org1ID.String()),
|
||||
},
|
||||
Name: "Organization 1",
|
||||
}
|
||||
|
||||
org2ID := uuid.New()
|
||||
org2 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org2ID.String()),
|
||||
},
|
||||
Name: "Organization 2",
|
||||
}
|
||||
|
||||
apiKey1 := "api-key-org-1"
|
||||
apiKey2 := "api-key-org-2"
|
||||
hashedKey1 := hash.String(apiKey1)
|
||||
hashedKey2 := hash.String(apiKey2)
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey1: org1,
|
||||
hashedKey2: org2,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test with first API key
|
||||
ctx1 := context.Background()
|
||||
initPayload1 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey1,
|
||||
}
|
||||
|
||||
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org1.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 1 not found in context")
|
||||
}
|
||||
|
||||
// Test with second API key
|
||||
ctx2 := context.Background()
|
||||
initPayload2 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey2,
|
||||
}
|
||||
|
||||
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org2.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 2 not found in context")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,467 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation of the Cache interface
|
||||
type MockCache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
args := m.Called(apiKey)
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*domain.Organization)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request without API key
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
userID := "user-123"
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedUser string
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with JWT token in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid API key type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid JWT token type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
userID := "user-123"
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks both user and organization in context
|
||||
var capturedUser string
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with both JWT and API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid user",
|
||||
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
|
||||
expected: "user-123",
|
||||
},
|
||||
{
|
||||
name: "without user",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), UserKey, 123),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := UserFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationFromContext(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid organization",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, org),
|
||||
expected: orgID.String(),
|
||||
},
|
||||
{
|
||||
name: "without organization",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := OrganizationFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
|
||||
// Test with user present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without user
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no user available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with organization present
|
||||
ctx := context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without organization
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no organization available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with both present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
ctx = context.WithValue(ctx, OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with only user
|
||||
ctx = context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test with only organization
|
||||
ctx = context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// Test with no requirements
|
||||
ctx := context.Background()
|
||||
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "success", result)
|
||||
}
|
||||
Reference in New Issue
Block a user