195 lines
5.6 KiB
Go
195 lines
5.6 KiB
Go
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"log/slog"
|
||
|
|
"net/http"
|
||
|
|
"os"
|
||
|
|
"os/signal"
|
||
|
|
"sync"
|
||
|
|
"syscall"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/alecthomas/kong"
|
||
|
|
"github.com/rs/cors"
|
||
|
|
|
||
|
|
"gitlab.com/unboundsoftware/auth0mock/auth"
|
||
|
|
"gitlab.com/unboundsoftware/auth0mock/handlers"
|
||
|
|
"gitlab.com/unboundsoftware/auth0mock/store"
|
||
|
|
)
|
||
|
|
|
||
|
|
var (
|
||
|
|
buildVersion = "dev"
|
||
|
|
serviceName = "auth0mock"
|
||
|
|
)
|
||
|
|
|
||
|
|
// CLI defines the command-line interface
|
||
|
|
type CLI struct {
|
||
|
|
Port int `name:"port" env:"PORT" help:"Listen port" default:"3333"`
|
||
|
|
Issuer string `name:"issuer" env:"ISSUER" help:"JWT issuer (without https://)" default:"localhost:3333"`
|
||
|
|
Audience string `name:"audience" env:"AUDIENCE" help:"JWT audience" default:"https://generic-audience"`
|
||
|
|
UsersFile string `name:"users-file" env:"USERS_FILE" help:"Path to initial users JSON file" default:"./users.json"`
|
||
|
|
AdminClaim string `name:"admin-claim" env:"ADMIN_CUSTOM_CLAIM" help:"Admin custom claim key" default:"https://unbound.se/admin"`
|
||
|
|
EmailClaim string `name:"email-claim" env:"EMAIL_CUSTOM_CLAIM" help:"Email custom claim key" default:"https://unbound.se/email"`
|
||
|
|
LogLevel string `name:"log-level" env:"LOG_LEVEL" help:"Log level" default:"info" enum:"debug,info,warn,error"`
|
||
|
|
LogFormat string `name:"log-format" env:"LOG_FORMAT" help:"Log format" default:"text" enum:"json,text"`
|
||
|
|
}
|
||
|
|
|
||
|
|
func main() {
|
||
|
|
var cli CLI
|
||
|
|
_ = kong.Parse(&cli)
|
||
|
|
|
||
|
|
// Setup logger
|
||
|
|
logger := setupLogger(cli.LogLevel, cli.LogFormat)
|
||
|
|
logger.Info("starting auth0mock",
|
||
|
|
"version", buildVersion,
|
||
|
|
"port", cli.Port,
|
||
|
|
"issuer", cli.Issuer,
|
||
|
|
)
|
||
|
|
|
||
|
|
// Initialize stores
|
||
|
|
userStore := store.NewUserStore()
|
||
|
|
if err := userStore.LoadFromFile(cli.UsersFile); err != nil {
|
||
|
|
logger.Warn("failed to load users file", "path", cli.UsersFile, "error", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
sessionStore := store.NewSessionStore(logger)
|
||
|
|
|
||
|
|
// Initialize JWT service
|
||
|
|
issuerURL := fmt.Sprintf("https://%s/", cli.Issuer)
|
||
|
|
jwtService, err := auth.NewJWTService(issuerURL, cli.Audience, cli.AdminClaim, cli.EmailClaim)
|
||
|
|
if err != nil {
|
||
|
|
logger.Error("failed to create JWT service", "error", err)
|
||
|
|
os.Exit(1)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Initialize handlers
|
||
|
|
discoveryHandler := handlers.NewDiscoveryHandler(jwtService)
|
||
|
|
oauthHandler, err := handlers.NewOAuthHandler(jwtService, sessionStore, logger)
|
||
|
|
if err != nil {
|
||
|
|
logger.Error("failed to create OAuth handler", "error", err)
|
||
|
|
os.Exit(1)
|
||
|
|
}
|
||
|
|
managementHandler := handlers.NewManagementHandler(userStore, logger)
|
||
|
|
sessionHandler := handlers.NewSessionHandler(jwtService, sessionStore, logger)
|
||
|
|
|
||
|
|
// Setup routes
|
||
|
|
mux := http.NewServeMux()
|
||
|
|
|
||
|
|
// CORS middleware
|
||
|
|
corsHandler := cors.New(cors.Options{
|
||
|
|
AllowedOrigins: []string{"*"},
|
||
|
|
AllowedMethods: []string{"GET", "POST", "PATCH", "OPTIONS"},
|
||
|
|
AllowedHeaders: []string{"*"},
|
||
|
|
AllowCredentials: true,
|
||
|
|
})
|
||
|
|
|
||
|
|
// Discovery endpoints
|
||
|
|
mux.HandleFunc("GET /.well-known/openid-configuration", discoveryHandler.OpenIDConfiguration)
|
||
|
|
mux.HandleFunc("GET /.well-known/jwks.json", discoveryHandler.JWKS)
|
||
|
|
|
||
|
|
// OAuth endpoints
|
||
|
|
mux.HandleFunc("POST /oauth/token", oauthHandler.Token)
|
||
|
|
mux.HandleFunc("GET /authorize", oauthHandler.Authorize)
|
||
|
|
mux.HandleFunc("POST /code", oauthHandler.Code)
|
||
|
|
|
||
|
|
// Session endpoints
|
||
|
|
mux.HandleFunc("GET /userinfo", sessionHandler.UserInfo)
|
||
|
|
mux.HandleFunc("POST /tokeninfo", sessionHandler.TokenInfo)
|
||
|
|
mux.HandleFunc("GET /v2/logout", sessionHandler.Logout)
|
||
|
|
|
||
|
|
// Management API endpoints
|
||
|
|
mux.HandleFunc("GET /api/v2/users-by-email", managementHandler.GetUsersByEmail)
|
||
|
|
mux.HandleFunc("POST /api/v2/users", managementHandler.CreateUser)
|
||
|
|
mux.HandleFunc("PATCH /api/v2/users/", managementHandler.UpdateUser)
|
||
|
|
mux.HandleFunc("POST /api/v2/tickets/password-change", managementHandler.PasswordChangeTicket)
|
||
|
|
|
||
|
|
// Health check
|
||
|
|
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
w.Write([]byte("OK"))
|
||
|
|
})
|
||
|
|
|
||
|
|
// Static files
|
||
|
|
mux.Handle("GET /favicon.ico", http.FileServer(http.Dir("public")))
|
||
|
|
|
||
|
|
// Create HTTP server
|
||
|
|
httpSrv := &http.Server{
|
||
|
|
Addr: fmt.Sprintf(":%d", cli.Port),
|
||
|
|
Handler: corsHandler.Handler(mux),
|
||
|
|
}
|
||
|
|
|
||
|
|
// Start session cleanup
|
||
|
|
rootCtx, rootCancel := context.WithCancel(context.Background())
|
||
|
|
sessionStore.StartCleanup(rootCtx)
|
||
|
|
|
||
|
|
// Graceful shutdown
|
||
|
|
wg := sync.WaitGroup{}
|
||
|
|
sigint := make(chan os.Signal, 1)
|
||
|
|
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
||
|
|
|
||
|
|
// Signal handler goroutine
|
||
|
|
wg.Add(1)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
sig := <-sigint
|
||
|
|
if sig != nil {
|
||
|
|
signal.Reset(os.Interrupt, syscall.SIGTERM)
|
||
|
|
logger.Info("received shutdown signal")
|
||
|
|
rootCancel()
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
// Shutdown handler goroutine
|
||
|
|
wg.Add(1)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
<-rootCtx.Done()
|
||
|
|
|
||
|
|
shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second)
|
||
|
|
defer shutdownRelease()
|
||
|
|
|
||
|
|
if err := httpSrv.Shutdown(shutdownCtx); err != nil {
|
||
|
|
logger.Error("failed to shutdown HTTP server", "error", err)
|
||
|
|
}
|
||
|
|
close(sigint)
|
||
|
|
}()
|
||
|
|
|
||
|
|
// HTTP server goroutine
|
||
|
|
wg.Add(1)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
defer rootCancel()
|
||
|
|
|
||
|
|
logger.Info("listening", "port", cli.Port)
|
||
|
|
if err := httpSrv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||
|
|
logger.Error("HTTP server error", "error", err)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
wg.Wait()
|
||
|
|
logger.Info("shutdown complete")
|
||
|
|
}
|
||
|
|
|
||
|
|
func setupLogger(level, format string) *slog.Logger {
|
||
|
|
var leveler slog.LevelVar
|
||
|
|
if err := leveler.UnmarshalText([]byte(level)); err != nil {
|
||
|
|
leveler.Set(slog.LevelInfo)
|
||
|
|
}
|
||
|
|
|
||
|
|
handlerOpts := &slog.HandlerOptions{
|
||
|
|
Level: leveler.Level(),
|
||
|
|
}
|
||
|
|
|
||
|
|
var handler slog.Handler
|
||
|
|
switch format {
|
||
|
|
case "json":
|
||
|
|
handler = slog.NewJSONHandler(os.Stdout, handlerOpts)
|
||
|
|
default:
|
||
|
|
handler = slog.NewTextHandler(os.Stdout, handlerOpts)
|
||
|
|
}
|
||
|
|
|
||
|
|
return slog.New(handler).With("service", serviceName, "version", buildVersion)
|
||
|
|
}
|