Files

134 lines
2.7 KiB
Go
Raw Permalink Normal View History

package store
import (
"context"
"log/slog"
"sync"
"time"
)
const (
// SessionTTL is the time-to-live for sessions
SessionTTL = 5 * time.Minute
// CleanupInterval is how often expired sessions are cleaned up
CleanupInterval = 60 * time.Second
)
// Session represents an OAuth session
type Session struct {
Email string
Password string
State string
Nonce string
ClientID string
CodeChallenge string
CodeVerifier string
CustomClaims []map[string]interface{}
CreatedAt time.Time
}
// SessionStore provides thread-safe session storage with TTL
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*Session
challenges map[string]string
logger *slog.Logger
}
// NewSessionStore creates a new session store
func NewSessionStore(logger *slog.Logger) *SessionStore {
return &SessionStore{
sessions: make(map[string]*Session),
challenges: make(map[string]string),
logger: logger,
}
}
// Create stores a new session
func (s *SessionStore) Create(code string, session *Session) {
s.mu.Lock()
defer s.mu.Unlock()
session.CreatedAt = time.Now()
s.sessions[code] = session
s.challenges[code] = code
}
// Get retrieves a session by code
func (s *SessionStore) Get(code string) (*Session, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[code]
return session, ok
}
// Update updates an existing session and optionally re-indexes it
func (s *SessionStore) Update(oldCode, newCode string, updateFn func(*Session)) bool {
s.mu.Lock()
defer s.mu.Unlock()
session, ok := s.sessions[oldCode]
if !ok {
return false
}
updateFn(session)
session.CreatedAt = time.Now() // Refresh timestamp
if oldCode != newCode {
s.sessions[newCode] = session
s.challenges[newCode] = newCode
delete(s.sessions, oldCode)
delete(s.challenges, oldCode)
}
return true
}
// Delete removes a session
func (s *SessionStore) Delete(code string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, code)
delete(s.challenges, code)
}
// Cleanup removes expired sessions
func (s *SessionStore) Cleanup() int {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
cleaned := 0
for code, session := range s.sessions {
if now.Sub(session.CreatedAt) > SessionTTL {
delete(s.sessions, code)
delete(s.challenges, code)
cleaned++
}
}
return cleaned
}
// StartCleanup starts a background goroutine to clean up expired sessions
func (s *SessionStore) StartCleanup(ctx context.Context) {
ticker := time.NewTicker(CleanupInterval)
go func() {
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
if cleaned := s.Cleanup(); cleaned > 0 {
s.logger.Info("cleaned up expired sessions", "count", cleaned)
}
}
}
}()
}