"""In-memory session manager with TTL-based expiry + SQLite persistence.""" import time from typing import Optional from app.chat_db import save_message, get_chat_history class Session: """Represents a single conversation session.""" def __init__(self, session_id: str): self.session_id = session_id self.history: list[dict] = [] # [{"role": "user"|"assistant", "content": str}] self.current_topic: str = "" self.pending_clarification: Optional[str] = None self.last_confidence: float = 0.0 self.last_action: str = "" # "retrieve" | "reason" | "clarify" self.last_active: float = time.time() def add_user_message(self, message: str): self.history.append({"role": "user", "content": message}) self.last_active = time.time() self._trim_history() # Persist to SQLite save_message(self.session_id, "user", message) def add_assistant_message(self, message: str): self.history.append({"role": "assistant", "content": message}) self.last_active = time.time() self._trim_history() # Persist to SQLite save_message(self.session_id, "assistant", message) def get_history_text(self) -> str: """Return formatted conversation history for LLM context.""" lines = [] for msg in self.history[-10:]: # last 10 messages role = "User" if msg["role"] == "user" else "Assistant" lines.append(f"{role}: {msg['content']}") return "\n".join(lines) def _trim_history(self, max_turns: int = 20): """Keep only the last max_turns messages in memory.""" if len(self.history) > max_turns: self.history = self.history[-max_turns:] class SessionManager: """Manages all active sessions with TTL expiry.""" def __init__(self, ttl_minutes: int = 30): self._sessions: dict[str, Session] = {} self._ttl_seconds = ttl_minutes * 60 def get_or_create(self, session_id: str) -> Session: """Get an existing session or create a new one.""" self._cleanup_expired() if session_id not in self._sessions: session = Session(session_id) # Load existing history from database (if any) db_history = get_chat_history(session_id) for msg in db_history: session.history.append({"role": msg["role"], "content": msg["content"]}) session._trim_history() self._sessions[session_id] = session session = self._sessions[session_id] session.last_active = time.time() return session def _cleanup_expired(self): """Remove sessions that have exceeded TTL.""" now = time.time() expired = [ sid for sid, s in self._sessions.items() if now - s.last_active > self._ttl_seconds ] for sid in expired: del self._sessions[sid] # Global singleton session_manager = SessionManager()