rabbitfishai's picture
Update app.py
9c15989 verified
# app.py — DocMap Healthcare Lead Analyzer
# Dynamic ontology (YAML/JSON) + AI-derived matching + vector similarity (CPU-only, robust startup)
import os
import re
import gc
import json
import yaml
import torch
import gradio as gr
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Any, List, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
READY = False
# You can override these via environment variables if you want:
BIOGPT_ID = os.getenv("BIOGPT_ID", "microsoft/BioGPT-Large") # switch to "microsoft/BioGPT" if RAM is tight
EMB_MODEL_ID = os.getenv("EMB_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
ONTOLOGY_PATHS = ["ontology.yaml", "ontology.json"] # user-extensible
# Force CPU-only for this build
DEVICE = torch.device("cpu")
# Optionally limit CPU threads to avoid oversubscription
# torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
# -----------------------------
# Emergency red-flag rules
# -----------------------------
def red_flag_urgency_rules(text: str) -> bool:
t = text.lower()
patterns = [
r"severe chest pain",
r"shortness of breath",
r"loss of consciousness|faint(ed|ing)|syncope",
r"stroke|facial droop|slurred speech|one[- ]sided weakness",
r"massive bleeding|uncontrolled bleeding",
r"neck stiffness.*fever",
r"suicid(al|e)|self[- ]harm",
r"anaphylaxis|swollen tongue|throat closing",
r"new severe headache|worst headache of my life",
]
return any(re.search(p, t) for p in patterns)
# -----------------------------
# Robust JSON extraction
# -----------------------------
def extract_json(text: str) -> Dict[str, Any]:
# Try to grab the largest JSON-looking block (from first { to last })
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
raise ValueError("No JSON braces found in model output.")
candidate = text[start:end+1]
# Strip code fences if present
candidate = candidate.strip()
if candidate.startswith("```"):
candidate = re.sub(r"^```(json)?\s*", "", candidate)
candidate = candidate.rstrip("`").rstrip()
if candidate.endswith("```"):
candidate = candidate[:-3].rstrip()
# Try load -> try trailing-comma repair -> final fail
try:
return json.loads(candidate)
except Exception:
repaired = re.sub(r",\s*([\]}])", r"\1", candidate)
return json.loads(repaired)
def biogpt_safe_fallback(user_text: str) -> Dict[str, Any]:
# Minimal valid JSON to keep the pipeline alive
return {
"patient_inquiry": user_text,
"triage": {"urgency": "Routine", "confidence": 0.3, "priority_score_1_to_10": 3},
"signals": {
"key_symptoms": [],
"suspected_conditions": [],
"candidate_specialties": []
},
"recommended_actions": []
}
# -----------------------------
# Ontology loading (YAML/JSON)
# -----------------------------
def load_ontology() -> Dict[str, Any]:
for p in ONTOLOGY_PATHS:
if os.path.exists(p):
with open(p, "r", encoding="utf-8") as f:
if p.endswith((".yaml", ".yml")):
return yaml.safe_load(f)
return json.load(f)
# Fallback minimal seed (works out of the box; expand via ontology.yaml)
return {
"Cardiology": {
"description": "Heart and circulatory system disorders, chest pain, palpitations, dyspnea.",
"subspecialties": {
"Interventional Cardiology": "Coronary syndromes, stents, acute coronary occlusions.",
"Electrophysiology": "Arrhythmias, palpitations, ablations.",
"Heart Failure": "Fluid overload, reduced cardiac output, dyspnea on exertion."
}
},
"Emergency Medicine": {
"description": "Immediate assessment and stabilization of acute and life-threatening conditions.",
"subspecialties": {"Acute Care": "Multi-system emergencies, trauma, anaphylaxis, sepsis."}
},
"Dermatology": {
"description": "Skin, hair, and nail conditions including rashes, infections, inflammatory disease.",
"subspecialties": {
"General Dermatology": "Eczema, acne, psoriasis, non-urgent rashes.",
"Infectious Dermatology": "Cellulitis, abscess, impetigo."
}
},
"Orthopedics": {
"description": "Bones, joints, ligaments; fractures, sports injuries, arthritis.",
"subspecialties": {
"Trauma/Fracture": "Acute fractures, dislocations, post-traumatic management.",
"Sports Medicine": "Ligament/tendon injuries, meniscal tears, overuse syndromes."
}
},
"Obstetrics & Gynecology": {
"description": "Female reproductive health, pregnancy, pelvic pain, abnormal bleeding.",
"subspecialties": {
"Benign Gynecology": "Dysmenorrhea, fibroids, heavy periods.",
"Early Pregnancy": "First-trimester concerns, ectopic, miscarriage.",
"Reproductive Endocrinology": "Infertility, PCOS, ovulatory disorders."
}
},
"Gastroenterology": {
"description": "Digestive tract disorders; abdominal pain, reflux, bowel habit changes.",
"subspecialties": {
"Functional GI": "IBS, functional dyspepsia, bloating.",
"Hepatology": "Liver disease, hepatitis, jaundice.",
"Upper GI": "GERD, peptic ulcer disease, upper GI bleeding."
}
},
"Neurology": {
"description": "Brain, nerves; headaches, seizures, stroke, focal deficits.",
"subspecialties": {
"Stroke/Neurovascular": "Acute stroke/TIA, vascular diagnostics.",
"Headache": "Migraine, tension-type, red-flag assessment."
}
},
"Endocrinology": {
"description": "Hormonal disorders; diabetes, thyroid, adrenal.",
"subspecialties": {"Diabetes": "Hyper/hypoglycemia, long-term metabolic control.", "Thyroid": "Hypo/hyperthyroid symptoms, nodules."}
},
"Pulmonology": {
"description": "Lung and respiratory disorders; asthma, COPD, pneumonia.",
"subspecialties": {"Asthma/COPD": "Wheezing, cough, dyspnea, airflow limitation.", "Sleep/Ventilation": "Sleep apnea, nocturnal hypoventilation."}
},
"Urology": {
"description": "Urinary tract and male reproductive; dysuria, hematuria, stones.",
"subspecialties": {"General Urology": "Lower urinary tract symptoms, infections, stones."}
},
"Rheumatology": {
"description": "Autoimmune and inflammatory joint/soft tissue disorders.",
"subspecialties": {"Inflammatory Arthritis": "RA, psoriatic arthritis, morning stiffness.", "Connective Tissue": "Lupus, vasculitis, systemic features."}
}
}
# -----------------------------
# Embedding model (sentence-transformers style)
# -----------------------------
@dataclass
class EmbeddingIndex:
labels: List[Tuple[str, Optional[str]]] # (specialty, subspecialty or None)
embeddings: torch.Tensor # [N, D]
def load_text_encoder(model_id: str):
tok = AutoTokenizer.from_pretrained(model_id)
mdl = AutoModel.from_pretrained(
model_id,
trust_remote_code=False,
low_cpu_mem_usage=True
)
mdl.to(DEVICE).eval()
return tok, mdl, DEVICE
@torch.inference_mode()
def encode_sentences(tok, mdl, device, texts: List[str]) -> torch.Tensor:
batch = tok(texts, padding=True, truncation=True, return_tensors="pt").to(device)
out = mdl(**batch)
attn = batch["attention_mask"].unsqueeze(-1)
emb = (out.last_hidden_state * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1e-6)
emb = F.normalize(emb, p=2, dim=1)
return emb
def build_ontology_index(ontology: Dict[str, Any], tok, mdl, device) -> EmbeddingIndex:
texts, labels = [], []
for spec, info in ontology.items():
desc = info.get("description", spec)
texts.append(f"{spec}: {desc}"); labels.append((spec, None))
for sub, sdesc in (info.get("subspecialties", {}) or {}).items():
texts.append(f"{spec} > {sub}: {sdesc}"); labels.append((spec, sub))
embs = encode_sentences(tok, mdl, device, texts)
return EmbeddingIndex(labels=labels, embeddings=embs)
# -----------------------------
# BioGPT (classification JSON + signals)
# -----------------------------
JSON_INSTRUCTIONS = """You are a medical triage classifier. Return ONLY valid, minified JSON (no comments, no code fences).
Schema:
{
"patient_inquiry": string,
"triage": {
"urgency": "Emergency" | "Urgent" | "Routine",
"confidence": number,
"priority_score_1_to_10": integer
},
"signals": {
"key_symptoms": [string],
"suspected_conditions": [string],
"candidate_specialties": [string]
},
"recommended_actions": [
{"action": string, "priority": "Emergency" | "Urgent" | "Routine"}
]
}
Output JSON ONLY on a single line.
"""
def load_biogpt():
print(f"Loading BioGPT (CPU): {BIOGPT_ID}")
tok = AutoTokenizer.from_pretrained(BIOGPT_ID)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
try:
mdl = AutoModelForCausalLM.from_pretrained(
BIOGPT_ID,
torch_dtype=torch.float32, # CPU = FP32
low_cpu_mem_usage=True
)
mdl.to(DEVICE).eval()
dev = DEVICE
return tok, mdl, dev
except Exception as e:
# Fallback to base model if Large cannot load on CPU RAM
print(f"[WARN] Failed to load {BIOGPT_ID} on CPU: {e}")
fallback_id = "microsoft/BioGPT"
print(f"[INFO] Falling back to {fallback_id}")
tok2 = AutoTokenizer.from_pretrained(fallback_id)
if tok2.pad_token is None:
tok2.pad_token = tok2.eos_token
tok2.padding_side = "left"
mdl2 = AutoModelForCausalLM.from_pretrained(
fallback_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
mdl2.to(DEVICE).eval()
return tok2, mdl2, DEVICE
def biogpt_prompt(user_text: str) -> str:
# Keep it short and explicit; minified JSON reduces parse failures
return f"""{JSON_INSTRUCTIONS}
Patient Inquiry: {user_text}
JSON:
"""
@torch.inference_mode()
def biogpt_json(tok, mdl, device, text: str) -> Dict[str, Any]:
"""
Robust JSON generation with retries and a safe fallback to avoid startup crashes.
Only pass temperature/top_p when do_sample=True (prevents warnings on HF >=4.44).
"""
prompt = biogpt_prompt(text)
temps = [0.1, 0.0, 0.0]
max_tokens = [200, 180, 150]
for t, mx in zip(temps, max_tokens):
try:
inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=700, padding=True).to(device)
gen_kwargs = dict(
max_new_tokens=mx,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.eos_token_id,
)
if t > 0.0:
gen_kwargs.update(temperature=t, do_sample=True, top_p=0.9)
else:
gen_kwargs.update(do_sample=False)
outputs = mdl.generate(**inputs, **gen_kwargs)
decoded = tok.decode(outputs[0], skip_special_tokens=True)
if decoded.startswith(prompt):
decoded = decoded[len(prompt):].strip()
return extract_json(decoded)
except Exception:
continue
# Final fallback so the app never crashes
return biogpt_safe_fallback(text)
# -----------------------------
# Ranking logic (AI-derived matching)
# -----------------------------
def ai_enriched_query(signals: Dict[str, Any], user_text: str) -> str:
ks = signals.get("key_symptoms") or []
conds = signals.get("suspected_conditions") or []
cands = signals.get("candidate_specialties") or []
parts = [user_text]
if ks: parts.append("Key symptoms: " + "; ".join(ks))
if conds: parts.append("Suspected: " + "; ".join(conds))
if cands: parts.append("Candidate specialties: " + "; ".join(cands))
return " | ".join(parts)
def rank_specialties(query_emb: torch.Tensor, index: EmbeddingIndex, top_k: int = 8):
sims = torch.matmul(query_emb, index.embeddings.T).squeeze(0) # cosine on normalized vectors
vals, idxs = torch.topk(sims, k=min(top_k, sims.numel()))
results = []
for v, i in zip(vals.tolist(), idxs.tolist()):
spec, sub = index.labels[i]
results.append({"specialty": spec, "subspecialty": sub, "score": float(v)})
return results
# -----------------------------
# Care suggestions (AI-derived, per top match)
# -----------------------------
CARE_INSTRUCTIONS = """You are assisting care pathway planning. For the given specialty/subspecialty and symptoms, suggest initial tests/treatments. Return JSON:
{
"suggested_care": [string]
}
JSON ONLY, single line.
"""
def care_prompt(spec: str, sub: Optional[str], inquiry: str) -> str:
title = f"{spec}" + (f" > {sub}" if sub else "")
return f"""{CARE_INSTRUCTIONS}
Specialty: {title}
Patient Inquiry: {inquiry}
JSON:
"""
@torch.inference_mode()
def suggest_care(tok, mdl, device, spec: str, sub: Optional[str], inquiry: str) -> List[str]:
prompt = care_prompt(spec, sub, inquiry)
temps = [0.1, 0.0]
max_tokens = [120, 100]
for t, mx in zip(temps, max_tokens):
try:
inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)
gen_kwargs = dict(
max_new_tokens=mx,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.eos_token_id,
)
if t > 0.0:
gen_kwargs.update(temperature=t, do_sample=True, top_p=0.9)
else:
gen_kwargs.update(do_sample=False)
outputs = mdl.generate(**inputs, **gen_kwargs)
decoded = tok.decode(outputs[0], skip_special_tokens=True)
if decoded.startswith(prompt):
decoded = decoded[len(prompt):].strip()
data = extract_json(decoded)
items = data.get("suggested_care") or []
return [str(x) for x in items][:5]
except Exception:
continue
return []
# -----------------------------
# Rendering
# -----------------------------
def render(payload: Dict[str, Any]) -> str:
inquiry = payload["patient_inquiry"]
triage = payload["triage"]
urgency = triage["urgency"]
conf = triage.get("confidence")
score = triage.get("priority_score_1_to_10")
urg_emoji = {"Emergency": "🚨", "Urgent": "🟠", "Routine": "🟢"}.get(urgency, "🟢")
lines = []
lines.append("## 🏥 Healthcare Lead Analysis\n")
lines.append(f"**Patient Inquiry:** \"{inquiry}\"\n")
head = f"- **Urgency:** {urg_emoji} **{urgency}**"
if isinstance(conf, (int, float)): head += f" · Confidence: {conf:.2f}"
if isinstance(score, (int, float)): head += f" · Priority: {score}/10"
lines.append(head + "\n")
if payload.get("matches"):
lines.append("### 🩺 Recommended Specialty Pathways:")
for m in payload["matches"]:
title = f"**{m['specialty']}**" + (f" › {m['subspecialty']}" if m.get("subspecialty") else "")
sc = f" ({m['score']:.2f})"
lines.append(f"- {title}{sc}")
care = m.get("suggested_care") or []
if care:
lines.append(" - _Suggested initial tests/treatments_: " + "; ".join(care))
lines.append("")
acts = payload.get("recommended_actions") or []
if acts:
lines.append("### ✅ Recommended Actions:")
for a in acts:
pr = a.get("priority")
badge = f" **[{pr}]**" if pr else ""
lines.append(f"- {a.get('action','')}{badge}")
lines.append("")
lines.append("\n---\n*Analysis powered by BioGPT + vector retrieval over your ontology*")
return "\n".join(lines).strip()
# -----------------------------
# Classify handlers
# -----------------------------
def classify(text: str) -> str:
if not READY:
return "Model is still loading. Please try again in a moment."
if not text or not text.strip():
return "Please enter some text to classify."
forced_emergency = red_flag_urgency_rules(text)
# 1) BioGPT: get triage + signals (robust)
biogpt_out = biogpt_json(bio_tok, bio_mdl, bio_device, text)
triage = biogpt_out.get("triage", {}) if isinstance(biogpt_out.get("triage", {}), dict) else {}
urgency = (triage.get("urgency") or "Routine").title()
if urgency not in ["Emergency", "Urgent", "Routine"]:
urgency = "Routine"
conf = triage.get("confidence")
score = triage.get("priority_score_1_to_10")
actions = biogpt_out.get("recommended_actions") or []
signals = biogpt_out.get("signals") or {}
# Emergency override
if forced_emergency:
urgency = "Emergency"
score = max(9, int(score) if isinstance(score, int) else 9)
if not actions:
actions = [{"action": "Immediate ER referral or call emergency services", "priority": "Emergency"}]
# 2) Build AI-enriched query and rank against ontology
query_text = ai_enriched_query(signals, text)
q_emb = encode_sentences(emb_tok, emb_mdl, emb_device, [query_text]) # [1, D]
ranked = rank_specialties(q_emb, ONT_INDEX, top_k=8)
# 3) For top-N, ask BioGPT to suggest care steps
top_matches = []
for r in ranked[:5]:
care = suggest_care(bio_tok, bio_mdl, bio_device, r["specialty"], r.get("subspecialty"), text)
top_matches.append({
"specialty": r["specialty"],
"subspecialty": r.get("subspecialty"),
"score": r["score"],
"suggested_care": care
})
payload = {
"patient_inquiry": text,
"triage": {"urgency": urgency, "confidence": conf, "priority_score_1_to_10": score},
"matches": top_matches,
"recommended_actions": actions
}
gc.collect()
return render(payload)
def classify_batch(texts: str) -> str:
if not READY:
return "Model is still loading. Please try again in a moment."
if not texts or not texts.strip():
return "Please enter text to analyze."
items = [t.strip() for t in texts.replace(";", "\n").split("\n") if t.strip()]
if len(items) == 1:
return classify(items[0])
parts = [f"### Lead {i}\n\n{classify(t)}" for i, t in enumerate(items, 1)]
return "\n\n---\n\n".join(parts)
# -----------------------------
# Load models & ontology on startup (CPU-only)
# -----------------------------
print("Loading ontology...")
ONTOLOGY = load_ontology()
print("Loading embedding encoder (CPU)...")
emb_tok, emb_mdl, emb_device = load_text_encoder(EMB_MODEL_ID)
print("Building ontology index...")
ONT_INDEX = build_ontology_index(ONTOLOGY, emb_tok, emb_mdl, emb_device)
print("Loading BioGPT (CPU)...")
bio_tok, bio_mdl, bio_device = load_biogpt()
print("Models loaded. Server will start...")
# Signal ready AFTER all models are loaded
READY = True
print("Server ready.")
# -----------------------------
# Gradio UI
# -----------------------------
single = gr.Interface(
fn=classify,
inputs=gr.Textbox(
label="📝 Enter Patient Inquiry or Lead Description",
placeholder="e.g., Severe chest pain radiating to left arm with shortness of breath.",
lines=4,
info="Describe symptoms, urgency, and what the patient is seeking."
),
outputs=gr.Markdown(label="🏥 Analysis"),
title="DocMap Healthcare Lead Analyzer",
description="Dynamic ontology + AI-derived specialty/subspecialty matching with suggested initial care.",
examples=[
["Emergency: Severe chest pain and shortness of breath, need immediate care"],
["Period pain for first three hours and tummy pain"],
["Knee pain after football, swelling and trouble bearing weight"],
["Chronic cough, wheezing at night, worse with exercise"],
["Severe headache with photophobia and nausea"]
],
cache_examples=False, # ✅ prevents startup caching (and 500s)
theme=gr.themes.Soft()
)
batch = gr.Interface(
fn=classify_batch,
inputs=gr.Textbox(
label="📋 Batch Analysis (Multiple Leads)",
placeholder="One inquiry per line or separated by semicolons…",
lines=6
),
outputs=gr.Markdown(label="🏥 Batch Results"),
title="DocMap Batch Analyzer",
description="Analyze multiple leads with AI-derived matching",
cache_examples=False # ✅ prevents startup caching
)
demo = gr.TabbedInterface([single, batch], ["Single Lead", "Batch Processing"])
if __name__ == "__main__":
# Disable SSR (it can surface startup-errors as fatal in some envs)
demo.queue(max_size=20).launch(server_name="0.0.0.0", ssr_mode=False)