| |
| |
|
|
| import os |
| import re |
| import gc |
| import json |
| import yaml |
| import torch |
| import gradio as gr |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Dict, Any, List, Tuple, Optional |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel |
|
|
| READY = False |
|
|
| |
| BIOGPT_ID = os.getenv("BIOGPT_ID", "microsoft/BioGPT-Large") |
| EMB_MODEL_ID = os.getenv("EMB_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2") |
| ONTOLOGY_PATHS = ["ontology.yaml", "ontology.json"] |
|
|
| |
| DEVICE = torch.device("cpu") |
| |
| |
|
|
| |
| |
| |
| def red_flag_urgency_rules(text: str) -> bool: |
| t = text.lower() |
| patterns = [ |
| r"severe chest pain", |
| r"shortness of breath", |
| r"loss of consciousness|faint(ed|ing)|syncope", |
| r"stroke|facial droop|slurred speech|one[- ]sided weakness", |
| r"massive bleeding|uncontrolled bleeding", |
| r"neck stiffness.*fever", |
| r"suicid(al|e)|self[- ]harm", |
| r"anaphylaxis|swollen tongue|throat closing", |
| r"new severe headache|worst headache of my life", |
| ] |
| return any(re.search(p, t) for p in patterns) |
|
|
| |
| |
| |
| def extract_json(text: str) -> Dict[str, Any]: |
| |
| start = text.find("{") |
| end = text.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| raise ValueError("No JSON braces found in model output.") |
| candidate = text[start:end+1] |
|
|
| |
| candidate = candidate.strip() |
| if candidate.startswith("```"): |
| candidate = re.sub(r"^```(json)?\s*", "", candidate) |
| candidate = candidate.rstrip("`").rstrip() |
| if candidate.endswith("```"): |
| candidate = candidate[:-3].rstrip() |
|
|
| |
| try: |
| return json.loads(candidate) |
| except Exception: |
| repaired = re.sub(r",\s*([\]}])", r"\1", candidate) |
| return json.loads(repaired) |
|
|
| def biogpt_safe_fallback(user_text: str) -> Dict[str, Any]: |
| |
| return { |
| "patient_inquiry": user_text, |
| "triage": {"urgency": "Routine", "confidence": 0.3, "priority_score_1_to_10": 3}, |
| "signals": { |
| "key_symptoms": [], |
| "suspected_conditions": [], |
| "candidate_specialties": [] |
| }, |
| "recommended_actions": [] |
| } |
|
|
| |
| |
| |
| def load_ontology() -> Dict[str, Any]: |
| for p in ONTOLOGY_PATHS: |
| if os.path.exists(p): |
| with open(p, "r", encoding="utf-8") as f: |
| if p.endswith((".yaml", ".yml")): |
| return yaml.safe_load(f) |
| return json.load(f) |
| |
| return { |
| "Cardiology": { |
| "description": "Heart and circulatory system disorders, chest pain, palpitations, dyspnea.", |
| "subspecialties": { |
| "Interventional Cardiology": "Coronary syndromes, stents, acute coronary occlusions.", |
| "Electrophysiology": "Arrhythmias, palpitations, ablations.", |
| "Heart Failure": "Fluid overload, reduced cardiac output, dyspnea on exertion." |
| } |
| }, |
| "Emergency Medicine": { |
| "description": "Immediate assessment and stabilization of acute and life-threatening conditions.", |
| "subspecialties": {"Acute Care": "Multi-system emergencies, trauma, anaphylaxis, sepsis."} |
| }, |
| "Dermatology": { |
| "description": "Skin, hair, and nail conditions including rashes, infections, inflammatory disease.", |
| "subspecialties": { |
| "General Dermatology": "Eczema, acne, psoriasis, non-urgent rashes.", |
| "Infectious Dermatology": "Cellulitis, abscess, impetigo." |
| } |
| }, |
| "Orthopedics": { |
| "description": "Bones, joints, ligaments; fractures, sports injuries, arthritis.", |
| "subspecialties": { |
| "Trauma/Fracture": "Acute fractures, dislocations, post-traumatic management.", |
| "Sports Medicine": "Ligament/tendon injuries, meniscal tears, overuse syndromes." |
| } |
| }, |
| "Obstetrics & Gynecology": { |
| "description": "Female reproductive health, pregnancy, pelvic pain, abnormal bleeding.", |
| "subspecialties": { |
| "Benign Gynecology": "Dysmenorrhea, fibroids, heavy periods.", |
| "Early Pregnancy": "First-trimester concerns, ectopic, miscarriage.", |
| "Reproductive Endocrinology": "Infertility, PCOS, ovulatory disorders." |
| } |
| }, |
| "Gastroenterology": { |
| "description": "Digestive tract disorders; abdominal pain, reflux, bowel habit changes.", |
| "subspecialties": { |
| "Functional GI": "IBS, functional dyspepsia, bloating.", |
| "Hepatology": "Liver disease, hepatitis, jaundice.", |
| "Upper GI": "GERD, peptic ulcer disease, upper GI bleeding." |
| } |
| }, |
| "Neurology": { |
| "description": "Brain, nerves; headaches, seizures, stroke, focal deficits.", |
| "subspecialties": { |
| "Stroke/Neurovascular": "Acute stroke/TIA, vascular diagnostics.", |
| "Headache": "Migraine, tension-type, red-flag assessment." |
| } |
| }, |
| "Endocrinology": { |
| "description": "Hormonal disorders; diabetes, thyroid, adrenal.", |
| "subspecialties": {"Diabetes": "Hyper/hypoglycemia, long-term metabolic control.", "Thyroid": "Hypo/hyperthyroid symptoms, nodules."} |
| }, |
| "Pulmonology": { |
| "description": "Lung and respiratory disorders; asthma, COPD, pneumonia.", |
| "subspecialties": {"Asthma/COPD": "Wheezing, cough, dyspnea, airflow limitation.", "Sleep/Ventilation": "Sleep apnea, nocturnal hypoventilation."} |
| }, |
| "Urology": { |
| "description": "Urinary tract and male reproductive; dysuria, hematuria, stones.", |
| "subspecialties": {"General Urology": "Lower urinary tract symptoms, infections, stones."} |
| }, |
| "Rheumatology": { |
| "description": "Autoimmune and inflammatory joint/soft tissue disorders.", |
| "subspecialties": {"Inflammatory Arthritis": "RA, psoriatic arthritis, morning stiffness.", "Connective Tissue": "Lupus, vasculitis, systemic features."} |
| } |
| } |
|
|
| |
| |
| |
| @dataclass |
| class EmbeddingIndex: |
| labels: List[Tuple[str, Optional[str]]] |
| embeddings: torch.Tensor |
|
|
| def load_text_encoder(model_id: str): |
| tok = AutoTokenizer.from_pretrained(model_id) |
| mdl = AutoModel.from_pretrained( |
| model_id, |
| trust_remote_code=False, |
| low_cpu_mem_usage=True |
| ) |
| mdl.to(DEVICE).eval() |
| return tok, mdl, DEVICE |
|
|
| @torch.inference_mode() |
| def encode_sentences(tok, mdl, device, texts: List[str]) -> torch.Tensor: |
| batch = tok(texts, padding=True, truncation=True, return_tensors="pt").to(device) |
| out = mdl(**batch) |
| attn = batch["attention_mask"].unsqueeze(-1) |
| emb = (out.last_hidden_state * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1e-6) |
| emb = F.normalize(emb, p=2, dim=1) |
| return emb |
|
|
| def build_ontology_index(ontology: Dict[str, Any], tok, mdl, device) -> EmbeddingIndex: |
| texts, labels = [], [] |
| for spec, info in ontology.items(): |
| desc = info.get("description", spec) |
| texts.append(f"{spec}: {desc}"); labels.append((spec, None)) |
| for sub, sdesc in (info.get("subspecialties", {}) or {}).items(): |
| texts.append(f"{spec} > {sub}: {sdesc}"); labels.append((spec, sub)) |
| embs = encode_sentences(tok, mdl, device, texts) |
| return EmbeddingIndex(labels=labels, embeddings=embs) |
|
|
| |
| |
| |
| JSON_INSTRUCTIONS = """You are a medical triage classifier. Return ONLY valid, minified JSON (no comments, no code fences). |
| |
| Schema: |
| { |
| "patient_inquiry": string, |
| "triage": { |
| "urgency": "Emergency" | "Urgent" | "Routine", |
| "confidence": number, |
| "priority_score_1_to_10": integer |
| }, |
| "signals": { |
| "key_symptoms": [string], |
| "suspected_conditions": [string], |
| "candidate_specialties": [string] |
| }, |
| "recommended_actions": [ |
| {"action": string, "priority": "Emergency" | "Urgent" | "Routine"} |
| ] |
| } |
| Output JSON ONLY on a single line. |
| """ |
|
|
| def load_biogpt(): |
| print(f"Loading BioGPT (CPU): {BIOGPT_ID}") |
| tok = AutoTokenizer.from_pretrained(BIOGPT_ID) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| tok.padding_side = "left" |
| try: |
| mdl = AutoModelForCausalLM.from_pretrained( |
| BIOGPT_ID, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
| mdl.to(DEVICE).eval() |
| dev = DEVICE |
| return tok, mdl, dev |
| except Exception as e: |
| |
| print(f"[WARN] Failed to load {BIOGPT_ID} on CPU: {e}") |
| fallback_id = "microsoft/BioGPT" |
| print(f"[INFO] Falling back to {fallback_id}") |
| tok2 = AutoTokenizer.from_pretrained(fallback_id) |
| if tok2.pad_token is None: |
| tok2.pad_token = tok2.eos_token |
| tok2.padding_side = "left" |
| mdl2 = AutoModelForCausalLM.from_pretrained( |
| fallback_id, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
| mdl2.to(DEVICE).eval() |
| return tok2, mdl2, DEVICE |
|
|
| def biogpt_prompt(user_text: str) -> str: |
| |
| return f"""{JSON_INSTRUCTIONS} |
| |
| Patient Inquiry: {user_text} |
| JSON: |
| """ |
|
|
| @torch.inference_mode() |
| def biogpt_json(tok, mdl, device, text: str) -> Dict[str, Any]: |
| """ |
| Robust JSON generation with retries and a safe fallback to avoid startup crashes. |
| Only pass temperature/top_p when do_sample=True (prevents warnings on HF >=4.44). |
| """ |
| prompt = biogpt_prompt(text) |
| temps = [0.1, 0.0, 0.0] |
| max_tokens = [200, 180, 150] |
|
|
| for t, mx in zip(temps, max_tokens): |
| try: |
| inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=700, padding=True).to(device) |
| gen_kwargs = dict( |
| max_new_tokens=mx, |
| eos_token_id=tok.eos_token_id, |
| pad_token_id=tok.eos_token_id, |
| ) |
| if t > 0.0: |
| gen_kwargs.update(temperature=t, do_sample=True, top_p=0.9) |
| else: |
| gen_kwargs.update(do_sample=False) |
|
|
| outputs = mdl.generate(**inputs, **gen_kwargs) |
| decoded = tok.decode(outputs[0], skip_special_tokens=True) |
| if decoded.startswith(prompt): |
| decoded = decoded[len(prompt):].strip() |
| return extract_json(decoded) |
| except Exception: |
| continue |
|
|
| |
| return biogpt_safe_fallback(text) |
|
|
| |
| |
| |
| def ai_enriched_query(signals: Dict[str, Any], user_text: str) -> str: |
| ks = signals.get("key_symptoms") or [] |
| conds = signals.get("suspected_conditions") or [] |
| cands = signals.get("candidate_specialties") or [] |
| parts = [user_text] |
| if ks: parts.append("Key symptoms: " + "; ".join(ks)) |
| if conds: parts.append("Suspected: " + "; ".join(conds)) |
| if cands: parts.append("Candidate specialties: " + "; ".join(cands)) |
| return " | ".join(parts) |
|
|
| def rank_specialties(query_emb: torch.Tensor, index: EmbeddingIndex, top_k: int = 8): |
| sims = torch.matmul(query_emb, index.embeddings.T).squeeze(0) |
| vals, idxs = torch.topk(sims, k=min(top_k, sims.numel())) |
| results = [] |
| for v, i in zip(vals.tolist(), idxs.tolist()): |
| spec, sub = index.labels[i] |
| results.append({"specialty": spec, "subspecialty": sub, "score": float(v)}) |
| return results |
|
|
| |
| |
| |
| CARE_INSTRUCTIONS = """You are assisting care pathway planning. For the given specialty/subspecialty and symptoms, suggest initial tests/treatments. Return JSON: |
| { |
| "suggested_care": [string] |
| } |
| JSON ONLY, single line. |
| """ |
|
|
| def care_prompt(spec: str, sub: Optional[str], inquiry: str) -> str: |
| title = f"{spec}" + (f" > {sub}" if sub else "") |
| return f"""{CARE_INSTRUCTIONS} |
| |
| Specialty: {title} |
| Patient Inquiry: {inquiry} |
| JSON: |
| """ |
|
|
| @torch.inference_mode() |
| def suggest_care(tok, mdl, device, spec: str, sub: Optional[str], inquiry: str) -> List[str]: |
| prompt = care_prompt(spec, sub, inquiry) |
| temps = [0.1, 0.0] |
| max_tokens = [120, 100] |
| for t, mx in zip(temps, max_tokens): |
| try: |
| inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device) |
| gen_kwargs = dict( |
| max_new_tokens=mx, |
| eos_token_id=tok.eos_token_id, |
| pad_token_id=tok.eos_token_id, |
| ) |
| if t > 0.0: |
| gen_kwargs.update(temperature=t, do_sample=True, top_p=0.9) |
| else: |
| gen_kwargs.update(do_sample=False) |
|
|
| outputs = mdl.generate(**inputs, **gen_kwargs) |
| decoded = tok.decode(outputs[0], skip_special_tokens=True) |
| if decoded.startswith(prompt): |
| decoded = decoded[len(prompt):].strip() |
| data = extract_json(decoded) |
| items = data.get("suggested_care") or [] |
| return [str(x) for x in items][:5] |
| except Exception: |
| continue |
| return [] |
|
|
| |
| |
| |
| def render(payload: Dict[str, Any]) -> str: |
| inquiry = payload["patient_inquiry"] |
| triage = payload["triage"] |
| urgency = triage["urgency"] |
| conf = triage.get("confidence") |
| score = triage.get("priority_score_1_to_10") |
| urg_emoji = {"Emergency": "🚨", "Urgent": "🟠", "Routine": "🟢"}.get(urgency, "🟢") |
|
|
| lines = [] |
| lines.append("## 🏥 Healthcare Lead Analysis\n") |
| lines.append(f"**Patient Inquiry:** \"{inquiry}\"\n") |
| head = f"- **Urgency:** {urg_emoji} **{urgency}**" |
| if isinstance(conf, (int, float)): head += f" · Confidence: {conf:.2f}" |
| if isinstance(score, (int, float)): head += f" · Priority: {score}/10" |
| lines.append(head + "\n") |
|
|
| if payload.get("matches"): |
| lines.append("### 🩺 Recommended Specialty Pathways:") |
| for m in payload["matches"]: |
| title = f"**{m['specialty']}**" + (f" › {m['subspecialty']}" if m.get("subspecialty") else "") |
| sc = f" ({m['score']:.2f})" |
| lines.append(f"- {title}{sc}") |
| care = m.get("suggested_care") or [] |
| if care: |
| lines.append(" - _Suggested initial tests/treatments_: " + "; ".join(care)) |
| lines.append("") |
|
|
| acts = payload.get("recommended_actions") or [] |
| if acts: |
| lines.append("### ✅ Recommended Actions:") |
| for a in acts: |
| pr = a.get("priority") |
| badge = f" **[{pr}]**" if pr else "" |
| lines.append(f"- {a.get('action','')}{badge}") |
| lines.append("") |
|
|
| lines.append("\n---\n*Analysis powered by BioGPT + vector retrieval over your ontology*") |
| return "\n".join(lines).strip() |
|
|
| |
| |
| |
| def classify(text: str) -> str: |
| if not READY: |
| return "Model is still loading. Please try again in a moment." |
| if not text or not text.strip(): |
| return "Please enter some text to classify." |
|
|
| forced_emergency = red_flag_urgency_rules(text) |
|
|
| |
| biogpt_out = biogpt_json(bio_tok, bio_mdl, bio_device, text) |
| triage = biogpt_out.get("triage", {}) if isinstance(biogpt_out.get("triage", {}), dict) else {} |
| urgency = (triage.get("urgency") or "Routine").title() |
| if urgency not in ["Emergency", "Urgent", "Routine"]: |
| urgency = "Routine" |
| conf = triage.get("confidence") |
| score = triage.get("priority_score_1_to_10") |
| actions = biogpt_out.get("recommended_actions") or [] |
| signals = biogpt_out.get("signals") or {} |
|
|
| |
| if forced_emergency: |
| urgency = "Emergency" |
| score = max(9, int(score) if isinstance(score, int) else 9) |
| if not actions: |
| actions = [{"action": "Immediate ER referral or call emergency services", "priority": "Emergency"}] |
|
|
| |
| query_text = ai_enriched_query(signals, text) |
| q_emb = encode_sentences(emb_tok, emb_mdl, emb_device, [query_text]) |
| ranked = rank_specialties(q_emb, ONT_INDEX, top_k=8) |
|
|
| |
| top_matches = [] |
| for r in ranked[:5]: |
| care = suggest_care(bio_tok, bio_mdl, bio_device, r["specialty"], r.get("subspecialty"), text) |
| top_matches.append({ |
| "specialty": r["specialty"], |
| "subspecialty": r.get("subspecialty"), |
| "score": r["score"], |
| "suggested_care": care |
| }) |
|
|
| payload = { |
| "patient_inquiry": text, |
| "triage": {"urgency": urgency, "confidence": conf, "priority_score_1_to_10": score}, |
| "matches": top_matches, |
| "recommended_actions": actions |
| } |
| gc.collect() |
| return render(payload) |
|
|
| def classify_batch(texts: str) -> str: |
| if not READY: |
| return "Model is still loading. Please try again in a moment." |
| if not texts or not texts.strip(): |
| return "Please enter text to analyze." |
| items = [t.strip() for t in texts.replace(";", "\n").split("\n") if t.strip()] |
| if len(items) == 1: |
| return classify(items[0]) |
| parts = [f"### Lead {i}\n\n{classify(t)}" for i, t in enumerate(items, 1)] |
| return "\n\n---\n\n".join(parts) |
|
|
| |
| |
| |
| print("Loading ontology...") |
| ONTOLOGY = load_ontology() |
|
|
| print("Loading embedding encoder (CPU)...") |
| emb_tok, emb_mdl, emb_device = load_text_encoder(EMB_MODEL_ID) |
| print("Building ontology index...") |
| ONT_INDEX = build_ontology_index(ONTOLOGY, emb_tok, emb_mdl, emb_device) |
|
|
| print("Loading BioGPT (CPU)...") |
| bio_tok, bio_mdl, bio_device = load_biogpt() |
| print("Models loaded. Server will start...") |
|
|
| |
| READY = True |
| print("Server ready.") |
|
|
| |
| |
| |
| single = gr.Interface( |
| fn=classify, |
| inputs=gr.Textbox( |
| label="📝 Enter Patient Inquiry or Lead Description", |
| placeholder="e.g., Severe chest pain radiating to left arm with shortness of breath.", |
| lines=4, |
| info="Describe symptoms, urgency, and what the patient is seeking." |
| ), |
| outputs=gr.Markdown(label="🏥 Analysis"), |
| title="DocMap Healthcare Lead Analyzer", |
| description="Dynamic ontology + AI-derived specialty/subspecialty matching with suggested initial care.", |
| examples=[ |
| ["Emergency: Severe chest pain and shortness of breath, need immediate care"], |
| ["Period pain for first three hours and tummy pain"], |
| ["Knee pain after football, swelling and trouble bearing weight"], |
| ["Chronic cough, wheezing at night, worse with exercise"], |
| ["Severe headache with photophobia and nausea"] |
| ], |
| cache_examples=False, |
| theme=gr.themes.Soft() |
| ) |
|
|
| batch = gr.Interface( |
| fn=classify_batch, |
| inputs=gr.Textbox( |
| label="📋 Batch Analysis (Multiple Leads)", |
| placeholder="One inquiry per line or separated by semicolons…", |
| lines=6 |
| ), |
| outputs=gr.Markdown(label="🏥 Batch Results"), |
| title="DocMap Batch Analyzer", |
| description="Analyze multiple leads with AI-derived matching", |
| cache_examples=False |
| ) |
|
|
| demo = gr.TabbedInterface([single, batch], ["Single Lead", "Batch Processing"]) |
|
|
| if __name__ == "__main__": |
| |
| demo.queue(max_size=20).launch(server_name="0.0.0.0", ssr_mode=False) |
|
|