Spaces:
Running
Running
File size: 3,306 Bytes
0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 0214972 b2d0640 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | """
NER inference module.
Loads fine-tuned DistilBERT and extracts legal entities from query text.
Loaded once at FastAPI startup via load_ner_model().
Fails gracefully — app runs without NER if model not found.
Example:
Input: "What did Justice Chandrachud say about Section 302 IPC?"
Output: {"JUDGE": ["Justice Chandrachud"],
"PROVISION": ["Section 302"],
"STATUTE": ["IPC"]}
The augmented query becomes:
"What did Justice Chandrachud say about Section 302 IPC?
JUDGE: Justice Chandrachud PROVISION: Section 302 STATUTE: IPC"
"""
import os
import logging
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
logger = logging.getLogger(__name__)
NER_MODEL_PATH = os.getenv("NER_MODEL_PATH", "models/ner_model")
TARGET_ENTITIES = {
"JUDGE", "COURT", "STATUTE", "PROVISION",
"CASE_NUMBER", "DATE", "PRECEDENT", "LAWYER",
"PETITIONER", "RESPONDENT", "GPE", "ORG"
}
_ner_pipeline = None
def load_ner_model():
"""
Load NER model once at startup.
Fails gracefully — app runs without NER if model not found.
Call this from api/main.py after download_models().
"""
global _ner_pipeline
if not os.path.exists(NER_MODEL_PATH):
logger.warning(
f"NER model not found at {NER_MODEL_PATH}. "
"Entity extraction disabled. App will run without NER."
)
return
try:
logger.info(f"Loading NER model from {NER_MODEL_PATH}...")
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
_ner_pipeline = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple"
)
logger.info("NER model ready.")
except Exception as e:
logger.error(f"NER model load failed: {e}. Entity extraction disabled.")
_ner_pipeline = None
def extract_entities(text: str) -> dict:
"""
Run NER on input text.
Returns dict of {entity_type: [entity_text, ...]}
Returns empty dict if NER not loaded or inference fails.
"""
if _ner_pipeline is None:
return {}
if not text.strip():
return {}
try:
results = _ner_pipeline(text[:512])
except Exception as e:
logger.warning(f"NER inference failed: {e}")
return {}
entities = {}
for result in results:
entity_type = result["entity_group"]
entity_text = result["word"].strip()
if entity_type not in TARGET_ENTITIES:
continue
if len(entity_text) < 2:
continue
if entity_type not in entities:
entities[entity_type] = []
if entity_text not in entities[entity_type]:
entities[entity_type].append(entity_text)
return entities
def augment_query(query: str, entities: dict) -> str:
"""
Append extracted entities to query string for better FAISS retrieval.
Returns original query unchanged if no entities found.
"""
if not entities:
return query
entity_string = " ".join(
f"{etype}: {etext}"
for etype, texts in entities.items()
for etext in texts
)
return f"{query} {entity_string}" |