nyayasetu / src /agent_v2.py
CaffeinatedCoding's picture
Upload folder using huggingface_hub
adf245d verified
"""
NyayaSetu V2 Agent — Full Intelligence Layer.
Pass 1 — ANALYSE: Understands message, detects tone/stage,
builds structured fact web, updates hypotheses,
forms targeted search queries, compresses summary.
Pass 2 — RETRIEVE: Parallel FAISS search. No LLM call.
Pass 3 — RESPOND: Dynamically assembled prompt + retrieved
context + full case state. Format-intelligent output.
2 LLM calls per turn. src/agent.py untouched.
"""
import os, sys, json, time, logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any, List
# sys.path must be set before any local imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.embed import embed_text
from src.retrieval import retrieve
from src.verify import verify_citations
from src.system_prompt import build_prompt, ANALYSIS_PROMPT
from src.ner import extract_entities, augment_query
logger = logging.getLogger(__name__)
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from dotenv import load_dotenv
import threading
import time
from src.llm import call_llm_raw
load_dotenv()
# ── Circuit Breaker for Groq API ──────────────────────────
class CircuitBreaker:
"""Simple circuit breaker to detect when Groq API is down."""
def __init__(self, failure_threshold=5, recovery_timeout=60):
self.failure_count = 0
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.last_failure_time = None
self.is_open = False
self.lock = threading.Lock()
def record_success(self):
with self.lock:
self.failure_count = 0
self.is_open = False
def record_failure(self):
with self.lock:
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.is_open = True
logger.warning(f"Circuit breaker OPEN: {self.failure_count} failures detected")
def can_attempt(self) -> bool:
with self.lock:
if not self.is_open:
return True
# Try to recover after timeout
if time.time() - self.last_failure_time > self.recovery_timeout:
logger.info("Circuit breaker attempting recovery...")
self.is_open = False
self.failure_count = 0
return True
return False
def get_status(self) -> str:
with self.lock:
if self.is_open:
return f"OPEN ({self.failure_count} failures)"
return f"CLOSED ({self.failure_count} failures)"
_circuit_breaker = CircuitBreaker()
# ── Session store ─────────────────────────────────────────
sessions: Dict[str, Dict] = {}
def empty_case_state() -> Dict:
return {
"parties": [],
"events": [],
"documents": [],
"amounts": [],
"locations": [],
"timeline": [],
"disputes": [],
"hypotheses": [],
"stage": "intake",
"last_response_type": "none",
"turn_count": 0,
"facts_missing": [],
"context_interpreted": False,
"last_radar_turn": -3, # track when radar last fired
"last_format": "none",
"format_override_turn": -1, # track when user explicitly requested a format
}
def get_or_create_session(session_id: str) -> Dict:
if session_id not in sessions:
sessions[session_id] = {
"summary": "",
"last_3_messages": [],
"case_state": empty_case_state()
}
return sessions[session_id]
def update_session(session_id: str, analysis: Dict, user_message: str, response: str):
session = sessions[session_id]
cs = session["case_state"]
if analysis.get("updated_summary"):
session["summary"] = analysis["updated_summary"]
facts = analysis.get("facts_extracted", {})
if facts:
for key in ["parties", "events", "documents", "amounts", "locations", "disputes"]:
new_items = facts.get(key, [])
existing = cs.get(key, [])
for item in new_items:
if item and item not in existing:
existing.append(item)
cs[key] = existing
for ev in facts.get("timeline_events", []):
if ev and ev not in cs["timeline"]:
cs["timeline"].append(ev)
for nh in analysis.get("hypotheses", []):
existing_claims = [h["claim"] for h in cs["hypotheses"]]
if nh.get("claim") and nh["claim"] not in existing_claims:
cs["hypotheses"].append(nh)
else:
for h in cs["hypotheses"]:
if h["claim"] == nh.get("claim"):
h["confidence"] = nh.get("confidence", h["confidence"])
for e in nh.get("evidence", []):
if e not in h.get("evidence", []):
h.setdefault("evidence", []).append(e)
cs["stage"] = analysis.get("stage", cs["stage"])
cs["last_response_type"] = analysis.get("action_needed", "none")
cs["facts_missing"] = analysis.get("facts_missing", [])
cs["last_format"] = analysis.get("format_decision", "none")
cs["turn_count"] = cs.get("turn_count", 0) + 1
if cs["turn_count"] >= 3:
cs["context_interpreted"] = True
session["last_3_messages"].append({"role": "user", "content": user_message})
session["last_3_messages"].append({"role": "assistant", "content": response[:400]})
if len(session["last_3_messages"]) > 6:
session["last_3_messages"] = session["last_3_messages"][-6:]
# ── Pass 1: Analyse ───────────────────────────────────────
# Retry up to 5 times with exponential backoff (1s to 16s) to handle transient failures
@retry(stop=stop_after_attempt(5), wait=wait_exponential(min=1, max=16, multiplier=1.5))
def analyse(user_message: str, session: Dict) -> Dict:
if not _circuit_breaker.can_attempt():
logger.error(f"Circuit breaker OPEN - skipping Pass 1. Status: {_circuit_breaker.get_status()}")
raise Exception("Groq API circuit breaker is open - service unavailable")
summary = session.get("summary", "")
last_msgs = session.get("last_3_messages", [])
cs = session["case_state"]
last_response_type = cs.get("last_response_type", "none")
turn_count = cs.get("turn_count", 0)
history_text = "\n".join(
f"{m['role'].upper()}: {m['content'][:250]}"
for m in last_msgs[-4:]
) if last_msgs else ""
fact_web = ""
if any(cs.get(k) for k in ["parties", "events", "documents", "amounts", "disputes"]):
hyp_lines = "\n".join(
f" - {h['claim']} [{h.get('confidence','?')}]"
for h in cs.get("hypotheses", [])[:3]
) or " none yet"
fact_web = f"""
CURRENT FACT WEB:
- Parties: {', '.join(cs.get('parties', [])) or 'none'}
- Events: {', '.join(cs.get('events', [])) or 'none'}
- Documents/Evidence: {', '.join(cs.get('documents', [])) or 'none'}
- Amounts: {', '.join(cs.get('amounts', [])) or 'none'}
- Disputes: {', '.join(cs.get('disputes', [])) or 'none'}
- Active hypotheses:
{hyp_lines}"""
user_content = f"""CONVERSATION SUMMARY:
{summary if summary else "First message — no prior context."}
RECENT MESSAGES:
{history_text if history_text else "None"}
LAST RESPONSE TYPE: {last_response_type}
TURN COUNT: {turn_count}
{fact_web}
NEW USER MESSAGE:
{user_message}
Rules:
- If last_response_type was "question", action_needed CANNOT be "question"
- action_needed SHOULD differ from last_response_type for variety
- Extract ALL facts from user message even if implied
- Update hypothesis confidence based on new evidence
- search_queries must be specific legal questions for vector search
- format_decision must be chosen fresh each turn based on THIS message's content
- NEVER carry over format_decision from previous turn unless user explicitly requests it again
- If user requested a specific format last turn, revert to most natural format this turn"""
response = call_llm_raw(
messages=[
{"role": "system", "content": ANALYSIS_PROMPT},
{"role": "user", "content": user_content}
]
)
_circuit_breaker.record_success() # API call succeeded
raw = response.strip()
raw = raw.replace("```json", "").replace("```", "").strip()
try:
analysis = json.loads(raw)
except json.JSONDecodeError:
logger.warning(f"Pass 1 JSON parse failed: {raw[:200]}")
analysis = {
"tone": "casual", "format_requested": "none",
"subject": "legal query", "action_needed": "advice",
"urgency": "medium",
"hypotheses": [{"claim": user_message[:80], "confidence": "low", "evidence": []}],
"facts_extracted": {}, "facts_missing": [],
"legal_issues": [], "clarifying_question": {},
"stage": "understanding", "last_response_type": last_response_type,
"updated_summary": f"{summary} | {user_message[:100]}",
"search_queries": [user_message[:200]],
"should_interpret_context": False,
"format_decision": "none"
}
return analysis
# ── Pass 2: Retrieve ──────────────────────────────────────
def retrieve_parallel(search_queries: List[str], top_k: int = 5) -> List[Dict]:
if not search_queries:
return []
all_results = []
def search_one(query):
try:
embedding = embed_text(query)
return retrieve(embedding, top_k=top_k)
except Exception as e:
logger.warning(f"FAISS search failed: {e}")
return []
with ThreadPoolExecutor(max_workers=min(3, len(search_queries))) as executor:
futures = {executor.submit(search_one, q): q for q in search_queries}
for future in as_completed(futures):
all_results.extend(future.result())
seen = {}
for chunk in all_results:
cid = chunk.get("chunk_id") or chunk.get("judgment_id", "")
score = chunk.get("similarity_score", 999)
if cid not in seen or score < seen[cid]["similarity_score"]:
seen[cid] = chunk
return sorted(seen.values(), key=lambda x: x.get("similarity_score", 999))[:top_k]
# ── Pass 3: Respond ───────────────────────────────────────
# Retry up to 5 times with exponential backoff (2s to 32s) — more aggressive than Pass 1
@retry(stop=stop_after_attempt(5), wait=wait_exponential(min=2, max=32, multiplier=1.5))
def respond(user_message: str, analysis: Dict, chunks: List[Dict], session: Dict) -> str:
if not _circuit_breaker.can_attempt():
logger.error(f"Circuit breaker OPEN - skipping Pass 3. Status: {_circuit_breaker.get_status()}")
raise Exception("Groq API circuit breaker is open - service unavailable")
system_prompt = build_prompt(analysis)
cs = session["case_state"]
turn_count = cs.get("turn_count", 0)
context_parts = []
for chunk in chunks[:5]:
source_type = chunk.get("source_type", "case_law")
title = chunk.get("title", "Unknown")
year = chunk.get("year", "")
jid = chunk.get("judgment_id", "")
text = chunk.get("expanded_context") or chunk.get("chunk_text") or chunk.get("text", "")
type_labels = {
"statute": f"[STATUTE: {title} | {year}]",
"procedure": f"[PROCEDURE: {title}]",
"law_commission": f"[LAW COMMISSION: {title}]",
"legal_reference": f"[LEGAL REFERENCE: {title}]",
"statute_qa": f"[LEGAL QA: {title}]",
}
header = type_labels.get(source_type, f"[CASE: {title} | {year} | {jid}]")
context_parts.append(f"{header}\n{text[:800]}")
context = "\n\n".join(context_parts) if context_parts else "No relevant sources retrieved."
case_summary = ""
if cs.get("parties") or cs.get("hypotheses"):
hyp_text = "\n".join(
f" - {h['claim']} [{h.get('confidence','?')} confidence] "
f"| evidence: {', '.join(h.get('evidence', [])) or 'none yet'}"
for h in cs.get("hypotheses", [])[:4]
) or " none established"
case_summary = f"""
CASE STATE (built across {turn_count} turns):
Parties: {', '.join(cs.get('parties', [])) or 'unspecified'}
Events: {', '.join(cs.get('events', [])) or 'unspecified'}
Evidence: {', '.join(cs.get('documents', [])) or 'none mentioned'}
Amounts: {', '.join(cs.get('amounts', [])) or 'none'}
Active hypotheses:
{hyp_text}
Missing facts: {', '.join(cs.get('facts_missing', [])) or 'none critical'}
Stage: {cs.get('stage', 'intake')}"""
# Context interpretation — only once per conversation at turn 2
interpret_instruction = ""
should_interpret = analysis.get("should_interpret_context", False)
if should_interpret and not cs.get("context_interpreted") and turn_count == 2:
interpret_instruction = "\nIn one sentence only, reflect back your understanding of the situation before responding."
# Radar — only fires every 3 turns, not every turn
last_radar_turn = cs.get("last_radar_turn", -3)
if (turn_count - last_radar_turn) >= 3:
cs["last_radar_turn"] = turn_count
radar_instruction = """
PROACTIVE RADAR — only if a genuinely non-obvious legal angle exists that hasn't been mentioned yet:
Add a single "⚡ You Should Also Know:" line (1-2 sentences max).
Skip entirely if the response already covers all relevant angles or if this is a question/understanding turn."""
else:
radar_instruction = "Do NOT add a 'You Should Also Know' section this turn."
summary = session.get("summary", "")
last_msgs = session.get("last_3_messages", [])
history_text = "\n".join(
f"{m['role'].upper()}: {m['content'][:300]}"
for m in last_msgs[-4:]
) if last_msgs else ""
user_content = f"""CONVERSATION SUMMARY:
{summary if summary else "First message."}
RECENT CONVERSATION:
{history_text if history_text else "None"}
{case_summary}
RETRIEVED LEGAL SOURCES:
{context}
USER MESSAGE: {user_message}
THIS TURN:
- Legal hypotheses: {', '.join(h['claim'] for h in analysis.get('hypotheses', [])[:3]) or 'analysing'}
- Stage: {analysis.get('stage', 'understanding')}
- Urgency: {analysis.get('urgency', 'medium')}
- Response type: {analysis.get('action_needed', 'advice')}
- Format: {analysis.get('format_decision', 'appropriate for content')}
{interpret_instruction}
Instructions:
- Cite specific sources when making legal claims
- Use your legal knowledge for reasoning and context
- Format: {analysis.get('format_decision', 'use the most appropriate format for the content type')}
- Opposition war-gaming: if giving strategy, include what the other side will argue
{radar_instruction}"""
response = call_llm_raw(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content}
]
)
_circuit_breaker.record_success() # API call succeeded
return response
# ── Main entry point ──────────────────────────────────────
def run_query_v2(user_message: str, session_id: str) -> Dict[str, Any]:
start = time.time()
session = get_or_create_session(session_id)
# Pass 1
try:
analysis = analyse(user_message, session)
except Exception as e:
error_type = type(e).__name__
logger.error(f"Pass 1 failed after retries: {error_type}: {e}. Circuit breaker: {_circuit_breaker.get_status()}")
# Record API failure if it was a connection error
if "APIConnectionError" in error_type or "RateLimitError" in error_type:
_circuit_breaker.record_failure()
analysis = {
"tone": "casual", "format_requested": "none",
"subject": "legal query", "action_needed": "advice",
"urgency": "medium",
"hypotheses": [{"claim": user_message[:80], "confidence": "low", "evidence": []}],
"facts_extracted": {}, "facts_missing": [],
"legal_issues": [], "clarifying_question": {},
"stage": "understanding", "last_response_type": "none",
"updated_summary": user_message[:200],
"search_queries": [user_message[:200]],
"should_interpret_context": False,
"format_decision": "none"
}
# Extract entities and augment queries for better retrieval
entities = extract_entities(user_message)
augmented_message = augment_query(user_message, entities)
# Pass 2 — build search queries from analysis + legal issues
search_queries = analysis.get("search_queries", [augmented_message])
if not search_queries:
search_queries = [augmented_message]
# Add queries from issue spotter
for issue in analysis.get("legal_issues", []):
statutes = issue.get("relevant_statutes", [])
specific = issue.get("specific_issue", "")
if specific:
issue_query = f"{specific} {' '.join(statutes[:2])}".strip()
if issue_query not in search_queries:
search_queries.append(issue_query)
if augmented_message not in search_queries:
search_queries.append(augmented_message)
chunks = []
try:
# Retrieve more candidates for reranker to work with
raw_chunks = retrieve_parallel(search_queries[:3], top_k=10)
# Rerank candidates by true relevance
from src.reranker import rerank
chunks = rerank(user_message, raw_chunks, top_k=5)
# Add precedent chain
from src.citation_graph import get_precedent_chain
retrieved_ids = [c.get("judgment_id", "") for c in chunks]
precedents = get_precedent_chain(retrieved_ids, max_precedents=2)
if precedents:
chunks.extend(precedents)
except Exception as e:
logger.error(f"Pass 2 failed: {e}")
# Pass 3
try:
answer = respond(user_message, analysis, chunks, session)
except Exception as e:
error_type = type(e).__name__
logger.error(f"Pass 3 failed after retries: {error_type}: {e}. Circuit breaker: {_circuit_breaker.get_status()}")
# Record API failure if it was a connection error
if "APIConnectionError" in error_type or "RateLimitError" in error_type:
_circuit_breaker.record_failure()
if chunks:
fallback = "\n\n".join(
f"[{c.get('title', 'Source')}]\n{c.get('text', '')[:400]}"
for c in chunks[:3]
)
answer = f"LLM service temporarily unavailable. Most relevant excerpts:\n\n{fallback}"
else:
answer = "I encountered an issue processing your request. Please try again."
verification_status, unverified_quotes = verify_citations(answer, chunks)
update_session(session_id, analysis, user_message, answer)
sources = []
for c in chunks:
title = c.get("title", "")
jid = c.get("judgment_id", "")
sources.append({
"meta": {
"judgment_id": jid,
"title": title if title and title != jid else jid,
"year": c.get("year", ""),
"chunk_index": c.get("chunk_index", 0),
"source_type": c.get("source_type", "case_law"),
"court": c.get("court", "Supreme Court of India")
},
"text": (c.get("expanded_context") or c.get("chunk_text") or c.get("text", ""))[:600]
})
return {
"query": user_message,
"answer": answer,
"sources": sources,
"verification_status": verification_status,
"unverified_quotes": unverified_quotes,
"entities": entities,
"num_sources": len(chunks),
"truncated": False,
"session_id": session_id,
"analysis": {
"tone": analysis.get("tone"),
"stage": analysis.get("stage"),
"urgency": analysis.get("urgency"),
"hypotheses": [h["claim"] for h in analysis.get("hypotheses", [])]
}
}