Files
schemas/middleware/auth_test.go
T
argoyle ffcf41b85a feat: add commands for managing organizations and users
Introduce `AddUserToOrganization`, `RemoveAPIKey`, and 
`RemoveOrganization` commands to enhance organization 
management. Implement validation for user addition and 
API key removal. Update GraphQL schema to support new 
mutations and add caching for the new events, ensuring 
that organizations and their relationships are accurately 
represented in the cache.
2025-11-22 18:37:07 +01:00

565 lines
16 KiB
Go

package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
mw "github.com/auth0/go-jwt-middleware/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
)
// MockCache is a mock implementation of the Cache interface
type MockCache struct {
mock.Mock
}
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
args := m.Called(apiKey)
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*domain.Organization)
}
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
// Mock expects plaintext key (cache handles hashing internally)
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
apiKey := "invalid-api-key"
// Mock expects plaintext key (cache handles hashing internally)
mockCache.On("OrganizationByAPIKey", apiKey).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware passes the plaintext API key (cache handles hashing)
mockCache.On("OrganizationByAPIKey", "").Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request without API key
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware passes the plaintext API key (cache handles hashing)
mockCache.On("OrganizationByAPIKey", "").Return(nil)
userID := "user-123"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
// Create a test handler that checks the context
var capturedUser string
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with JWT token in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
}
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid API key type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
}
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid JWT token type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
}
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
userID := "user-123"
apiKey := "test-api-key-123"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
// Mock expects plaintext key (cache handles hashing internally)
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
// Create a test handler that checks both user and organization in context
var capturedUser string
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with both JWT and API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
ctx = context.WithValue(ctx, ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
mockCache.AssertExpectations(t)
}
func TestUserFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid user",
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
expected: "user-123",
},
{
name: "without user",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), UserKey, 123),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UserFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestOrganizationFromContext(t *testing.T) {
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid organization",
ctx: context.WithValue(context.Background(), OrganizationKey, org),
expected: orgID.String(),
},
{
name: "without organization",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OrganizationFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
// Test with user present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.NoError(t, err)
// Test without user
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no user available in request")
}
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with organization present
ctx := context.WithValue(context.Background(), OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.NoError(t, err)
// Test without organization
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no organization available in request")
}
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// When both user and organization are marked as acceptable,
// the directive uses OR logic - either one is sufficient
// Test with both present - should succeed
ctx := context.WithValue(context.Background(), UserKey, "user-123")
ctx = context.WithValue(ctx, OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.NoError(t, err)
// Test with only user - should succeed (OR logic)
ctx = context.WithValue(context.Background(), UserKey, "user-123")
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.NoError(t, err)
// Test with only organization - should succeed (OR logic)
ctx = context.WithValue(context.Background(), OrganizationKey, org)
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.NoError(t, err)
// Test with neither - should fail
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication required")
}
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// Test with no requirements
ctx := context.Background()
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, nil)
assert.NoError(t, err)
assert.Equal(t, "success", result)
}
func TestUserHasRole_WithValidRole(t *testing.T) {
// Create token with roles claim
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "user-123",
"https://unbound.se/roles": []interface{}{"admin", "user"},
})
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
// Test for existing role
hasRole := UserHasRole(ctx, "admin")
assert.True(t, hasRole)
hasRole = UserHasRole(ctx, "user")
assert.True(t, hasRole)
}
func TestUserHasRole_WithoutRole(t *testing.T) {
// Create token with roles claim
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "user-123",
"https://unbound.se/roles": []interface{}{"user"},
})
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
// Test for non-existing role
hasRole := UserHasRole(ctx, "admin")
assert.False(t, hasRole)
}
func TestUserHasRole_WithoutRolesClaim(t *testing.T) {
// Create token without roles claim
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "user-123",
})
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
// Test should return false when roles claim is missing
hasRole := UserHasRole(ctx, "admin")
assert.False(t, hasRole)
}
func TestUserHasRole_WithoutToken(t *testing.T) {
ctx := context.Background()
// Test should return false when no token in context
hasRole := UserHasRole(ctx, "admin")
assert.False(t, hasRole)
}
func TestUserHasRole_WithInvalidTokenType(t *testing.T) {
// Put invalid token type in context
ctx := context.WithValue(context.Background(), mw.ContextKey{}, "not-a-token")
// Test should return false when token type is invalid
hasRole := UserHasRole(ctx, "admin")
assert.False(t, hasRole)
}
func TestUserHasRole_WithInvalidRolesType(t *testing.T) {
// Create token with invalid roles type
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "user-123",
"https://unbound.se/roles": "not-an-array",
})
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
// Test should return false when roles type is invalid
hasRole := UserHasRole(ctx, "admin")
assert.False(t, hasRole)
}
func TestUserHasRole_WithInvalidRoleElementType(t *testing.T) {
// Create token with invalid role element types
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "user-123",
"https://unbound.se/roles": []interface{}{123, 456}, // Numbers instead of strings
})
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
// Test should return false when role elements are not strings
hasRole := UserHasRole(ctx, "admin")
assert.False(t, hasRole)
}