nyayasetu / src /retrieval.py
CaffeinatedCoding's picture
Upload folder using huggingface_hub
0430f42 verified
"""
FAISS retrieval module.
Loads the FAISS index and chunk metadata once at startup.
Given a query embedding, returns the top-k most similar chunks
plus an expanded context window from the parent judgment.
WHY load at startup and not per request?
Loading a 650MB index takes ~3 seconds. If you loaded it per request,
every user query would take 3+ seconds just for setup. Loading once
at startup means retrieval takes ~5ms per query.
"""
import json
import numpy as np
import faiss
import os
from typing import List, Dict
INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "models/faiss_index/index.faiss")
METADATA_PATH = os.getenv("METADATA_PATH", "models/faiss_index/chunk_metadata.jsonl")
PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
TOP_K = 5
# Similarity threshold for out-of-domain detection.
# This index uses L2 distance — HIGHER score = FURTHER AWAY = worse match.
# Legal queries typically score 0.6 - 0.8.
# Out-of-domain queries (cricket, Bollywood) score 0.9+.
# Block anything where the best match is above this threshold.
SIMILARITY_THRESHOLD = 0.85
def _load_resources():
"""Load index, metadata and parent store. Called once at module import."""
print("Loading FAISS index...")
index = faiss.read_index(INDEX_PATH)
print(f"Index loaded: {index.ntotal} vectors")
print("Loading chunk metadata...")
metadata = []
with open(METADATA_PATH, "r", encoding="utf-8") as f:
for line in f:
metadata.append(json.loads(line))
print(f"Metadata loaded: {len(metadata)} chunks")
print("Loading parent judgments...")
parent_store = {}
with open(PARENT_PATH, "r", encoding="utf-8") as f:
for line in f:
parent = json.loads(line)
parent_store[parent["judgment_id"]] = parent["text"]
print(f"Parent store loaded: {len(parent_store)} judgments")
return index, metadata, parent_store
_index, _metadata, _parent_store = _load_resources()
def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
"""
Find top-k chunks most similar to the query embedding.
Returns empty list if best score is above SIMILARITY_THRESHOLD
(meaning the query is likely out of domain — no close match found).
L2 distance logic:
low score = close match = good = let through
high score = far match = bad = block
"""
query_vec = query_embedding.reshape(1, -1).astype(np.float32)
scores, indices = _index.search(query_vec, top_k)
# Block if even the best match is too far away
best_score = float(scores[0][0])
if best_score > SIMILARITY_THRESHOLD:
return [] # Out of domain — agent will handle this
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
chunk = _metadata[idx]
expanded = _get_expanded_context(
chunk["judgment_id"],
chunk["text"]
)
results.append({
"chunk_id": chunk["chunk_id"],
"judgment_id": chunk["judgment_id"],
"title": chunk.get("title", ""),
"year": chunk.get("year", ""),
"chunk_text": chunk["text"],
"expanded_context": expanded,
"similarity_score": float(score)
})
return results
def _get_expanded_context(judgment_id: str, chunk_text: str) -> str:
"""
Get ~1024 token window from parent judgment centred on the chunk.
Falls back to chunk text if parent not found.
WHY expand context?
The chunk is 512 tokens — enough for retrieval.
But the LLM needs more surrounding context to give a complete answer.
We go back to the full judgment and extract a wider window.
"""
parent_text = _parent_store.get(judgment_id, "")
if not parent_text:
return chunk_text
# Find chunk position in parent
anchor = chunk_text[:80]
start_pos = parent_text.find(anchor)
if start_pos == -1:
return chunk_text
# ~4 chars per token, 1024 tokens = ~4096 chars
WINDOW = 4096
expand_start = max(0, start_pos - WINDOW // 4)
expand_end = min(len(parent_text), start_pos + WINDOW)
return parent_text[expand_start:expand_end]