| |
| """ |
| Context Relevance Classification Module |
| Uses LLM inference to identify relevant session contexts and generate dynamic summaries |
| """ |
|
|
| import logging |
| import asyncio |
| from typing import Dict, List, Optional |
| from datetime import datetime |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ContextRelevanceClassifier: |
| """ |
| Classify which session contexts are relevant to current conversation |
| and generate 2-line summaries for each relevant session |
| |
| Performance Priority: |
| - LLM inference first (accuracy over speed) |
| - Parallel processing for multiple sessions |
| - Caching for repeated queries |
| - Graceful degradation on failures |
| """ |
| |
| def __init__(self, llm_router): |
| """ |
| Initialize classifier with LLM router |
| |
| Args: |
| llm_router: LLMRouter instance for inference calls |
| """ |
| self.llm_router = llm_router |
| self._relevance_cache = {} |
| self._summary_cache = {} |
| self._cache_ttl = 3600 |
| |
| async def classify_and_summarize_relevant_contexts(self, |
| current_input: str, |
| session_contexts: List[Dict], |
| user_id: str = "Test_Any") -> Dict: |
| """ |
| Main method: Classify relevant contexts AND generate 2-line summaries |
| |
| Performance Strategy: |
| 1. Extract current topic (LLM inference - single call) |
| 2. Calculate relevance in parallel (multiple LLM calls in parallel) |
| 3. Generate summaries in parallel (only for relevant sessions) |
| |
| Args: |
| current_input: Current user query |
| session_contexts: List of session context dictionaries |
| user_id: User identifier for logging |
| |
| Returns: |
| { |
| 'relevant_summaries': List[str], # 2-line summaries |
| 'combined_user_context': str, # Combined summaries |
| 'relevance_scores': Dict, # Scores for each session |
| 'classification_confidence': float, |
| 'topic': str, |
| 'processing_time': float |
| } |
| """ |
| start_time = datetime.now() |
| |
| try: |
| |
| if not session_contexts: |
| logger.info("No session contexts provided for classification") |
| return { |
| 'relevant_summaries': [], |
| 'combined_user_context': '', |
| 'relevance_scores': {}, |
| 'classification_confidence': 1.0, |
| 'topic': '', |
| 'processing_time': 0.0 |
| } |
| |
| |
| current_topic = await self._extract_current_topic(current_input) |
| logger.info(f"Extracted current topic: '{current_topic}'") |
| |
| |
| relevance_tasks = [] |
| for session_ctx in session_contexts: |
| task = self._calculate_relevance_with_cache( |
| current_topic, |
| current_input, |
| session_ctx |
| ) |
| relevance_tasks.append((session_ctx, task)) |
| |
| |
| relevance_results = await asyncio.gather( |
| *[task for _, task in relevance_tasks], |
| return_exceptions=True |
| ) |
| |
| |
| relevant_sessions = [] |
| relevance_scores = {} |
| |
| for (session_ctx, _), result in zip(relevance_tasks, relevance_results): |
| if isinstance(result, Exception): |
| logger.error(f"Error calculating relevance: {result}") |
| continue |
| |
| session_id = session_ctx.get('session_id', 'unknown') |
| score = result.get('score', 0.0) |
| relevance_scores[session_id] = score |
| |
| if score >= 0.6: |
| relevant_sessions.append({ |
| 'session_id': session_id, |
| 'summary': session_ctx.get('summary', ''), |
| 'relevance_score': score, |
| 'interaction_contexts': session_ctx.get('interaction_contexts', []), |
| 'created_at': session_ctx.get('created_at', '') |
| }) |
| |
| logger.info(f"Found {len(relevant_sessions)} relevant sessions out of {len(session_contexts)}") |
| |
| |
| summary_tasks = [] |
| for relevant_session in relevant_sessions: |
| task = self._generate_session_summary( |
| relevant_session, |
| current_input, |
| current_topic |
| ) |
| summary_tasks.append(task) |
| |
| |
| summary_results = await asyncio.gather(*summary_tasks, return_exceptions=True) |
| |
| |
| valid_summaries = [] |
| for summary in summary_results: |
| if isinstance(summary, str) and summary.strip(): |
| valid_summaries.append(summary.strip()) |
| elif isinstance(summary, Exception): |
| logger.error(f"Error generating summary: {summary}") |
| |
| |
| combined_user_context = self._combine_summaries(valid_summaries, current_topic) |
| |
| processing_time = (datetime.now() - start_time).total_seconds() |
| |
| logger.info( |
| f"Relevance classification complete: {len(valid_summaries)} summaries, " |
| f"topic '{current_topic}', time: {processing_time:.2f}s" |
| ) |
| |
| return { |
| 'relevant_summaries': valid_summaries, |
| 'combined_user_context': combined_user_context, |
| 'relevance_scores': relevance_scores, |
| 'classification_confidence': 0.8, |
| 'topic': current_topic, |
| 'processing_time': processing_time |
| } |
| |
| except Exception as e: |
| logger.error(f"Error in relevance classification: {e}", exc_info=True) |
| processing_time = (datetime.now() - start_time).total_seconds() |
| |
| |
| return { |
| 'relevant_summaries': [], |
| 'combined_user_context': '', |
| 'relevance_scores': {}, |
| 'classification_confidence': 0.0, |
| 'topic': '', |
| 'processing_time': processing_time, |
| 'error': str(e) |
| } |
| |
| async def _extract_current_topic(self, user_input: str) -> str: |
| """ |
| Extract main topic from current input using LLM inference |
| |
| Performance: Single LLM call with caching |
| """ |
| try: |
| |
| cache_key = f"topic_{hash(user_input[:200])}" |
| if cache_key in self._relevance_cache: |
| cached = self._relevance_cache[cache_key] |
| if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp(): |
| return cached['value'] |
| |
| if not self.llm_router: |
| |
| words = user_input.split()[:5] |
| return ' '.join(words) if words else 'general query' |
| |
| prompt = f"""Extract the main topic (2-5 words) from this query: |
| |
| Query: "{user_input}" |
| |
| Respond with ONLY the topic name. Maximum 5 words.""" |
| |
| result = await self.llm_router.route_inference( |
| task_type="classification", |
| prompt=prompt, |
| max_tokens=20, |
| temperature=0.2 |
| ) |
| |
| topic = result.strip() if result else user_input[:100] |
| |
| |
| self._relevance_cache[cache_key] = { |
| 'value': topic, |
| 'timestamp': datetime.now().timestamp() |
| } |
| |
| return topic |
| |
| except Exception as e: |
| logger.error(f"Error extracting topic: {e}", exc_info=True) |
| |
| return user_input[:100] |
| |
| async def _calculate_relevance_with_cache(self, |
| current_topic: str, |
| current_input: str, |
| session_ctx: Dict) -> Dict: |
| """ |
| Calculate relevance score with caching to reduce LLM calls |
| |
| Returns: {'score': float, 'cached': bool} |
| """ |
| try: |
| session_id = session_ctx.get('session_id', 'unknown') |
| session_summary = session_ctx.get('summary', '') |
| |
| |
| cache_key = f"rel_{session_id}_{hash(current_input[:100] + current_topic)}" |
| if cache_key in self._relevance_cache: |
| cached = self._relevance_cache[cache_key] |
| if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp(): |
| return {'score': cached['value'], 'cached': True} |
| |
| |
| score = await self._calculate_relevance( |
| current_topic, |
| current_input, |
| session_summary |
| ) |
| |
| |
| self._relevance_cache[cache_key] = { |
| 'value': score, |
| 'timestamp': datetime.now().timestamp() |
| } |
| |
| return {'score': score, 'cached': False} |
| |
| except Exception as e: |
| logger.error(f"Error in cached relevance calculation: {e}", exc_info=True) |
| return {'score': 0.5, 'cached': False} |
| |
| async def _calculate_relevance(self, |
| current_topic: str, |
| current_input: str, |
| context_text: str) -> float: |
| """ |
| Calculate relevance score (0.0 to 1.0) using LLM inference |
| |
| Performance: Single LLM call per session context |
| """ |
| try: |
| if not context_text: |
| return 0.0 |
| |
| if not self.llm_router: |
| |
| return self._simple_keyword_relevance(current_input, context_text) |
| |
| |
| |
| |
| |
| prompt = f"""Rate the relevance (0.0 to 1.0) of this session context to the current conversation. |
| |
| Current Topic: {current_topic} |
| Current Query: "{current_input[:200]}" |
| |
| Session Context: |
| "{context_text[:500]}" |
| |
| Consider: |
| - Topic similarity (0.0-1.0) |
| - Discussion depth alignment |
| - Information continuity |
| |
| Respond with ONLY a number between 0.0 and 1.0 (e.g., 0.75).""" |
| |
| result = await self.llm_router.route_inference( |
| task_type="general_reasoning", |
| prompt=prompt, |
| max_tokens=10, |
| temperature=0.1 |
| ) |
| |
| if result: |
| try: |
| score = float(result.strip()) |
| return max(0.0, min(1.0, score)) |
| except ValueError: |
| logger.warning(f"Could not parse relevance score: {result}") |
| |
| |
| return self._simple_keyword_relevance(current_input, context_text) |
| |
| except Exception as e: |
| logger.error(f"Error calculating relevance: {e}", exc_info=True) |
| return 0.5 |
| |
| def _simple_keyword_relevance(self, current_input: str, context_text: str) -> float: |
| """Fallback keyword-based relevance calculation""" |
| try: |
| current_lower = current_input.lower() |
| context_lower = context_text.lower() |
| |
| current_words = set(current_lower.split()) |
| context_words = set(context_lower.split()) |
| |
| |
| stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'} |
| current_words = current_words - stop_words |
| context_words = context_words - stop_words |
| |
| if not current_words: |
| return 0.5 |
| |
| |
| intersection = len(current_words & context_words) |
| union = len(current_words | context_words) |
| |
| return (intersection / union) if union > 0 else 0.0 |
| |
| except Exception: |
| return 0.5 |
| |
| async def _generate_session_summary(self, |
| session_data: Dict, |
| current_input: str, |
| current_topic: str) -> str: |
| """ |
| Generate 2-line summary for a relevant session context |
| |
| Performance: LLM inference with caching and timeout protection |
| Builds depth and width of topic discussion |
| """ |
| try: |
| session_id = session_data.get('session_id', 'unknown') |
| session_summary = session_data.get('summary', '') |
| interaction_contexts = session_data.get('interaction_contexts', []) |
| |
| |
| cache_key = f"summary_{session_id}_{hash(current_topic)}" |
| if cache_key in self._summary_cache: |
| cached = self._summary_cache[cache_key] |
| if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp(): |
| return cached['value'] |
| |
| |
| if not session_summary and not interaction_contexts: |
| logger.warning(f"No content for summarization: session {session_id}") |
| return f"Previous discussion on {current_topic}.\nContext details unavailable." |
| |
| |
| session_context_text = session_summary[:500] if session_summary else "" |
| |
| if interaction_contexts: |
| recent_interactions = "\n".join([ |
| ic.get('summary', '')[:100] |
| for ic in interaction_contexts[-5:] |
| if ic.get('summary') |
| ]) |
| if recent_interactions: |
| session_context_text = f"{session_context_text}\n\nRecent interactions:\n{recent_interactions[:400]}" |
| |
| |
| if len(session_context_text) > 1000: |
| session_context_text = session_context_text[:1000] + "..." |
| |
| if not self.llm_router: |
| |
| return f"Previous {current_topic} discussion.\nCovered: {session_summary[:80]}..." |
| |
| |
| prompt = f"""Generate a precise 2-line summary (maximum 2 sentences, ~100 tokens total) that captures the depth and breadth of the topic discussion: |
| |
| Current Topic: {current_topic} |
| Current Query: "{current_input[:150]}" |
| |
| Previous Session Context: |
| {session_context_text} |
| |
| Requirements: |
| - Line 1: Summarize the MAIN TOPICS/SUBJECTS discussed (breadth/width) |
| - Line 2: Summarize the DEPTH/LEVEL of discussion (technical depth, detail level, approach) |
| - Focus on relevance to: "{current_topic}" |
| - Keep total under 100 tokens |
| - Be specific about what was covered |
| |
| Respond with ONLY the 2-line summary, no explanations.""" |
| |
| try: |
| result = await asyncio.wait_for( |
| self.llm_router.route_inference( |
| task_type="general_reasoning", |
| prompt=prompt, |
| max_tokens=100, |
| temperature=0.4 |
| ), |
| timeout=10.0 |
| ) |
| except asyncio.TimeoutError: |
| logger.warning(f"Summary generation timeout for session {session_id}") |
| return f"Previous {current_topic} discussion.\nDepth and approach covered in prior session." |
| |
| |
| if result and isinstance(result, str) and result.strip(): |
| summary = result.strip() |
| lines = [line.strip() for line in summary.split('\n') if line.strip()] |
| |
| if len(lines) >= 1: |
| if len(lines) > 2: |
| combined = f"{lines[0]}\n{'. '.join(lines[1:])}" |
| formatted_summary = combined[:200] |
| else: |
| formatted_summary = '\n'.join(lines[:2])[:200] |
| |
| |
| if len(formatted_summary) < 20: |
| formatted_summary = f"Previous {current_topic} discussion.\nDetails from previous session." |
| |
| |
| self._summary_cache[cache_key] = { |
| 'value': formatted_summary, |
| 'timestamp': datetime.now().timestamp() |
| } |
| |
| return formatted_summary |
| else: |
| return f"Previous {current_topic} discussion.\nContext from previous session." |
| |
| |
| logger.warning(f"Invalid summary result for session {session_id}") |
| return f"Previous {current_topic} discussion.\nDepth and approach covered previously." |
| |
| except Exception as e: |
| logger.error(f"Error generating session summary: {e}", exc_info=True) |
| session_summary = session_data.get('summary', '')[:100] if session_data.get('summary') else 'topic discussion' |
| return f"{session_summary}...\n{current_topic} discussion from previous session." |
| |
| def _combine_summaries(self, summaries: List[str], current_topic: str) -> str: |
| """ |
| Combine multiple 2-line summaries into coherent user context |
| |
| Builds width (multiple topics) and depth (summarized discussions) |
| """ |
| try: |
| if not summaries: |
| return '' |
| |
| if len(summaries) == 1: |
| return summaries[0] |
| |
| |
| combined = f"Relevant Previous Discussions (Topic: {current_topic}):\n\n" |
| |
| for idx, summary in enumerate(summaries, 1): |
| combined += f"[Session {idx}]\n{summary}\n\n" |
| |
| |
| combined += f"These sessions provide context for {current_topic} discussions, covering multiple aspects and depth levels." |
| |
| return combined |
| |
| except Exception as e: |
| logger.error(f"Error combining summaries: {e}", exc_info=True) |
| |
| return '\n\n'.join(summaries[:5]) |
|
|
|
|