76 lines
2.3 KiB
Go
76 lines
2.3 KiB
Go
|
|
package auth
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/hmac"
|
||
|
|
"crypto/sha256"
|
||
|
|
"encoding/hex"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"github.com/stretchr/testify/assert"
|
||
|
|
)
|
||
|
|
|
||
|
|
func sign(key, header string) string {
|
||
|
|
mac := hmac.New(sha256.New, []byte(key))
|
||
|
|
mac.Write([]byte(header))
|
||
|
|
return hex.EncodeToString(mac.Sum(nil))
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestUserMiddleware(t *testing.T) {
|
||
|
|
key := "secret"
|
||
|
|
header := `{"email":"jim@example.org","roles":["admin"]}`
|
||
|
|
capture := func(next *bool) http.Handler {
|
||
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
*next = true
|
||
|
|
if u := FromContext(r.Context()); u != nil {
|
||
|
|
assert.Equal(t, "jim@example.org", u.Email)
|
||
|
|
assert.True(t, u.HasRole("admin"))
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
t.Run("valid signature passes and injects user", func(t *testing.T) {
|
||
|
|
called := false
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/query", nil)
|
||
|
|
req.Header.Set("user", header)
|
||
|
|
req.Header.Set("user-signature", sign(key, header))
|
||
|
|
rw := httptest.NewRecorder()
|
||
|
|
UserMiddleware([]byte(key))(capture(&called)).ServeHTTP(rw, req)
|
||
|
|
assert.True(t, called)
|
||
|
|
assert.Equal(t, http.StatusOK, rw.Code)
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("invalid signature is rejected", func(t *testing.T) {
|
||
|
|
called := false
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/query", nil)
|
||
|
|
req.Header.Set("user", header)
|
||
|
|
req.Header.Set("user-signature", "deadbeef")
|
||
|
|
rw := httptest.NewRecorder()
|
||
|
|
UserMiddleware([]byte(key))(capture(&called)).ServeHTTP(rw, req)
|
||
|
|
assert.False(t, called)
|
||
|
|
assert.Equal(t, http.StatusUnauthorized, rw.Code)
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("missing signature when key set is rejected", func(t *testing.T) {
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/query", nil)
|
||
|
|
req.Header.Set("user", header)
|
||
|
|
rw := httptest.NewRecorder()
|
||
|
|
UserMiddleware([]byte(key))(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})).ServeHTTP(rw, req)
|
||
|
|
assert.Equal(t, http.StatusUnauthorized, rw.Code)
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("empty key skips verification (dev only)", func(t *testing.T) {
|
||
|
|
called := false
|
||
|
|
req := httptest.NewRequest(http.MethodPost, "/query", nil)
|
||
|
|
req.Header.Set("user", header)
|
||
|
|
rw := httptest.NewRecorder()
|
||
|
|
UserMiddleware(nil)(capture(&called)).ServeHTTP(rw, req)
|
||
|
|
assert.True(t, called)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestFromContextNil(t *testing.T) {
|
||
|
|
assert.Nil(t, FromContext(httptest.NewRequest(http.MethodGet, "/", nil).Context()))
|
||
|
|
}
|