package main import ( "context" "errors" "fmt" "log/slog" "net/http" "os" "os/signal" "sync" "syscall" "time" "github.com/alecthomas/kong" "github.com/rs/cors" "gitlab.com/unboundsoftware/auth0mock/auth" "gitlab.com/unboundsoftware/auth0mock/handlers" "gitlab.com/unboundsoftware/auth0mock/store" ) var ( buildVersion = "dev" serviceName = "auth0mock" ) // CLI defines the command-line interface type CLI struct { Port int `name:"port" env:"PORT" help:"Listen port" default:"3333"` Issuer string `name:"issuer" env:"ISSUER" help:"JWT issuer (without https://)" default:"localhost:3333"` Audience string `name:"audience" env:"AUDIENCE" help:"JWT audience" default:"https://generic-audience"` UsersFile string `name:"users-file" env:"USERS_FILE" help:"Path to initial users JSON file" default:"./users.json"` AdminClaim string `name:"admin-claim" env:"ADMIN_CUSTOM_CLAIM" help:"Admin custom claim key" default:"https://unbound.se/admin"` EmailClaim string `name:"email-claim" env:"EMAIL_CUSTOM_CLAIM" help:"Email custom claim key" default:"https://unbound.se/email"` LogLevel string `name:"log-level" env:"LOG_LEVEL" help:"Log level" default:"info" enum:"debug,info,warn,error"` LogFormat string `name:"log-format" env:"LOG_FORMAT" help:"Log format" default:"text" enum:"json,text"` } func main() { var cli CLI _ = kong.Parse(&cli) // Setup logger logger := setupLogger(cli.LogLevel, cli.LogFormat) logger.Info("starting auth0mock", "version", buildVersion, "port", cli.Port, "issuer", cli.Issuer, ) // Initialize stores userStore := store.NewUserStore() if err := userStore.LoadFromFile(cli.UsersFile); err != nil { logger.Warn("failed to load users file", "path", cli.UsersFile, "error", err) } sessionStore := store.NewSessionStore(logger) // Initialize JWT service issuerURL := fmt.Sprintf("https://%s/", cli.Issuer) jwtService, err := auth.NewJWTService(issuerURL, cli.Audience, cli.AdminClaim, cli.EmailClaim) if err != nil { logger.Error("failed to create JWT service", "error", err) os.Exit(1) } // Initialize handlers discoveryHandler := handlers.NewDiscoveryHandler(jwtService) oauthHandler, err := handlers.NewOAuthHandler(jwtService, sessionStore, logger) if err != nil { logger.Error("failed to create OAuth handler", "error", err) os.Exit(1) } managementHandler := handlers.NewManagementHandler(userStore, logger) sessionHandler := handlers.NewSessionHandler(jwtService, sessionStore, logger) // Setup routes mux := http.NewServeMux() // CORS middleware corsHandler := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PATCH", "OPTIONS"}, AllowedHeaders: []string{"*"}, AllowCredentials: true, }) // Discovery endpoints mux.HandleFunc("GET /.well-known/openid-configuration", discoveryHandler.OpenIDConfiguration) mux.HandleFunc("GET /.well-known/jwks.json", discoveryHandler.JWKS) // OAuth endpoints mux.HandleFunc("POST /oauth/token", oauthHandler.Token) mux.HandleFunc("GET /authorize", oauthHandler.Authorize) mux.HandleFunc("POST /code", oauthHandler.Code) // Session endpoints mux.HandleFunc("GET /userinfo", sessionHandler.UserInfo) mux.HandleFunc("POST /tokeninfo", sessionHandler.TokenInfo) mux.HandleFunc("GET /v2/logout", sessionHandler.Logout) // Management API endpoints mux.HandleFunc("GET /api/v2/users-by-email", managementHandler.GetUsersByEmail) mux.HandleFunc("POST /api/v2/users", managementHandler.CreateUser) mux.HandleFunc("PATCH /api/v2/users/", managementHandler.UpdateUser) mux.HandleFunc("POST /api/v2/tickets/password-change", managementHandler.PasswordChangeTicket) // Health check mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) }) // Static files mux.Handle("GET /favicon.ico", http.FileServer(http.Dir("public"))) // Create HTTP server httpSrv := &http.Server{ Addr: fmt.Sprintf(":%d", cli.Port), Handler: corsHandler.Handler(mux), } // Start session cleanup rootCtx, rootCancel := context.WithCancel(context.Background()) sessionStore.StartCleanup(rootCtx) // Graceful shutdown wg := sync.WaitGroup{} sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) // Signal handler goroutine wg.Add(1) go func() { defer wg.Done() sig := <-sigint if sig != nil { signal.Reset(os.Interrupt, syscall.SIGTERM) logger.Info("received shutdown signal") rootCancel() } }() // Shutdown handler goroutine wg.Add(1) go func() { defer wg.Done() <-rootCtx.Done() shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second) defer shutdownRelease() if err := httpSrv.Shutdown(shutdownCtx); err != nil { logger.Error("failed to shutdown HTTP server", "error", err) } close(sigint) }() // HTTP server goroutine wg.Add(1) go func() { defer wg.Done() defer rootCancel() logger.Info("listening", "port", cli.Port) if err := httpSrv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { logger.Error("HTTP server error", "error", err) } }() wg.Wait() logger.Info("shutdown complete") } func setupLogger(level, format string) *slog.Logger { var leveler slog.LevelVar if err := leveler.UnmarshalText([]byte(level)); err != nil { leveler.Set(slog.LevelInfo) } handlerOpts := &slog.HandlerOptions{ Level: leveler.Level(), } var handler slog.Handler switch format { case "json": handler = slog.NewJSONHandler(os.Stdout, handlerOpts) default: handler = slog.NewTextHandler(os.Stdout, handlerOpts) } return slog.New(handler).With("service", serviceName, "version", buildVersion) }