feat(tests): add unit tests for WebSocket initialization logic
Adds unit tests for the WebSocket initialization function to validate behavior with valid, invalid, and absent API keys. Introduces a mock cache implementation to simulate organization retrieval based on hashed API keys. Ensures proper context value setting upon initialization, enhancing test coverage and reliability for API key handling in WebSocket connections.
This commit is contained in:
@@ -0,0 +1,467 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation of the Cache interface
|
||||
type MockCache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
args := m.Called(apiKey)
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*domain.Organization)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request without API key
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
userID := "user-123"
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedUser string
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with JWT token in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid API key type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid JWT token type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
userID := "user-123"
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks both user and organization in context
|
||||
var capturedUser string
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with both JWT and API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid user",
|
||||
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
|
||||
expected: "user-123",
|
||||
},
|
||||
{
|
||||
name: "without user",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), UserKey, 123),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := UserFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationFromContext(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid organization",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, org),
|
||||
expected: orgID.String(),
|
||||
},
|
||||
{
|
||||
name: "without organization",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := OrganizationFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
|
||||
// Test with user present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without user
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no user available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with organization present
|
||||
ctx := context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without organization
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no organization available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with both present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
ctx = context.WithValue(ctx, OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with only user
|
||||
ctx = context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test with only organization
|
||||
ctx = context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// Test with no requirements
|
||||
ctx := context.Background()
|
||||
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "success", result)
|
||||
}
|
||||
Reference in New Issue
Block a user