Spaces:
Running
Running
| """ | |
| Cross-encoder reranker. | |
| Reranks FAISS retrieval results by true query-document relevance. | |
| WHY cross-encoder over bi-encoder (MiniLM)? | |
| MiniLM embeds query and document independently — fast but approximate. | |
| Cross-encoder sees query+document together — slower but much more accurate. | |
| Used post-retrieval on top-15 candidates to select best top-5. | |
| WHY ms-marco-MiniLM-L-6-v2? | |
| Trained on MS-MARCO passage ranking — transfers well to legal QA. | |
| Small enough to load on HF Spaces free tier (~80MB). | |
| Fast enough for reranking 15 candidates in ~200ms on CPU. | |
| Interview answer: | |
| "I added a cross-encoder reranker post-retrieval to boost precision@5 | |
| by focusing on true relevance rather than embedding similarity alone. | |
| Legal domain papers show 8-15% precision lift from reranking." | |
| """ | |
| import logging | |
| from typing import List, Dict | |
| logger = logging.getLogger(__name__) | |
| _reranker = None | |
| _reranker_loaded = False | |
| def load_reranker(): | |
| """ | |
| Load cross-encoder once at startup. | |
| Fails gracefully — retrieval works without reranker. | |
| Call from api/main.py after other models load. | |
| """ | |
| global _reranker, _reranker_loaded | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| logger.info("Loading cross-encoder reranker...") | |
| _reranker = CrossEncoder( | |
| "cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| max_length=512 | |
| ) | |
| _reranker_loaded = True | |
| logger.info("Cross-encoder reranker ready") | |
| except Exception as e: | |
| logger.warning(f"Reranker load failed: {e}. Retrieval will use FAISS scores only.") | |
| _reranker_loaded = False | |
| def rerank(query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]: | |
| """ | |
| Rerank chunks by cross-encoder relevance score. | |
| Args: | |
| query: user query string | |
| chunks: list of retrieved chunks from FAISS | |
| top_k: number of top chunks to return after reranking | |
| Returns: | |
| top_k chunks sorted by reranker score descending. | |
| If reranker not loaded, returns original chunks[:top_k]. | |
| """ | |
| if not _reranker_loaded or _reranker is None: | |
| return chunks[:top_k] | |
| if not chunks: | |
| return [] | |
| try: | |
| # Build query-document pairs | |
| pairs = [] | |
| for chunk in chunks: | |
| text = ( | |
| chunk.get("expanded_context") or | |
| chunk.get("chunk_text") or | |
| chunk.get("text", "") | |
| )[:512] | |
| pairs.append([query, text]) | |
| # Score all pairs | |
| scores = _reranker.predict(pairs, batch_size=16) | |
| # Attach scores and sort | |
| for chunk, score in zip(chunks, scores): | |
| chunk["reranker_score"] = float(score) | |
| reranked = sorted(chunks, key=lambda x: x.get("reranker_score", 0), reverse=True) | |
| logger.info( | |
| f"Reranked {len(chunks)} chunks → top {top_k}. " | |
| f"Top score: {reranked[0].get('reranker_score', 0):.3f}" | |
| ) | |
| return reranked[:top_k] | |
| except Exception as e: | |
| logger.warning(f"Reranking failed: {e}. Using FAISS order.") | |
| return chunks[:top_k] | |
| def is_loaded() -> bool: | |
| return _reranker_loaded | |