134 lines
2.7 KiB
Go
134 lines
2.7 KiB
Go
|
|
package store
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"log/slog"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
const (
|
||
|
|
// SessionTTL is the time-to-live for sessions
|
||
|
|
SessionTTL = 5 * time.Minute
|
||
|
|
// CleanupInterval is how often expired sessions are cleaned up
|
||
|
|
CleanupInterval = 60 * time.Second
|
||
|
|
)
|
||
|
|
|
||
|
|
// Session represents an OAuth session
|
||
|
|
type Session struct {
|
||
|
|
Email string
|
||
|
|
Password string
|
||
|
|
State string
|
||
|
|
Nonce string
|
||
|
|
ClientID string
|
||
|
|
CodeChallenge string
|
||
|
|
CodeVerifier string
|
||
|
|
CustomClaims []map[string]interface{}
|
||
|
|
CreatedAt time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
// SessionStore provides thread-safe session storage with TTL
|
||
|
|
type SessionStore struct {
|
||
|
|
mu sync.RWMutex
|
||
|
|
sessions map[string]*Session
|
||
|
|
challenges map[string]string
|
||
|
|
logger *slog.Logger
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewSessionStore creates a new session store
|
||
|
|
func NewSessionStore(logger *slog.Logger) *SessionStore {
|
||
|
|
return &SessionStore{
|
||
|
|
sessions: make(map[string]*Session),
|
||
|
|
challenges: make(map[string]string),
|
||
|
|
logger: logger,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create stores a new session
|
||
|
|
func (s *SessionStore) Create(code string, session *Session) {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
|
||
|
|
session.CreatedAt = time.Now()
|
||
|
|
s.sessions[code] = session
|
||
|
|
s.challenges[code] = code
|
||
|
|
}
|
||
|
|
|
||
|
|
// Get retrieves a session by code
|
||
|
|
func (s *SessionStore) Get(code string) (*Session, bool) {
|
||
|
|
s.mu.RLock()
|
||
|
|
defer s.mu.RUnlock()
|
||
|
|
|
||
|
|
session, ok := s.sessions[code]
|
||
|
|
return session, ok
|
||
|
|
}
|
||
|
|
|
||
|
|
// Update updates an existing session and optionally re-indexes it
|
||
|
|
func (s *SessionStore) Update(oldCode, newCode string, updateFn func(*Session)) bool {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
|
||
|
|
session, ok := s.sessions[oldCode]
|
||
|
|
if !ok {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
updateFn(session)
|
||
|
|
session.CreatedAt = time.Now() // Refresh timestamp
|
||
|
|
|
||
|
|
if oldCode != newCode {
|
||
|
|
s.sessions[newCode] = session
|
||
|
|
s.challenges[newCode] = newCode
|
||
|
|
delete(s.sessions, oldCode)
|
||
|
|
delete(s.challenges, oldCode)
|
||
|
|
}
|
||
|
|
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
// Delete removes a session
|
||
|
|
func (s *SessionStore) Delete(code string) {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
|
||
|
|
delete(s.sessions, code)
|
||
|
|
delete(s.challenges, code)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Cleanup removes expired sessions
|
||
|
|
func (s *SessionStore) Cleanup() int {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
|
||
|
|
now := time.Now()
|
||
|
|
cleaned := 0
|
||
|
|
|
||
|
|
for code, session := range s.sessions {
|
||
|
|
if now.Sub(session.CreatedAt) > SessionTTL {
|
||
|
|
delete(s.sessions, code)
|
||
|
|
delete(s.challenges, code)
|
||
|
|
cleaned++
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return cleaned
|
||
|
|
}
|
||
|
|
|
||
|
|
// StartCleanup starts a background goroutine to clean up expired sessions
|
||
|
|
func (s *SessionStore) StartCleanup(ctx context.Context) {
|
||
|
|
ticker := time.NewTicker(CleanupInterval)
|
||
|
|
go func() {
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
ticker.Stop()
|
||
|
|
return
|
||
|
|
case <-ticker.C:
|
||
|
|
if cleaned := s.Cleanup(); cleaned > 0 {
|
||
|
|
s.logger.Info("cleaned up expired sessions", "count", cleaned)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
}
|