nyayasetu / src /ner.py
CaffeinatedCoding's picture
Upload folder using huggingface_hub
b2d0640 verified
"""
NER inference module.
Loads fine-tuned DistilBERT and extracts legal entities from query text.
Loaded once at FastAPI startup via load_ner_model().
Fails gracefully — app runs without NER if model not found.
Example:
Input: "What did Justice Chandrachud say about Section 302 IPC?"
Output: {"JUDGE": ["Justice Chandrachud"],
"PROVISION": ["Section 302"],
"STATUTE": ["IPC"]}
The augmented query becomes:
"What did Justice Chandrachud say about Section 302 IPC?
JUDGE: Justice Chandrachud PROVISION: Section 302 STATUTE: IPC"
"""
import os
import logging
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
logger = logging.getLogger(__name__)
NER_MODEL_PATH = os.getenv("NER_MODEL_PATH", "models/ner_model")
TARGET_ENTITIES = {
"JUDGE", "COURT", "STATUTE", "PROVISION",
"CASE_NUMBER", "DATE", "PRECEDENT", "LAWYER",
"PETITIONER", "RESPONDENT", "GPE", "ORG"
}
_ner_pipeline = None
def load_ner_model():
"""
Load NER model once at startup.
Fails gracefully — app runs without NER if model not found.
Call this from api/main.py after download_models().
"""
global _ner_pipeline
if not os.path.exists(NER_MODEL_PATH):
logger.warning(
f"NER model not found at {NER_MODEL_PATH}. "
"Entity extraction disabled. App will run without NER."
)
return
try:
logger.info(f"Loading NER model from {NER_MODEL_PATH}...")
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
_ner_pipeline = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple"
)
logger.info("NER model ready.")
except Exception as e:
logger.error(f"NER model load failed: {e}. Entity extraction disabled.")
_ner_pipeline = None
def extract_entities(text: str) -> dict:
"""
Run NER on input text.
Returns dict of {entity_type: [entity_text, ...]}
Returns empty dict if NER not loaded or inference fails.
"""
if _ner_pipeline is None:
return {}
if not text.strip():
return {}
try:
results = _ner_pipeline(text[:512])
except Exception as e:
logger.warning(f"NER inference failed: {e}")
return {}
entities = {}
for result in results:
entity_type = result["entity_group"]
entity_text = result["word"].strip()
if entity_type not in TARGET_ENTITIES:
continue
if len(entity_text) < 2:
continue
if entity_type not in entities:
entities[entity_type] = []
if entity_text not in entities[entity_type]:
entities[entity_type].append(entity_text)
return entities
def augment_query(query: str, entities: dict) -> str:
"""
Append extracted entities to query string for better FAISS retrieval.
Returns original query unchanged if no entities found.
"""
if not entities:
return query
entity_string = " ".join(
f"{etype}: {etext}"
for etype, texts in entities.items()
for etext in texts
)
return f"{query} {entity_string}"