nyayasetu / src /agent.py
CaffeinatedCoding's picture
Upload folder using huggingface_hub
0214972 verified
"""
NyayaSetu RAG Agent — single-pass function.
Every user query goes through exactly these steps in order:
1. NER extraction (if model available, else skip gracefully)
2. Query augmentation (append extracted entities)
3. Embed augmented query with MiniLM
4. FAISS retrieval (top-5 chunks)
5. Out-of-domain check (empty results = no relevant judgments)
6. Context assembly (build prompt context from expanded windows)
7. Single LLM call with retry
8. Citation verification
9. Return structured result
WHY single-pass and no while loop?
A while loop that retries the whole pipeline masks failures.
If retrieval returned bad results, retrying with the same query
returns the same bad results. Better to fail honestly and tell
the user, than to loop silently and return garbage.
"""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.embed import embed_text
from src.retrieval import retrieve
from src.llm import call_llm
from src.verify import verify_citations
from typing import Dict, Any
# NER is optional — if not trained yet, pipeline runs without it
# This is the Cut Line Rule from the blueprint:
# ship without NER rather than blocking the whole project
NER_AVAILABLE = False
try:
from src.ner import extract_entities
NER_AVAILABLE = True
print("NER model loaded — query augmentation active")
except Exception as e:
print(f"NER not available, running without entity augmentation: {e}")
def run_query(query: str) -> Dict[str, Any]:
"""
Main pipeline. Input: user query string.
Output: structured dict with answer, sources, verification.
"""
# ── Step 1: NER ──────────────────────────────────────────
entities = {}
augmented_query = query
if NER_AVAILABLE:
try:
entities = extract_entities(query)
entity_string = " ".join(
f"{etype}: {etext}"
for etype, texts in entities.items()
for etext in texts
)
if entity_string:
augmented_query = f"{query} {entity_string}"
except Exception as e:
print(f"NER failed, using raw query: {e}")
augmented_query = query
# ── Step 2: Embed ─────────────────────────────────────────
query_embedding = embed_text(augmented_query)
# ── Step 3: Retrieve ──────────────────────────────────────
retrieved_chunks = retrieve(query_embedding, top_k=5)
# ── Step 4: Out-of-domain check ───────────────────────────
if not retrieved_chunks:
return {
"query": query,
"augmented_query": augmented_query,
"answer": "Your query doesn't appear to relate to Indian law. "
"NyayaSetu can answer questions about Supreme Court judgments, "
"constitutional rights, statutes, and legal provisions. "
"Please ask a legal question.",
"sources": [],
"verification_status": "No sources retrieved",
"unverified_quotes": [],
"entities": entities,
"num_sources": 0,
"truncated": False
}
# ── Step 5: Context assembly ──────────────────────────────
# Check total token estimate — rough rule: 1 token ≈ 4 characters
# LLM context limit ~6000 tokens for context = ~24000 chars
LLM_CONTEXT_LIMIT_CHARS = 24000
truncated = False
context_parts = []
total_chars = 0
for i, chunk in enumerate(retrieved_chunks, 1):
excerpt = chunk["expanded_context"]
header = f"[EXCERPT {i}{chunk['title']} | {chunk['year']} | ID: {chunk['judgment_id']}]\n"
part = header + excerpt + "\n"
if total_chars + len(part) > LLM_CONTEXT_LIMIT_CHARS:
# Drop remaining chunks — too long for LLM context
truncated = True
print(f"Context truncated at {i-1} of {len(retrieved_chunks)} chunks")
break
context_parts.append(part)
total_chars += len(part)
context = "\n".join(context_parts)
# ── Step 6: LLM call ──────────────────────────────────────
try:
answer = call_llm(query=query, context=context)
except Exception as e:
# All 3 retries failed — return raw excerpts as fallback
print(f"LLM call failed after retries: {e}")
fallback_excerpts = "\n\n".join(
f"[{c['title']} | {c['year']}]\n{c['chunk_text'][:500]}"
for c in retrieved_chunks
)
return {
"query": query,
"augmented_query": augmented_query,
"answer": f"LLM service temporarily unavailable. "
f"Most relevant excerpts shown below:\n\n{fallback_excerpts}",
"sources": _build_sources(retrieved_chunks),
"verification_status": "LLM unavailable",
"unverified_quotes": [],
"entities": entities,
"num_sources": len(retrieved_chunks),
"truncated": truncated
}
# ── Step 7: Citation verification ─────────────────────────
verification_status, unverified_quotes = verify_citations(answer, retrieved_chunks)
# ── Step 8: Return ────────────────────────────────────────
return {
"query": query,
"augmented_query": augmented_query,
"answer": answer,
"sources": _build_sources(retrieved_chunks),
"verification_status": verification_status,
"unverified_quotes": unverified_quotes,
"entities": entities,
"num_sources": len(retrieved_chunks),
"truncated": truncated
}
def _build_sources(chunks) -> list:
"""Format retrieved chunks for API response."""
return [
{
"judgment_id": c["judgment_id"],
"title": c["title"],
"year": c["year"],
"similarity_score": round(c["similarity_score"], 4),
"excerpt": c["chunk_text"][:300] + "..."
}
for c in chunks
]
if __name__ == "__main__":
# Smoke test — run directly to verify pipeline works end to end
test_queries = [
"What are the rights of an arrested person under Article 22?",
"What did the Supreme Court say about freedom of speech?",
"How do I bake a cake?" # Out of domain — should return no results
]
for query in test_queries:
print(f"\n{'='*60}")
print(f"QUERY: {query}")
result = run_query(query)
print(f"SOURCES: {result['num_sources']}")
print(f"VERIFICATION: {result['verification_status']}")
print(f"ANSWER (first 300 chars):\n{result['answer'][:300]}")