Spaces:
Running
Running
| """ | |
| NyayaSetu RAG Agent — single-pass function. | |
| Every user query goes through exactly these steps in order: | |
| 1. NER extraction (if model available, else skip gracefully) | |
| 2. Query augmentation (append extracted entities) | |
| 3. Embed augmented query with MiniLM | |
| 4. FAISS retrieval (top-5 chunks) | |
| 5. Out-of-domain check (empty results = no relevant judgments) | |
| 6. Context assembly (build prompt context from expanded windows) | |
| 7. Single LLM call with retry | |
| 8. Citation verification | |
| 9. Return structured result | |
| WHY single-pass and no while loop? | |
| A while loop that retries the whole pipeline masks failures. | |
| If retrieval returned bad results, retrying with the same query | |
| returns the same bad results. Better to fail honestly and tell | |
| the user, than to loop silently and return garbage. | |
| """ | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.embed import embed_text | |
| from src.retrieval import retrieve | |
| from src.llm import call_llm | |
| from src.verify import verify_citations | |
| from typing import Dict, Any | |
| # NER is optional — if not trained yet, pipeline runs without it | |
| # This is the Cut Line Rule from the blueprint: | |
| # ship without NER rather than blocking the whole project | |
| NER_AVAILABLE = False | |
| try: | |
| from src.ner import extract_entities | |
| NER_AVAILABLE = True | |
| print("NER model loaded — query augmentation active") | |
| except Exception as e: | |
| print(f"NER not available, running without entity augmentation: {e}") | |
| def run_query(query: str) -> Dict[str, Any]: | |
| """ | |
| Main pipeline. Input: user query string. | |
| Output: structured dict with answer, sources, verification. | |
| """ | |
| # ── Step 1: NER ────────────────────────────────────────── | |
| entities = {} | |
| augmented_query = query | |
| if NER_AVAILABLE: | |
| try: | |
| entities = extract_entities(query) | |
| entity_string = " ".join( | |
| f"{etype}: {etext}" | |
| for etype, texts in entities.items() | |
| for etext in texts | |
| ) | |
| if entity_string: | |
| augmented_query = f"{query} {entity_string}" | |
| except Exception as e: | |
| print(f"NER failed, using raw query: {e}") | |
| augmented_query = query | |
| # ── Step 2: Embed ───────────────────────────────────────── | |
| query_embedding = embed_text(augmented_query) | |
| # ── Step 3: Retrieve ────────────────────────────────────── | |
| retrieved_chunks = retrieve(query_embedding, top_k=5) | |
| # ── Step 4: Out-of-domain check ─────────────────────────── | |
| if not retrieved_chunks: | |
| return { | |
| "query": query, | |
| "augmented_query": augmented_query, | |
| "answer": "Your query doesn't appear to relate to Indian law. " | |
| "NyayaSetu can answer questions about Supreme Court judgments, " | |
| "constitutional rights, statutes, and legal provisions. " | |
| "Please ask a legal question.", | |
| "sources": [], | |
| "verification_status": "No sources retrieved", | |
| "unverified_quotes": [], | |
| "entities": entities, | |
| "num_sources": 0, | |
| "truncated": False | |
| } | |
| # ── Step 5: Context assembly ────────────────────────────── | |
| # Check total token estimate — rough rule: 1 token ≈ 4 characters | |
| # LLM context limit ~6000 tokens for context = ~24000 chars | |
| LLM_CONTEXT_LIMIT_CHARS = 24000 | |
| truncated = False | |
| context_parts = [] | |
| total_chars = 0 | |
| for i, chunk in enumerate(retrieved_chunks, 1): | |
| excerpt = chunk["expanded_context"] | |
| header = f"[EXCERPT {i} — {chunk['title']} | {chunk['year']} | ID: {chunk['judgment_id']}]\n" | |
| part = header + excerpt + "\n" | |
| if total_chars + len(part) > LLM_CONTEXT_LIMIT_CHARS: | |
| # Drop remaining chunks — too long for LLM context | |
| truncated = True | |
| print(f"Context truncated at {i-1} of {len(retrieved_chunks)} chunks") | |
| break | |
| context_parts.append(part) | |
| total_chars += len(part) | |
| context = "\n".join(context_parts) | |
| # ── Step 6: LLM call ────────────────────────────────────── | |
| try: | |
| answer = call_llm(query=query, context=context) | |
| except Exception as e: | |
| # All 3 retries failed — return raw excerpts as fallback | |
| print(f"LLM call failed after retries: {e}") | |
| fallback_excerpts = "\n\n".join( | |
| f"[{c['title']} | {c['year']}]\n{c['chunk_text'][:500]}" | |
| for c in retrieved_chunks | |
| ) | |
| return { | |
| "query": query, | |
| "augmented_query": augmented_query, | |
| "answer": f"LLM service temporarily unavailable. " | |
| f"Most relevant excerpts shown below:\n\n{fallback_excerpts}", | |
| "sources": _build_sources(retrieved_chunks), | |
| "verification_status": "LLM unavailable", | |
| "unverified_quotes": [], | |
| "entities": entities, | |
| "num_sources": len(retrieved_chunks), | |
| "truncated": truncated | |
| } | |
| # ── Step 7: Citation verification ───────────────────────── | |
| verification_status, unverified_quotes = verify_citations(answer, retrieved_chunks) | |
| # ── Step 8: Return ──────────────────────────────────────── | |
| return { | |
| "query": query, | |
| "augmented_query": augmented_query, | |
| "answer": answer, | |
| "sources": _build_sources(retrieved_chunks), | |
| "verification_status": verification_status, | |
| "unverified_quotes": unverified_quotes, | |
| "entities": entities, | |
| "num_sources": len(retrieved_chunks), | |
| "truncated": truncated | |
| } | |
| def _build_sources(chunks) -> list: | |
| """Format retrieved chunks for API response.""" | |
| return [ | |
| { | |
| "judgment_id": c["judgment_id"], | |
| "title": c["title"], | |
| "year": c["year"], | |
| "similarity_score": round(c["similarity_score"], 4), | |
| "excerpt": c["chunk_text"][:300] + "..." | |
| } | |
| for c in chunks | |
| ] | |
| if __name__ == "__main__": | |
| # Smoke test — run directly to verify pipeline works end to end | |
| test_queries = [ | |
| "What are the rights of an arrested person under Article 22?", | |
| "What did the Supreme Court say about freedom of speech?", | |
| "How do I bake a cake?" # Out of domain — should return no results | |
| ] | |
| for query in test_queries: | |
| print(f"\n{'='*60}") | |
| print(f"QUERY: {query}") | |
| result = run_query(query) | |
| print(f"SOURCES: {result['num_sources']}") | |
| print(f"VERIFICATION: {result['verification_status']}") | |
| print(f"ANSWER (first 300 chars):\n{result['answer'][:300]}") |