feat(tests): add unit tests for WebSocket initialization logic

Adds unit tests for the WebSocket initialization function to validate
behavior with valid, invalid, and absent API keys. Introduces a mock
cache implementation to simulate organization retrieval based on
hashed API keys. Ensures proper context value setting upon
initialization, enhancing test coverage and reliability for API key
handling in WebSocket connections.
This commit is contained in:
2025-11-20 14:24:39 +01:00
parent bb0c08be06
commit 4d18cf4175
2 changed files with 803 additions and 0 deletions
+336
View File
@@ -0,0 +1,336 @@
package main
import (
"context"
"testing"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
"gitlab.com/unboundsoftware/schemas/middleware"
)
// MockCache is a mock implementation for testing
type MockCache struct {
organizations map[string]*domain.Organization
}
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
return m.organizations[apiKey]
}
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
// Setup
orgID := uuid.New()
org := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
hashedKey: org,
},
}
// Create InitFunc (simulating the WebSocket InitFunc logic)
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is in context
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok, "Organization should be of correct type")
assert.Equal(t, org.Name, capturedOrg.Name)
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization not found in context")
}
}
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
apiKey := "invalid-api-key"
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is NOT in context (since API key is invalid)
value := resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set for invalid API key")
}
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when not provided")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is not provided")
}
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": "", // Empty string
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (because empty string fails the condition)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when empty")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is empty")
}
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": 12345, // Wrong type (int instead of string)
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (type assertion fails)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when wrong type")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
}
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
// Setup - create multiple organizations
org1ID := uuid.New()
org1 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org1ID.String()),
},
Name: "Organization 1",
}
org2ID := uuid.New()
org2 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org2ID.String()),
},
Name: "Organization 2",
}
apiKey1 := "api-key-org-1"
apiKey2 := "api-key-org-2"
hashedKey1 := hash.String(apiKey1)
hashedKey2 := hash.String(apiKey2)
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
hashedKey1: org1,
hashedKey2: org2,
},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test with first API key
ctx1 := context.Background()
initPayload1 := transport.InitPayload{
"X-Api-Key": apiKey1,
}
resultCtx1, _, err := initFunc(ctx1, initPayload1)
require.NoError(t, err)
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org1.Name, capturedOrg.Name)
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 1 not found in context")
}
// Test with second API key
ctx2 := context.Background()
initPayload2 := transport.InitPayload{
"X-Api-Key": apiKey2,
}
resultCtx2, _, err := initFunc(ctx2, initPayload2)
require.NoError(t, err)
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org2.Name, capturedOrg.Name)
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 2 not found in context")
}
}
+467
View File
@@ -0,0 +1,467 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
mw "github.com/auth0/go-jwt-middleware/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
)
// MockCache is a mock implementation of the Cache interface
type MockCache struct {
mock.Mock
}
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
args := m.Called(apiKey)
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*domain.Organization)
}
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
apiKey := "invalid-api-key"
hashedKey := hash.String(apiKey)
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware always hashes the API key (even if empty) and calls the cache
emptyKeyHash := hash.String("")
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request without API key
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware always hashes the API key (even if empty) and calls the cache
emptyKeyHash := hash.String("")
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
userID := "user-123"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
// Create a test handler that checks the context
var capturedUser string
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with JWT token in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
}
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid API key type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
}
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid JWT token type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
}
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
userID := "user-123"
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
// Create a test handler that checks both user and organization in context
var capturedUser string
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with both JWT and API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
ctx = context.WithValue(ctx, ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
mockCache.AssertExpectations(t)
}
func TestUserFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid user",
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
expected: "user-123",
},
{
name: "without user",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), UserKey, 123),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UserFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestOrganizationFromContext(t *testing.T) {
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid organization",
ctx: context.WithValue(context.Background(), OrganizationKey, org),
expected: orgID.String(),
},
{
name: "without organization",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OrganizationFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
// Test with user present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.NoError(t, err)
// Test without user
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no user available in request")
}
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with organization present
ctx := context.WithValue(context.Background(), OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.NoError(t, err)
// Test without organization
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no organization available in request")
}
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with both present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
ctx = context.WithValue(ctx, OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.NoError(t, err)
// Test with only user
ctx = context.WithValue(context.Background(), UserKey, "user-123")
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
// Test with only organization
ctx = context.WithValue(context.Background(), OrganizationKey, org)
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
}
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// Test with no requirements
ctx := context.Background()
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, nil)
assert.NoError(t, err)
assert.Equal(t, "success", result)
}