| """ |
| Query Planner Agent |
| |
| Decomposes complex queries into sub-queries and identifies query intent. |
| Follows the "Decomposed Prompting" approach from FAANG research. |
| |
| Key Features: |
| - Multi-hop query decomposition |
| - Query intent classification (factoid, comparison, aggregation, etc.) |
| - Dependency graph for sub-queries |
| - Query expansion with synonyms and related terms |
| """ |
|
|
| from typing import List, Optional, Dict, Any, Literal |
| from pydantic import BaseModel, Field |
| from loguru import logger |
| from enum import Enum |
| import json |
| import re |
|
|
| try: |
| import httpx |
| HTTPX_AVAILABLE = True |
| except ImportError: |
| HTTPX_AVAILABLE = False |
|
|
|
|
| class QueryIntent(str, Enum): |
| """Classification of query intent.""" |
| FACTOID = "factoid" |
| COMPARISON = "comparison" |
| AGGREGATION = "aggregation" |
| CAUSAL = "causal" |
| PROCEDURAL = "procedural" |
| DEFINITION = "definition" |
| LIST = "list" |
| MULTI_HOP = "multi_hop" |
|
|
|
|
| class SubQuery(BaseModel): |
| """A decomposed sub-query.""" |
| id: str |
| query: str |
| intent: QueryIntent |
| depends_on: List[str] = Field(default_factory=list) |
| priority: int = Field(default=1, ge=1, le=5) |
| filters: Dict[str, Any] = Field(default_factory=dict) |
| expected_answer_type: str = Field(default="text") |
|
|
|
|
| class QueryPlan(BaseModel): |
| """Complete query execution plan.""" |
| original_query: str |
| intent: QueryIntent |
| sub_queries: List[SubQuery] |
| expanded_terms: List[str] = Field(default_factory=list) |
| requires_aggregation: bool = False |
| confidence: float = Field(default=1.0, ge=0.0, le=1.0) |
|
|
|
|
| class QueryPlannerAgent: |
| """ |
| Plans and decomposes queries for optimal retrieval. |
| |
| Capabilities: |
| 1. Identify query complexity and intent |
| 2. Decompose multi-hop queries into atomic sub-queries |
| 3. Build dependency graph for sub-query execution |
| 4. Expand queries with related terms |
| """ |
|
|
| SYSTEM_PROMPT = """You are a query planning expert. Your job is to analyze user queries and create optimal retrieval plans. |
| |
| For each query, you must: |
| 1. Classify the query intent (factoid, comparison, aggregation, causal, procedural, definition, list, multi_hop) |
| 2. Decompose complex queries into simpler sub-queries |
| 3. Identify dependencies between sub-queries |
| 4. Suggest query expansions (synonyms, related terms) |
| |
| Output your analysis as JSON with this structure: |
| { |
| "intent": "factoid|comparison|aggregation|causal|procedural|definition|list|multi_hop", |
| "sub_queries": [ |
| { |
| "id": "sq1", |
| "query": "the sub-query text", |
| "intent": "factoid", |
| "depends_on": [], |
| "priority": 1, |
| "expected_answer_type": "text|number|date|list|boolean" |
| } |
| ], |
| "expanded_terms": ["synonym1", "related_term1"], |
| "requires_aggregation": false, |
| "confidence": 0.95 |
| } |
| |
| For simple queries, return a single sub-query matching the original. |
| For complex queries requiring multiple steps, break them down logically. |
| """ |
|
|
| def __init__( |
| self, |
| model: str = "llama3.2:3b", |
| base_url: str = "http://localhost:11434", |
| temperature: float = 0.1, |
| use_llm: bool = True, |
| ): |
| """ |
| Initialize Query Planner. |
| |
| Args: |
| model: LLM model for planning |
| base_url: Ollama API URL |
| temperature: LLM temperature (lower = more deterministic) |
| use_llm: If False, use rule-based planning only |
| """ |
| self.model = model |
| self.base_url = base_url.rstrip("/") |
| self.temperature = temperature |
| self.use_llm = use_llm |
|
|
| logger.info(f"QueryPlannerAgent initialized (model={model}, use_llm={use_llm})") |
|
|
| def plan(self, query: str) -> QueryPlan: |
| """ |
| Create execution plan for a query. |
| |
| Args: |
| query: User's natural language query |
| |
| Returns: |
| QueryPlan with sub-queries and metadata |
| """ |
| |
| rule_based_plan = self._rule_based_planning(query) |
|
|
| if not self.use_llm or not HTTPX_AVAILABLE: |
| return rule_based_plan |
|
|
| |
| try: |
| llm_plan = self._llm_planning(query) |
|
|
| |
| if rule_based_plan.expanded_terms: |
| llm_plan.expanded_terms = list(set( |
| llm_plan.expanded_terms + rule_based_plan.expanded_terms |
| )) |
|
|
| return llm_plan |
|
|
| except Exception as e: |
| logger.warning(f"LLM planning failed, using rule-based: {e}") |
| return rule_based_plan |
|
|
| def _rule_based_planning(self, query: str) -> QueryPlan: |
| """Fast rule-based query planning.""" |
| query_lower = query.lower().strip() |
|
|
| |
| intent = self._detect_intent(query_lower) |
|
|
| |
| expansions = self._expand_query(query) |
|
|
| |
| sub_queries = self._decompose_if_needed(query, intent) |
|
|
| return QueryPlan( |
| original_query=query, |
| intent=intent, |
| sub_queries=sub_queries, |
| expanded_terms=expansions, |
| requires_aggregation=intent in [QueryIntent.AGGREGATION, QueryIntent.LIST], |
| confidence=0.8, |
| ) |
|
|
| def _detect_intent(self, query: str) -> QueryIntent: |
| """Detect query intent from patterns.""" |
| |
| if re.match(r"^(what is|define|what are|what does .* mean)", query): |
| return QueryIntent.DEFINITION |
|
|
| |
| if any(p in query for p in ["compare", "difference between", "vs", "versus", "better than"]): |
| return QueryIntent.COMPARISON |
|
|
| |
| if any(p in query for p in ["list", "what are all", "give me all", "enumerate"]): |
| return QueryIntent.LIST |
|
|
| |
| if any(p in query for p in ["why", "how does", "what causes", "reason for"]): |
| return QueryIntent.CAUSAL |
|
|
| |
| if any(p in query for p in ["how to", "steps to", "process for", "how can i"]): |
| return QueryIntent.PROCEDURAL |
|
|
| |
| if any(p in query for p in ["summarize", "overview", "summary of", "main points"]): |
| return QueryIntent.AGGREGATION |
|
|
| |
| if " and " in query and "?" in query: |
| return QueryIntent.MULTI_HOP |
| if query.count("?") > 1: |
| return QueryIntent.MULTI_HOP |
|
|
| |
| return QueryIntent.FACTOID |
|
|
| def _expand_query(self, query: str) -> List[str]: |
| """Generate query expansions (synonyms, related terms).""" |
| expansions = [] |
| query_lower = query.lower() |
|
|
| |
| expansion_map = { |
| "patent": ["intellectual property", "IP", "invention", "claim"], |
| "license": ["licensing", "agreement", "contract", "terms"], |
| "royalty": ["royalties", "payment", "fee", "compensation"], |
| "open source": ["OSS", "FOSS", "free software", "open-source"], |
| "trademark": ["brand", "mark", "logo"], |
| "copyright": ["rights", "authorship", "protection"], |
| "infringement": ["violation", "breach", "unauthorized use"], |
| "disclosure": ["reveal", "publish", "filing"], |
| } |
|
|
| for term, synonyms in expansion_map.items(): |
| if term in query_lower: |
| expansions.extend(synonyms) |
|
|
| return list(set(expansions))[:10] |
|
|
| def _decompose_if_needed(self, query: str, intent: QueryIntent) -> List[SubQuery]: |
| """Decompose query if complex.""" |
|
|
| |
| if intent == QueryIntent.COMPARISON: |
| entities = self._extract_comparison_entities(query) |
| if len(entities) >= 2: |
| sub_queries = [] |
| for i, entity in enumerate(entities): |
| sub_queries.append(SubQuery( |
| id=f"sq{i+1}", |
| query=f"What are the key characteristics of {entity}?", |
| intent=QueryIntent.FACTOID, |
| priority=1, |
| expected_answer_type="text", |
| )) |
| |
| sub_queries.append(SubQuery( |
| id=f"sq{len(entities)+1}", |
| query=query, |
| intent=QueryIntent.COMPARISON, |
| depends_on=[f"sq{i+1}" for i in range(len(entities))], |
| priority=2, |
| expected_answer_type="text", |
| )) |
| return sub_queries |
|
|
| |
| if intent == QueryIntent.MULTI_HOP and " and " in query.lower(): |
| parts = re.split(r'\s+and\s+', query, flags=re.IGNORECASE) |
| sub_queries = [] |
| for i, part in enumerate(parts): |
| part = part.strip().rstrip("?") + "?" |
| sub_queries.append(SubQuery( |
| id=f"sq{i+1}", |
| query=part, |
| intent=QueryIntent.FACTOID, |
| priority=i+1, |
| expected_answer_type="text", |
| )) |
| return sub_queries |
|
|
| |
| return [SubQuery( |
| id="sq1", |
| query=query, |
| intent=intent, |
| priority=1, |
| expected_answer_type="text", |
| )] |
|
|
| def _extract_comparison_entities(self, query: str) -> List[str]: |
| """Extract entities being compared.""" |
| patterns = [ |
| r"(?:compare|difference between)\s+(.+?)\s+(?:and|vs|versus)\s+(.+?)(?:\?|$)", |
| r"(.+?)\s+(?:vs|versus)\s+(.+?)(?:\?|$)", |
| r"(?:between)\s+(.+?)\s+(?:and)\s+(.+?)(?:\?|$)", |
| ] |
|
|
| for pattern in patterns: |
| match = re.search(pattern, query, re.IGNORECASE) |
| if match: |
| return [match.group(1).strip(), match.group(2).strip()] |
|
|
| return [] |
|
|
| def _llm_planning(self, query: str) -> QueryPlan: |
| """Use LLM for sophisticated query planning.""" |
| prompt = f"""Analyze this query and create a retrieval plan: |
| |
| Query: {query} |
| |
| Provide your analysis as JSON.""" |
|
|
| with httpx.Client(timeout=30.0) as client: |
| response = client.post( |
| f"{self.base_url}/api/generate", |
| json={ |
| "model": self.model, |
| "prompt": prompt, |
| "system": self.SYSTEM_PROMPT, |
| "stream": False, |
| "options": { |
| "temperature": self.temperature, |
| "num_predict": 1024, |
| }, |
| }, |
| ) |
| response.raise_for_status() |
| result = response.json() |
|
|
| |
| response_text = result.get("response", "") |
| plan_data = self._parse_json_response(response_text) |
|
|
| |
| sub_queries = [] |
| for sq_data in plan_data.get("sub_queries", []): |
| sub_queries.append(SubQuery( |
| id=sq_data.get("id", "sq1"), |
| query=sq_data.get("query", query), |
| intent=QueryIntent(sq_data.get("intent", "factoid")), |
| depends_on=sq_data.get("depends_on", []), |
| priority=sq_data.get("priority", 1), |
| expected_answer_type=sq_data.get("expected_answer_type", "text"), |
| )) |
|
|
| if not sub_queries: |
| sub_queries = [SubQuery( |
| id="sq1", |
| query=query, |
| intent=QueryIntent.FACTOID, |
| priority=1, |
| )] |
|
|
| return QueryPlan( |
| original_query=query, |
| intent=QueryIntent(plan_data.get("intent", "factoid")), |
| sub_queries=sub_queries, |
| expanded_terms=plan_data.get("expanded_terms", []), |
| requires_aggregation=plan_data.get("requires_aggregation", False), |
| confidence=plan_data.get("confidence", 0.9), |
| ) |
|
|
| def _parse_json_response(self, text: str) -> Dict[str, Any]: |
| """Extract JSON from LLM response.""" |
| |
| json_match = re.search(r'\{[\s\S]*\}', text) |
| if json_match: |
| try: |
| return json.loads(json_match.group()) |
| except json.JSONDecodeError: |
| pass |
|
|
| |
| return { |
| "intent": "factoid", |
| "sub_queries": [], |
| "expanded_terms": [], |
| "requires_aggregation": False, |
| "confidence": 0.7, |
| } |
|
|