package middleware import ( "context" "net/http" "net/http/httptest" "testing" mw "github.com/auth0/go-jwt-middleware/v2" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gitlab.com/unboundsoftware/eventsourced/eventsourced" "gitea.unbound.se/unboundsoftware/schemas/domain" ) // MockCache is a mock implementation of the Cache interface type MockCache struct { mock.Mock } func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { args := m.Called(apiKey) if args.Get(0) == nil { return nil } return args.Get(0).(*domain.Organization) } func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) { // Setup mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) orgID := uuid.New() expectedOrg := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(orgID.String()), }, Name: "Test Organization", } apiKey := "test-api-key-123" // Mock expects plaintext key (cache handles hashing internally) mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg) // Create a test handler that checks the context var capturedOrg *domain.Organization testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if org := r.Context().Value(OrganizationKey); org != nil { if o, ok := org.(domain.Organization); ok { capturedOrg = &o } } w.WriteHeader(http.StatusOK) }) // Create request with API key in context req := httptest.NewRequest(http.MethodGet, "/test", nil) ctx := context.WithValue(req.Context(), ApiKey, apiKey) req = req.WithContext(ctx) rec := httptest.NewRecorder() // Execute authMiddleware.Handler(testHandler).ServeHTTP(rec, req) // Assert assert.Equal(t, http.StatusOK, rec.Code) require.NotNil(t, capturedOrg) assert.Equal(t, expectedOrg.Name, capturedOrg.Name) assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String()) mockCache.AssertExpectations(t) } func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) { // Setup mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) apiKey := "invalid-api-key" // Mock expects plaintext key (cache handles hashing internally) mockCache.On("OrganizationByAPIKey", apiKey).Return(nil) // Create a test handler that checks the context var capturedOrg *domain.Organization testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if org := r.Context().Value(OrganizationKey); org != nil { if o, ok := org.(domain.Organization); ok { capturedOrg = &o } } w.WriteHeader(http.StatusOK) }) // Create request with API key in context req := httptest.NewRequest(http.MethodGet, "/test", nil) ctx := context.WithValue(req.Context(), ApiKey, apiKey) req = req.WithContext(ctx) rec := httptest.NewRecorder() // Execute authMiddleware.Handler(testHandler).ServeHTTP(rec, req) // Assert assert.Equal(t, http.StatusOK, rec.Code) assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key") mockCache.AssertExpectations(t) } func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) { // Setup mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) // The middleware passes the plaintext API key (cache handles hashing) mockCache.On("OrganizationByAPIKey", "").Return(nil) // Create a test handler that checks the context var capturedOrg *domain.Organization testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if org := r.Context().Value(OrganizationKey); org != nil { if o, ok := org.(domain.Organization); ok { capturedOrg = &o } } w.WriteHeader(http.StatusOK) }) // Create request without API key req := httptest.NewRequest(http.MethodGet, "/test", nil) rec := httptest.NewRecorder() // Execute authMiddleware.Handler(testHandler).ServeHTTP(rec, req) // Assert assert.Equal(t, http.StatusOK, rec.Code) assert.Nil(t, capturedOrg, "Organization should not be set without API key") mockCache.AssertExpectations(t) } func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) { // Setup mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) // The middleware passes the plaintext API key (cache handles hashing) mockCache.On("OrganizationByAPIKey", "").Return(nil) userID := "user-123" token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": userID, }) // Create a test handler that checks the context var capturedUser string testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if user := r.Context().Value(UserKey); user != nil { if u, ok := user.(string); ok { capturedUser = u } } w.WriteHeader(http.StatusOK) }) // Create request with JWT token in context req := httptest.NewRequest(http.MethodGet, "/test", nil) ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) req = req.WithContext(ctx) rec := httptest.NewRecorder() // Execute authMiddleware.Handler(testHandler).ServeHTTP(rec, req) // Assert assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, userID, capturedUser) } func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) { // Setup mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Create request with invalid API key type in context req := httptest.NewRequest(http.MethodGet, "/test", nil) ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type req = req.WithContext(ctx) rec := httptest.NewRecorder() // Execute authMiddleware.Handler(testHandler).ServeHTTP(rec, req) // Assert assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, rec.Body.String(), "Invalid API Key format") } func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) { // Setup mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Create request with invalid JWT token type in context req := httptest.NewRequest(http.MethodGet, "/test", nil) ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type req = req.WithContext(ctx) rec := httptest.NewRecorder() // Execute authMiddleware.Handler(testHandler).ServeHTTP(rec, req) // Assert assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, rec.Body.String(), "Invalid JWT token format") } func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) { // Setup mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) orgID := uuid.New() expectedOrg := &domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(orgID.String()), }, Name: "Test Organization", } userID := "user-123" apiKey := "test-api-key-123" token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": userID, }) // Mock expects plaintext key (cache handles hashing internally) mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg) // Create a test handler that checks both user and organization in context var capturedUser string var capturedOrg *domain.Organization testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if user := r.Context().Value(UserKey); user != nil { if u, ok := user.(string); ok { capturedUser = u } } if org := r.Context().Value(OrganizationKey); org != nil { if o, ok := org.(domain.Organization); ok { capturedOrg = &o } } w.WriteHeader(http.StatusOK) }) // Create request with both JWT and API key in context req := httptest.NewRequest(http.MethodGet, "/test", nil) ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) ctx = context.WithValue(ctx, ApiKey, apiKey) req = req.WithContext(ctx) rec := httptest.NewRecorder() // Execute authMiddleware.Handler(testHandler).ServeHTTP(rec, req) // Assert assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, userID, capturedUser) require.NotNil(t, capturedOrg) assert.Equal(t, expectedOrg.Name, capturedOrg.Name) mockCache.AssertExpectations(t) } func TestUserFromContext(t *testing.T) { tests := []struct { name string ctx context.Context expected string }{ { name: "with valid user", ctx: context.WithValue(context.Background(), UserKey, "user-123"), expected: "user-123", }, { name: "without user", ctx: context.Background(), expected: "", }, { name: "with invalid type", ctx: context.WithValue(context.Background(), UserKey, 123), expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := UserFromContext(tt.ctx) assert.Equal(t, tt.expected, result) }) } } func TestOrganizationFromContext(t *testing.T) { orgID := uuid.New() org := domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(orgID.String()), }, Name: "Test Org", } tests := []struct { name string ctx context.Context expected string }{ { name: "with valid organization", ctx: context.WithValue(context.Background(), OrganizationKey, org), expected: orgID.String(), }, { name: "without organization", ctx: context.Background(), expected: "", }, { name: "with invalid type", ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"), expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := OrganizationFromContext(tt.ctx) assert.Equal(t, tt.expected, result) }) } } func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) { mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) requireUser := true // Test with user present ctx := context.WithValue(context.Background(), UserKey, "user-123") _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, &requireUser, nil) assert.NoError(t, err) // Test without user ctx = context.Background() _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, &requireUser, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "no user available in request") } func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) { mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) requireOrg := true orgID := uuid.New() org := domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(orgID.String()), }, Name: "Test Org", } // Test with organization present ctx := context.WithValue(context.Background(), OrganizationKey, org) _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, nil, &requireOrg) assert.NoError(t, err) // Test without organization ctx = context.Background() _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, nil, &requireOrg) assert.Error(t, err) assert.Contains(t, err.Error(), "no organization available in request") } func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) { mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) requireUser := true requireOrg := true orgID := uuid.New() org := domain.Organization{ BaseAggregate: eventsourced.BaseAggregate{ ID: eventsourced.IdFromString(orgID.String()), }, Name: "Test Org", } // When both user and organization are marked as acceptable, // the directive uses OR logic - either one is sufficient // Test with both present - should succeed ctx := context.WithValue(context.Background(), UserKey, "user-123") ctx = context.WithValue(ctx, OrganizationKey, org) _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, &requireUser, &requireOrg) assert.NoError(t, err) // Test with only user - should succeed (OR logic) ctx = context.WithValue(context.Background(), UserKey, "user-123") _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, &requireUser, &requireOrg) assert.NoError(t, err) // Test with only organization - should succeed (OR logic) ctx = context.WithValue(context.Background(), OrganizationKey, org) _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, &requireUser, &requireOrg) assert.NoError(t, err) // Test with neither - should fail ctx = context.Background() _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, &requireUser, &requireOrg) assert.Error(t, err) assert.Contains(t, err.Error(), "authentication required") } func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) { mockCache := new(MockCache) authMiddleware := NewAuth(mockCache) // Test with no requirements ctx := context.Background() result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { return "success", nil }, nil, nil) assert.NoError(t, err) assert.Equal(t, "success", result) } func TestUserHasRole_WithValidRole(t *testing.T) { // Create token with roles claim token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": "user-123", "https://unbound.se/roles": []interface{}{"admin", "user"}, }) ctx := context.WithValue(context.Background(), mw.ContextKey{}, token) // Test for existing role hasRole := UserHasRole(ctx, "admin") assert.True(t, hasRole) hasRole = UserHasRole(ctx, "user") assert.True(t, hasRole) } func TestUserHasRole_WithoutRole(t *testing.T) { // Create token with roles claim token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": "user-123", "https://unbound.se/roles": []interface{}{"user"}, }) ctx := context.WithValue(context.Background(), mw.ContextKey{}, token) // Test for non-existing role hasRole := UserHasRole(ctx, "admin") assert.False(t, hasRole) } func TestUserHasRole_WithoutRolesClaim(t *testing.T) { // Create token without roles claim token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": "user-123", }) ctx := context.WithValue(context.Background(), mw.ContextKey{}, token) // Test should return false when roles claim is missing hasRole := UserHasRole(ctx, "admin") assert.False(t, hasRole) } func TestUserHasRole_WithoutToken(t *testing.T) { ctx := context.Background() // Test should return false when no token in context hasRole := UserHasRole(ctx, "admin") assert.False(t, hasRole) } func TestUserHasRole_WithInvalidTokenType(t *testing.T) { // Put invalid token type in context ctx := context.WithValue(context.Background(), mw.ContextKey{}, "not-a-token") // Test should return false when token type is invalid hasRole := UserHasRole(ctx, "admin") assert.False(t, hasRole) } func TestUserHasRole_WithInvalidRolesType(t *testing.T) { // Create token with invalid roles type token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": "user-123", "https://unbound.se/roles": "not-an-array", }) ctx := context.WithValue(context.Background(), mw.ContextKey{}, token) // Test should return false when roles type is invalid hasRole := UserHasRole(ctx, "admin") assert.False(t, hasRole) } func TestUserHasRole_WithInvalidRoleElementType(t *testing.T) { // Create token with invalid role element types token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": "user-123", "https://unbound.se/roles": []interface{}{123, 456}, // Numbers instead of strings }) ctx := context.WithValue(context.Background(), mw.ContextKey{}, token) // Test should return false when role elements are not strings hasRole := UserHasRole(ctx, "admin") assert.False(t, hasRole) }