Spaces:
Running
Running
| """ | |
| NyayaSetu V2 Agent — Full Intelligence Layer. | |
| Pass 1 — ANALYSE: Understands message, detects tone/stage, | |
| builds structured fact web, updates hypotheses, | |
| forms targeted search queries, compresses summary. | |
| Pass 2 — RETRIEVE: Parallel FAISS search. No LLM call. | |
| Pass 3 — RESPOND: Dynamically assembled prompt + retrieved | |
| context + full case state. Format-intelligent output. | |
| 2 LLM calls per turn. src/agent.py untouched. | |
| """ | |
| import os, sys, json, time, logging | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import Dict, Any, List | |
| # sys.path must be set before any local imports | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.embed import embed_text | |
| from src.retrieval import retrieve | |
| from src.verify import verify_citations | |
| from src.system_prompt import build_prompt, ANALYSIS_PROMPT | |
| from src.ner import extract_entities, augment_query | |
| logger = logging.getLogger(__name__) | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| from dotenv import load_dotenv | |
| import threading | |
| import time | |
| from src.llm import call_llm_raw | |
| load_dotenv() | |
| # ── Circuit Breaker for Groq API ────────────────────────── | |
| class CircuitBreaker: | |
| """Simple circuit breaker to detect when Groq API is down.""" | |
| def __init__(self, failure_threshold=5, recovery_timeout=60): | |
| self.failure_count = 0 | |
| self.failure_threshold = failure_threshold | |
| self.recovery_timeout = recovery_timeout | |
| self.last_failure_time = None | |
| self.is_open = False | |
| self.lock = threading.Lock() | |
| def record_success(self): | |
| with self.lock: | |
| self.failure_count = 0 | |
| self.is_open = False | |
| def record_failure(self): | |
| with self.lock: | |
| self.failure_count += 1 | |
| self.last_failure_time = time.time() | |
| if self.failure_count >= self.failure_threshold: | |
| self.is_open = True | |
| logger.warning(f"Circuit breaker OPEN: {self.failure_count} failures detected") | |
| def can_attempt(self) -> bool: | |
| with self.lock: | |
| if not self.is_open: | |
| return True | |
| # Try to recover after timeout | |
| if time.time() - self.last_failure_time > self.recovery_timeout: | |
| logger.info("Circuit breaker attempting recovery...") | |
| self.is_open = False | |
| self.failure_count = 0 | |
| return True | |
| return False | |
| def get_status(self) -> str: | |
| with self.lock: | |
| if self.is_open: | |
| return f"OPEN ({self.failure_count} failures)" | |
| return f"CLOSED ({self.failure_count} failures)" | |
| _circuit_breaker = CircuitBreaker() | |
| # ── Session store ───────────────────────────────────────── | |
| sessions: Dict[str, Dict] = {} | |
| def empty_case_state() -> Dict: | |
| return { | |
| "parties": [], | |
| "events": [], | |
| "documents": [], | |
| "amounts": [], | |
| "locations": [], | |
| "timeline": [], | |
| "disputes": [], | |
| "hypotheses": [], | |
| "stage": "intake", | |
| "last_response_type": "none", | |
| "turn_count": 0, | |
| "facts_missing": [], | |
| "context_interpreted": False, | |
| "last_radar_turn": -3, # track when radar last fired | |
| "last_format": "none", | |
| "format_override_turn": -1, # track when user explicitly requested a format | |
| } | |
| def get_or_create_session(session_id: str) -> Dict: | |
| if session_id not in sessions: | |
| sessions[session_id] = { | |
| "summary": "", | |
| "last_3_messages": [], | |
| "case_state": empty_case_state() | |
| } | |
| return sessions[session_id] | |
| def update_session(session_id: str, analysis: Dict, user_message: str, response: str): | |
| session = sessions[session_id] | |
| cs = session["case_state"] | |
| if analysis.get("updated_summary"): | |
| session["summary"] = analysis["updated_summary"] | |
| facts = analysis.get("facts_extracted", {}) | |
| if facts: | |
| for key in ["parties", "events", "documents", "amounts", "locations", "disputes"]: | |
| new_items = facts.get(key, []) | |
| existing = cs.get(key, []) | |
| for item in new_items: | |
| if item and item not in existing: | |
| existing.append(item) | |
| cs[key] = existing | |
| for ev in facts.get("timeline_events", []): | |
| if ev and ev not in cs["timeline"]: | |
| cs["timeline"].append(ev) | |
| for nh in analysis.get("hypotheses", []): | |
| existing_claims = [h["claim"] for h in cs["hypotheses"]] | |
| if nh.get("claim") and nh["claim"] not in existing_claims: | |
| cs["hypotheses"].append(nh) | |
| else: | |
| for h in cs["hypotheses"]: | |
| if h["claim"] == nh.get("claim"): | |
| h["confidence"] = nh.get("confidence", h["confidence"]) | |
| for e in nh.get("evidence", []): | |
| if e not in h.get("evidence", []): | |
| h.setdefault("evidence", []).append(e) | |
| cs["stage"] = analysis.get("stage", cs["stage"]) | |
| cs["last_response_type"] = analysis.get("action_needed", "none") | |
| cs["facts_missing"] = analysis.get("facts_missing", []) | |
| cs["last_format"] = analysis.get("format_decision", "none") | |
| cs["turn_count"] = cs.get("turn_count", 0) + 1 | |
| if cs["turn_count"] >= 3: | |
| cs["context_interpreted"] = True | |
| session["last_3_messages"].append({"role": "user", "content": user_message}) | |
| session["last_3_messages"].append({"role": "assistant", "content": response[:400]}) | |
| if len(session["last_3_messages"]) > 6: | |
| session["last_3_messages"] = session["last_3_messages"][-6:] | |
| # ── Pass 1: Analyse ─────────────────────────────────────── | |
| # Retry up to 5 times with exponential backoff (1s to 16s) to handle transient failures | |
| def analyse(user_message: str, session: Dict) -> Dict: | |
| if not _circuit_breaker.can_attempt(): | |
| logger.error(f"Circuit breaker OPEN - skipping Pass 1. Status: {_circuit_breaker.get_status()}") | |
| raise Exception("Groq API circuit breaker is open - service unavailable") | |
| summary = session.get("summary", "") | |
| last_msgs = session.get("last_3_messages", []) | |
| cs = session["case_state"] | |
| last_response_type = cs.get("last_response_type", "none") | |
| turn_count = cs.get("turn_count", 0) | |
| history_text = "\n".join( | |
| f"{m['role'].upper()}: {m['content'][:250]}" | |
| for m in last_msgs[-4:] | |
| ) if last_msgs else "" | |
| fact_web = "" | |
| if any(cs.get(k) for k in ["parties", "events", "documents", "amounts", "disputes"]): | |
| hyp_lines = "\n".join( | |
| f" - {h['claim']} [{h.get('confidence','?')}]" | |
| for h in cs.get("hypotheses", [])[:3] | |
| ) or " none yet" | |
| fact_web = f""" | |
| CURRENT FACT WEB: | |
| - Parties: {', '.join(cs.get('parties', [])) or 'none'} | |
| - Events: {', '.join(cs.get('events', [])) or 'none'} | |
| - Documents/Evidence: {', '.join(cs.get('documents', [])) or 'none'} | |
| - Amounts: {', '.join(cs.get('amounts', [])) or 'none'} | |
| - Disputes: {', '.join(cs.get('disputes', [])) or 'none'} | |
| - Active hypotheses: | |
| {hyp_lines}""" | |
| user_content = f"""CONVERSATION SUMMARY: | |
| {summary if summary else "First message — no prior context."} | |
| RECENT MESSAGES: | |
| {history_text if history_text else "None"} | |
| LAST RESPONSE TYPE: {last_response_type} | |
| TURN COUNT: {turn_count} | |
| {fact_web} | |
| NEW USER MESSAGE: | |
| {user_message} | |
| Rules: | |
| - If last_response_type was "question", action_needed CANNOT be "question" | |
| - action_needed SHOULD differ from last_response_type for variety | |
| - Extract ALL facts from user message even if implied | |
| - Update hypothesis confidence based on new evidence | |
| - search_queries must be specific legal questions for vector search | |
| - format_decision must be chosen fresh each turn based on THIS message's content | |
| - NEVER carry over format_decision from previous turn unless user explicitly requests it again | |
| - If user requested a specific format last turn, revert to most natural format this turn""" | |
| response = call_llm_raw( | |
| messages=[ | |
| {"role": "system", "content": ANALYSIS_PROMPT}, | |
| {"role": "user", "content": user_content} | |
| ] | |
| ) | |
| _circuit_breaker.record_success() # API call succeeded | |
| raw = response.strip() | |
| raw = raw.replace("```json", "").replace("```", "").strip() | |
| try: | |
| analysis = json.loads(raw) | |
| except json.JSONDecodeError: | |
| logger.warning(f"Pass 1 JSON parse failed: {raw[:200]}") | |
| analysis = { | |
| "tone": "casual", "format_requested": "none", | |
| "subject": "legal query", "action_needed": "advice", | |
| "urgency": "medium", | |
| "hypotheses": [{"claim": user_message[:80], "confidence": "low", "evidence": []}], | |
| "facts_extracted": {}, "facts_missing": [], | |
| "legal_issues": [], "clarifying_question": {}, | |
| "stage": "understanding", "last_response_type": last_response_type, | |
| "updated_summary": f"{summary} | {user_message[:100]}", | |
| "search_queries": [user_message[:200]], | |
| "should_interpret_context": False, | |
| "format_decision": "none" | |
| } | |
| return analysis | |
| # ── Pass 2: Retrieve ────────────────────────────────────── | |
| def retrieve_parallel(search_queries: List[str], top_k: int = 5) -> List[Dict]: | |
| if not search_queries: | |
| return [] | |
| all_results = [] | |
| def search_one(query): | |
| try: | |
| embedding = embed_text(query) | |
| return retrieve(embedding, top_k=top_k) | |
| except Exception as e: | |
| logger.warning(f"FAISS search failed: {e}") | |
| return [] | |
| with ThreadPoolExecutor(max_workers=min(3, len(search_queries))) as executor: | |
| futures = {executor.submit(search_one, q): q for q in search_queries} | |
| for future in as_completed(futures): | |
| all_results.extend(future.result()) | |
| seen = {} | |
| for chunk in all_results: | |
| cid = chunk.get("chunk_id") or chunk.get("judgment_id", "") | |
| score = chunk.get("similarity_score", 999) | |
| if cid not in seen or score < seen[cid]["similarity_score"]: | |
| seen[cid] = chunk | |
| return sorted(seen.values(), key=lambda x: x.get("similarity_score", 999))[:top_k] | |
| # ── Pass 3: Respond ─────────────────────────────────────── | |
| # Retry up to 5 times with exponential backoff (2s to 32s) — more aggressive than Pass 1 | |
| def respond(user_message: str, analysis: Dict, chunks: List[Dict], session: Dict) -> str: | |
| if not _circuit_breaker.can_attempt(): | |
| logger.error(f"Circuit breaker OPEN - skipping Pass 3. Status: {_circuit_breaker.get_status()}") | |
| raise Exception("Groq API circuit breaker is open - service unavailable") | |
| system_prompt = build_prompt(analysis) | |
| cs = session["case_state"] | |
| turn_count = cs.get("turn_count", 0) | |
| context_parts = [] | |
| for chunk in chunks[:5]: | |
| source_type = chunk.get("source_type", "case_law") | |
| title = chunk.get("title", "Unknown") | |
| year = chunk.get("year", "") | |
| jid = chunk.get("judgment_id", "") | |
| text = chunk.get("expanded_context") or chunk.get("chunk_text") or chunk.get("text", "") | |
| type_labels = { | |
| "statute": f"[STATUTE: {title} | {year}]", | |
| "procedure": f"[PROCEDURE: {title}]", | |
| "law_commission": f"[LAW COMMISSION: {title}]", | |
| "legal_reference": f"[LEGAL REFERENCE: {title}]", | |
| "statute_qa": f"[LEGAL QA: {title}]", | |
| } | |
| header = type_labels.get(source_type, f"[CASE: {title} | {year} | {jid}]") | |
| context_parts.append(f"{header}\n{text[:800]}") | |
| context = "\n\n".join(context_parts) if context_parts else "No relevant sources retrieved." | |
| case_summary = "" | |
| if cs.get("parties") or cs.get("hypotheses"): | |
| hyp_text = "\n".join( | |
| f" - {h['claim']} [{h.get('confidence','?')} confidence] " | |
| f"| evidence: {', '.join(h.get('evidence', [])) or 'none yet'}" | |
| for h in cs.get("hypotheses", [])[:4] | |
| ) or " none established" | |
| case_summary = f""" | |
| CASE STATE (built across {turn_count} turns): | |
| Parties: {', '.join(cs.get('parties', [])) or 'unspecified'} | |
| Events: {', '.join(cs.get('events', [])) or 'unspecified'} | |
| Evidence: {', '.join(cs.get('documents', [])) or 'none mentioned'} | |
| Amounts: {', '.join(cs.get('amounts', [])) or 'none'} | |
| Active hypotheses: | |
| {hyp_text} | |
| Missing facts: {', '.join(cs.get('facts_missing', [])) or 'none critical'} | |
| Stage: {cs.get('stage', 'intake')}""" | |
| # Context interpretation — only once per conversation at turn 2 | |
| interpret_instruction = "" | |
| should_interpret = analysis.get("should_interpret_context", False) | |
| if should_interpret and not cs.get("context_interpreted") and turn_count == 2: | |
| interpret_instruction = "\nIn one sentence only, reflect back your understanding of the situation before responding." | |
| # Radar — only fires every 3 turns, not every turn | |
| last_radar_turn = cs.get("last_radar_turn", -3) | |
| if (turn_count - last_radar_turn) >= 3: | |
| cs["last_radar_turn"] = turn_count | |
| radar_instruction = """ | |
| PROACTIVE RADAR — only if a genuinely non-obvious legal angle exists that hasn't been mentioned yet: | |
| Add a single "⚡ You Should Also Know:" line (1-2 sentences max). | |
| Skip entirely if the response already covers all relevant angles or if this is a question/understanding turn.""" | |
| else: | |
| radar_instruction = "Do NOT add a 'You Should Also Know' section this turn." | |
| summary = session.get("summary", "") | |
| last_msgs = session.get("last_3_messages", []) | |
| history_text = "\n".join( | |
| f"{m['role'].upper()}: {m['content'][:300]}" | |
| for m in last_msgs[-4:] | |
| ) if last_msgs else "" | |
| user_content = f"""CONVERSATION SUMMARY: | |
| {summary if summary else "First message."} | |
| RECENT CONVERSATION: | |
| {history_text if history_text else "None"} | |
| {case_summary} | |
| RETRIEVED LEGAL SOURCES: | |
| {context} | |
| USER MESSAGE: {user_message} | |
| THIS TURN: | |
| - Legal hypotheses: {', '.join(h['claim'] for h in analysis.get('hypotheses', [])[:3]) or 'analysing'} | |
| - Stage: {analysis.get('stage', 'understanding')} | |
| - Urgency: {analysis.get('urgency', 'medium')} | |
| - Response type: {analysis.get('action_needed', 'advice')} | |
| - Format: {analysis.get('format_decision', 'appropriate for content')} | |
| {interpret_instruction} | |
| Instructions: | |
| - Cite specific sources when making legal claims | |
| - Use your legal knowledge for reasoning and context | |
| - Format: {analysis.get('format_decision', 'use the most appropriate format for the content type')} | |
| - Opposition war-gaming: if giving strategy, include what the other side will argue | |
| {radar_instruction}""" | |
| response = call_llm_raw( | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_content} | |
| ] | |
| ) | |
| _circuit_breaker.record_success() # API call succeeded | |
| return response | |
| # ── Main entry point ────────────────────────────────────── | |
| def run_query_v2(user_message: str, session_id: str) -> Dict[str, Any]: | |
| start = time.time() | |
| session = get_or_create_session(session_id) | |
| # Pass 1 | |
| try: | |
| analysis = analyse(user_message, session) | |
| except Exception as e: | |
| error_type = type(e).__name__ | |
| logger.error(f"Pass 1 failed after retries: {error_type}: {e}. Circuit breaker: {_circuit_breaker.get_status()}") | |
| # Record API failure if it was a connection error | |
| if "APIConnectionError" in error_type or "RateLimitError" in error_type: | |
| _circuit_breaker.record_failure() | |
| analysis = { | |
| "tone": "casual", "format_requested": "none", | |
| "subject": "legal query", "action_needed": "advice", | |
| "urgency": "medium", | |
| "hypotheses": [{"claim": user_message[:80], "confidence": "low", "evidence": []}], | |
| "facts_extracted": {}, "facts_missing": [], | |
| "legal_issues": [], "clarifying_question": {}, | |
| "stage": "understanding", "last_response_type": "none", | |
| "updated_summary": user_message[:200], | |
| "search_queries": [user_message[:200]], | |
| "should_interpret_context": False, | |
| "format_decision": "none" | |
| } | |
| # Extract entities and augment queries for better retrieval | |
| entities = extract_entities(user_message) | |
| augmented_message = augment_query(user_message, entities) | |
| # Pass 2 — build search queries from analysis + legal issues | |
| search_queries = analysis.get("search_queries", [augmented_message]) | |
| if not search_queries: | |
| search_queries = [augmented_message] | |
| # Add queries from issue spotter | |
| for issue in analysis.get("legal_issues", []): | |
| statutes = issue.get("relevant_statutes", []) | |
| specific = issue.get("specific_issue", "") | |
| if specific: | |
| issue_query = f"{specific} {' '.join(statutes[:2])}".strip() | |
| if issue_query not in search_queries: | |
| search_queries.append(issue_query) | |
| if augmented_message not in search_queries: | |
| search_queries.append(augmented_message) | |
| chunks = [] | |
| try: | |
| # Retrieve more candidates for reranker to work with | |
| raw_chunks = retrieve_parallel(search_queries[:3], top_k=10) | |
| # Rerank candidates by true relevance | |
| from src.reranker import rerank | |
| chunks = rerank(user_message, raw_chunks, top_k=5) | |
| # Add precedent chain | |
| from src.citation_graph import get_precedent_chain | |
| retrieved_ids = [c.get("judgment_id", "") for c in chunks] | |
| precedents = get_precedent_chain(retrieved_ids, max_precedents=2) | |
| if precedents: | |
| chunks.extend(precedents) | |
| except Exception as e: | |
| logger.error(f"Pass 2 failed: {e}") | |
| # Pass 3 | |
| try: | |
| answer = respond(user_message, analysis, chunks, session) | |
| except Exception as e: | |
| error_type = type(e).__name__ | |
| logger.error(f"Pass 3 failed after retries: {error_type}: {e}. Circuit breaker: {_circuit_breaker.get_status()}") | |
| # Record API failure if it was a connection error | |
| if "APIConnectionError" in error_type or "RateLimitError" in error_type: | |
| _circuit_breaker.record_failure() | |
| if chunks: | |
| fallback = "\n\n".join( | |
| f"[{c.get('title', 'Source')}]\n{c.get('text', '')[:400]}" | |
| for c in chunks[:3] | |
| ) | |
| answer = f"LLM service temporarily unavailable. Most relevant excerpts:\n\n{fallback}" | |
| else: | |
| answer = "I encountered an issue processing your request. Please try again." | |
| verification_status, unverified_quotes = verify_citations(answer, chunks) | |
| update_session(session_id, analysis, user_message, answer) | |
| sources = [] | |
| for c in chunks: | |
| title = c.get("title", "") | |
| jid = c.get("judgment_id", "") | |
| sources.append({ | |
| "meta": { | |
| "judgment_id": jid, | |
| "title": title if title and title != jid else jid, | |
| "year": c.get("year", ""), | |
| "chunk_index": c.get("chunk_index", 0), | |
| "source_type": c.get("source_type", "case_law"), | |
| "court": c.get("court", "Supreme Court of India") | |
| }, | |
| "text": (c.get("expanded_context") or c.get("chunk_text") or c.get("text", ""))[:600] | |
| }) | |
| return { | |
| "query": user_message, | |
| "answer": answer, | |
| "sources": sources, | |
| "verification_status": verification_status, | |
| "unverified_quotes": unverified_quotes, | |
| "entities": entities, | |
| "num_sources": len(chunks), | |
| "truncated": False, | |
| "session_id": session_id, | |
| "analysis": { | |
| "tone": analysis.get("tone"), | |
| "stage": analysis.get("stage"), | |
| "urgency": analysis.get("urgency"), | |
| "hypotheses": [h["claim"] for h in analysis.get("hypotheses", [])] | |
| } | |
| } |