Files
auth0mock/cmd/service/service.go
T

195 lines
5.6 KiB
Go
Raw Normal View History

package main
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/alecthomas/kong"
"github.com/rs/cors"
"gitlab.com/unboundsoftware/auth0mock/auth"
"gitlab.com/unboundsoftware/auth0mock/handlers"
"gitlab.com/unboundsoftware/auth0mock/store"
)
var (
buildVersion = "dev"
serviceName = "auth0mock"
)
// CLI defines the command-line interface
type CLI struct {
Port int `name:"port" env:"PORT" help:"Listen port" default:"3333"`
Issuer string `name:"issuer" env:"ISSUER" help:"JWT issuer (without https://)" default:"localhost:3333"`
Audience string `name:"audience" env:"AUDIENCE" help:"JWT audience" default:"https://generic-audience"`
UsersFile string `name:"users-file" env:"USERS_FILE" help:"Path to initial users JSON file" default:"./users.json"`
AdminClaim string `name:"admin-claim" env:"ADMIN_CUSTOM_CLAIM" help:"Admin custom claim key" default:"https://unbound.se/admin"`
EmailClaim string `name:"email-claim" env:"EMAIL_CUSTOM_CLAIM" help:"Email custom claim key" default:"https://unbound.se/email"`
LogLevel string `name:"log-level" env:"LOG_LEVEL" help:"Log level" default:"info" enum:"debug,info,warn,error"`
LogFormat string `name:"log-format" env:"LOG_FORMAT" help:"Log format" default:"text" enum:"json,text"`
}
func main() {
var cli CLI
_ = kong.Parse(&cli)
// Setup logger
logger := setupLogger(cli.LogLevel, cli.LogFormat)
logger.Info("starting auth0mock",
"version", buildVersion,
"port", cli.Port,
"issuer", cli.Issuer,
)
// Initialize stores
userStore := store.NewUserStore()
if err := userStore.LoadFromFile(cli.UsersFile); err != nil {
logger.Warn("failed to load users file", "path", cli.UsersFile, "error", err)
}
sessionStore := store.NewSessionStore(logger)
// Initialize JWT service
issuerURL := fmt.Sprintf("https://%s/", cli.Issuer)
jwtService, err := auth.NewJWTService(issuerURL, cli.Audience, cli.AdminClaim, cli.EmailClaim)
if err != nil {
logger.Error("failed to create JWT service", "error", err)
os.Exit(1)
}
// Initialize handlers
discoveryHandler := handlers.NewDiscoveryHandler(jwtService)
oauthHandler, err := handlers.NewOAuthHandler(jwtService, sessionStore, logger)
if err != nil {
logger.Error("failed to create OAuth handler", "error", err)
os.Exit(1)
}
managementHandler := handlers.NewManagementHandler(userStore, logger)
sessionHandler := handlers.NewSessionHandler(jwtService, sessionStore, logger)
// Setup routes
mux := http.NewServeMux()
// CORS middleware
corsHandler := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PATCH", "OPTIONS"},
AllowedHeaders: []string{"*"},
AllowCredentials: true,
})
// Discovery endpoints
mux.HandleFunc("GET /.well-known/openid-configuration", discoveryHandler.OpenIDConfiguration)
mux.HandleFunc("GET /.well-known/jwks.json", discoveryHandler.JWKS)
// OAuth endpoints
mux.HandleFunc("POST /oauth/token", oauthHandler.Token)
mux.HandleFunc("GET /authorize", oauthHandler.Authorize)
mux.HandleFunc("POST /code", oauthHandler.Code)
// Session endpoints
mux.HandleFunc("GET /userinfo", sessionHandler.UserInfo)
mux.HandleFunc("POST /tokeninfo", sessionHandler.TokenInfo)
mux.HandleFunc("GET /v2/logout", sessionHandler.Logout)
// Management API endpoints
mux.HandleFunc("GET /api/v2/users-by-email", managementHandler.GetUsersByEmail)
mux.HandleFunc("POST /api/v2/users", managementHandler.CreateUser)
mux.HandleFunc("PATCH /api/v2/users/", managementHandler.UpdateUser)
mux.HandleFunc("POST /api/v2/tickets/password-change", managementHandler.PasswordChangeTicket)
// Health check
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
})
// Static files
mux.Handle("GET /favicon.ico", http.FileServer(http.Dir("public")))
// Create HTTP server
httpSrv := &http.Server{
Addr: fmt.Sprintf(":%d", cli.Port),
Handler: corsHandler.Handler(mux),
}
// Start session cleanup
rootCtx, rootCancel := context.WithCancel(context.Background())
sessionStore.StartCleanup(rootCtx)
// Graceful shutdown
wg := sync.WaitGroup{}
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
// Signal handler goroutine
wg.Add(1)
go func() {
defer wg.Done()
sig := <-sigint
if sig != nil {
signal.Reset(os.Interrupt, syscall.SIGTERM)
logger.Info("received shutdown signal")
rootCancel()
}
}()
// Shutdown handler goroutine
wg.Add(1)
go func() {
defer wg.Done()
<-rootCtx.Done()
shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownRelease()
if err := httpSrv.Shutdown(shutdownCtx); err != nil {
logger.Error("failed to shutdown HTTP server", "error", err)
}
close(sigint)
}()
// HTTP server goroutine
wg.Add(1)
go func() {
defer wg.Done()
defer rootCancel()
logger.Info("listening", "port", cli.Port)
if err := httpSrv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
logger.Error("HTTP server error", "error", err)
}
}()
wg.Wait()
logger.Info("shutdown complete")
}
func setupLogger(level, format string) *slog.Logger {
var leveler slog.LevelVar
if err := leveler.UnmarshalText([]byte(level)); err != nil {
leveler.Set(slog.LevelInfo)
}
handlerOpts := &slog.HandlerOptions{
Level: leveler.Level(),
}
var handler slog.Handler
switch format {
case "json":
handler = slog.NewJSONHandler(os.Stdout, handlerOpts)
default:
handler = slog.NewTextHandler(os.Stdout, handlerOpts)
}
return slog.New(handler).With("service", serviceName, "version", buildVersion)
}