# app.py β€” DocMap Healthcare Lead Analyzer # Dynamic ontology (YAML/JSON) + AI-derived matching + vector similarity (CPU-only, robust startup) import os import re import gc import json import yaml import torch import gradio as gr import torch.nn.functional as F from dataclasses import dataclass from typing import Dict, Any, List, Tuple, Optional from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel READY = False # You can override these via environment variables if you want: BIOGPT_ID = os.getenv("BIOGPT_ID", "microsoft/BioGPT-Large") # switch to "microsoft/BioGPT" if RAM is tight EMB_MODEL_ID = os.getenv("EMB_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2") ONTOLOGY_PATHS = ["ontology.yaml", "ontology.json"] # user-extensible # Force CPU-only for this build DEVICE = torch.device("cpu") # Optionally limit CPU threads to avoid oversubscription # torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4"))) # ----------------------------- # Emergency red-flag rules # ----------------------------- def red_flag_urgency_rules(text: str) -> bool: t = text.lower() patterns = [ r"severe chest pain", r"shortness of breath", r"loss of consciousness|faint(ed|ing)|syncope", r"stroke|facial droop|slurred speech|one[- ]sided weakness", r"massive bleeding|uncontrolled bleeding", r"neck stiffness.*fever", r"suicid(al|e)|self[- ]harm", r"anaphylaxis|swollen tongue|throat closing", r"new severe headache|worst headache of my life", ] return any(re.search(p, t) for p in patterns) # ----------------------------- # Robust JSON extraction # ----------------------------- def extract_json(text: str) -> Dict[str, Any]: # Try to grab the largest JSON-looking block (from first { to last }) start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end <= start: raise ValueError("No JSON braces found in model output.") candidate = text[start:end+1] # Strip code fences if present candidate = candidate.strip() if candidate.startswith("```"): candidate = re.sub(r"^```(json)?\s*", "", candidate) candidate = candidate.rstrip("`").rstrip() if candidate.endswith("```"): candidate = candidate[:-3].rstrip() # Try load -> try trailing-comma repair -> final fail try: return json.loads(candidate) except Exception: repaired = re.sub(r",\s*([\]}])", r"\1", candidate) return json.loads(repaired) def biogpt_safe_fallback(user_text: str) -> Dict[str, Any]: # Minimal valid JSON to keep the pipeline alive return { "patient_inquiry": user_text, "triage": {"urgency": "Routine", "confidence": 0.3, "priority_score_1_to_10": 3}, "signals": { "key_symptoms": [], "suspected_conditions": [], "candidate_specialties": [] }, "recommended_actions": [] } # ----------------------------- # Ontology loading (YAML/JSON) # ----------------------------- def load_ontology() -> Dict[str, Any]: for p in ONTOLOGY_PATHS: if os.path.exists(p): with open(p, "r", encoding="utf-8") as f: if p.endswith((".yaml", ".yml")): return yaml.safe_load(f) return json.load(f) # Fallback minimal seed (works out of the box; expand via ontology.yaml) return { "Cardiology": { "description": "Heart and circulatory system disorders, chest pain, palpitations, dyspnea.", "subspecialties": { "Interventional Cardiology": "Coronary syndromes, stents, acute coronary occlusions.", "Electrophysiology": "Arrhythmias, palpitations, ablations.", "Heart Failure": "Fluid overload, reduced cardiac output, dyspnea on exertion." } }, "Emergency Medicine": { "description": "Immediate assessment and stabilization of acute and life-threatening conditions.", "subspecialties": {"Acute Care": "Multi-system emergencies, trauma, anaphylaxis, sepsis."} }, "Dermatology": { "description": "Skin, hair, and nail conditions including rashes, infections, inflammatory disease.", "subspecialties": { "General Dermatology": "Eczema, acne, psoriasis, non-urgent rashes.", "Infectious Dermatology": "Cellulitis, abscess, impetigo." } }, "Orthopedics": { "description": "Bones, joints, ligaments; fractures, sports injuries, arthritis.", "subspecialties": { "Trauma/Fracture": "Acute fractures, dislocations, post-traumatic management.", "Sports Medicine": "Ligament/tendon injuries, meniscal tears, overuse syndromes." } }, "Obstetrics & Gynecology": { "description": "Female reproductive health, pregnancy, pelvic pain, abnormal bleeding.", "subspecialties": { "Benign Gynecology": "Dysmenorrhea, fibroids, heavy periods.", "Early Pregnancy": "First-trimester concerns, ectopic, miscarriage.", "Reproductive Endocrinology": "Infertility, PCOS, ovulatory disorders." } }, "Gastroenterology": { "description": "Digestive tract disorders; abdominal pain, reflux, bowel habit changes.", "subspecialties": { "Functional GI": "IBS, functional dyspepsia, bloating.", "Hepatology": "Liver disease, hepatitis, jaundice.", "Upper GI": "GERD, peptic ulcer disease, upper GI bleeding." } }, "Neurology": { "description": "Brain, nerves; headaches, seizures, stroke, focal deficits.", "subspecialties": { "Stroke/Neurovascular": "Acute stroke/TIA, vascular diagnostics.", "Headache": "Migraine, tension-type, red-flag assessment." } }, "Endocrinology": { "description": "Hormonal disorders; diabetes, thyroid, adrenal.", "subspecialties": {"Diabetes": "Hyper/hypoglycemia, long-term metabolic control.", "Thyroid": "Hypo/hyperthyroid symptoms, nodules."} }, "Pulmonology": { "description": "Lung and respiratory disorders; asthma, COPD, pneumonia.", "subspecialties": {"Asthma/COPD": "Wheezing, cough, dyspnea, airflow limitation.", "Sleep/Ventilation": "Sleep apnea, nocturnal hypoventilation."} }, "Urology": { "description": "Urinary tract and male reproductive; dysuria, hematuria, stones.", "subspecialties": {"General Urology": "Lower urinary tract symptoms, infections, stones."} }, "Rheumatology": { "description": "Autoimmune and inflammatory joint/soft tissue disorders.", "subspecialties": {"Inflammatory Arthritis": "RA, psoriatic arthritis, morning stiffness.", "Connective Tissue": "Lupus, vasculitis, systemic features."} } } # ----------------------------- # Embedding model (sentence-transformers style) # ----------------------------- @dataclass class EmbeddingIndex: labels: List[Tuple[str, Optional[str]]] # (specialty, subspecialty or None) embeddings: torch.Tensor # [N, D] def load_text_encoder(model_id: str): tok = AutoTokenizer.from_pretrained(model_id) mdl = AutoModel.from_pretrained( model_id, trust_remote_code=False, low_cpu_mem_usage=True ) mdl.to(DEVICE).eval() return tok, mdl, DEVICE @torch.inference_mode() def encode_sentences(tok, mdl, device, texts: List[str]) -> torch.Tensor: batch = tok(texts, padding=True, truncation=True, return_tensors="pt").to(device) out = mdl(**batch) attn = batch["attention_mask"].unsqueeze(-1) emb = (out.last_hidden_state * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1e-6) emb = F.normalize(emb, p=2, dim=1) return emb def build_ontology_index(ontology: Dict[str, Any], tok, mdl, device) -> EmbeddingIndex: texts, labels = [], [] for spec, info in ontology.items(): desc = info.get("description", spec) texts.append(f"{spec}: {desc}"); labels.append((spec, None)) for sub, sdesc in (info.get("subspecialties", {}) or {}).items(): texts.append(f"{spec} > {sub}: {sdesc}"); labels.append((spec, sub)) embs = encode_sentences(tok, mdl, device, texts) return EmbeddingIndex(labels=labels, embeddings=embs) # ----------------------------- # BioGPT (classification JSON + signals) # ----------------------------- JSON_INSTRUCTIONS = """You are a medical triage classifier. Return ONLY valid, minified JSON (no comments, no code fences). Schema: { "patient_inquiry": string, "triage": { "urgency": "Emergency" | "Urgent" | "Routine", "confidence": number, "priority_score_1_to_10": integer }, "signals": { "key_symptoms": [string], "suspected_conditions": [string], "candidate_specialties": [string] }, "recommended_actions": [ {"action": string, "priority": "Emergency" | "Urgent" | "Routine"} ] } Output JSON ONLY on a single line. """ def load_biogpt(): print(f"Loading BioGPT (CPU): {BIOGPT_ID}") tok = AutoTokenizer.from_pretrained(BIOGPT_ID) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.padding_side = "left" try: mdl = AutoModelForCausalLM.from_pretrained( BIOGPT_ID, torch_dtype=torch.float32, # CPU = FP32 low_cpu_mem_usage=True ) mdl.to(DEVICE).eval() dev = DEVICE return tok, mdl, dev except Exception as e: # Fallback to base model if Large cannot load on CPU RAM print(f"[WARN] Failed to load {BIOGPT_ID} on CPU: {e}") fallback_id = "microsoft/BioGPT" print(f"[INFO] Falling back to {fallback_id}") tok2 = AutoTokenizer.from_pretrained(fallback_id) if tok2.pad_token is None: tok2.pad_token = tok2.eos_token tok2.padding_side = "left" mdl2 = AutoModelForCausalLM.from_pretrained( fallback_id, torch_dtype=torch.float32, low_cpu_mem_usage=True ) mdl2.to(DEVICE).eval() return tok2, mdl2, DEVICE def biogpt_prompt(user_text: str) -> str: # Keep it short and explicit; minified JSON reduces parse failures return f"""{JSON_INSTRUCTIONS} Patient Inquiry: {user_text} JSON: """ @torch.inference_mode() def biogpt_json(tok, mdl, device, text: str) -> Dict[str, Any]: """ Robust JSON generation with retries and a safe fallback to avoid startup crashes. Only pass temperature/top_p when do_sample=True (prevents warnings on HF >=4.44). """ prompt = biogpt_prompt(text) temps = [0.1, 0.0, 0.0] max_tokens = [200, 180, 150] for t, mx in zip(temps, max_tokens): try: inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=700, padding=True).to(device) gen_kwargs = dict( max_new_tokens=mx, eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id, ) if t > 0.0: gen_kwargs.update(temperature=t, do_sample=True, top_p=0.9) else: gen_kwargs.update(do_sample=False) outputs = mdl.generate(**inputs, **gen_kwargs) decoded = tok.decode(outputs[0], skip_special_tokens=True) if decoded.startswith(prompt): decoded = decoded[len(prompt):].strip() return extract_json(decoded) except Exception: continue # Final fallback so the app never crashes return biogpt_safe_fallback(text) # ----------------------------- # Ranking logic (AI-derived matching) # ----------------------------- def ai_enriched_query(signals: Dict[str, Any], user_text: str) -> str: ks = signals.get("key_symptoms") or [] conds = signals.get("suspected_conditions") or [] cands = signals.get("candidate_specialties") or [] parts = [user_text] if ks: parts.append("Key symptoms: " + "; ".join(ks)) if conds: parts.append("Suspected: " + "; ".join(conds)) if cands: parts.append("Candidate specialties: " + "; ".join(cands)) return " | ".join(parts) def rank_specialties(query_emb: torch.Tensor, index: EmbeddingIndex, top_k: int = 8): sims = torch.matmul(query_emb, index.embeddings.T).squeeze(0) # cosine on normalized vectors vals, idxs = torch.topk(sims, k=min(top_k, sims.numel())) results = [] for v, i in zip(vals.tolist(), idxs.tolist()): spec, sub = index.labels[i] results.append({"specialty": spec, "subspecialty": sub, "score": float(v)}) return results # ----------------------------- # Care suggestions (AI-derived, per top match) # ----------------------------- CARE_INSTRUCTIONS = """You are assisting care pathway planning. For the given specialty/subspecialty and symptoms, suggest initial tests/treatments. Return JSON: { "suggested_care": [string] } JSON ONLY, single line. """ def care_prompt(spec: str, sub: Optional[str], inquiry: str) -> str: title = f"{spec}" + (f" > {sub}" if sub else "") return f"""{CARE_INSTRUCTIONS} Specialty: {title} Patient Inquiry: {inquiry} JSON: """ @torch.inference_mode() def suggest_care(tok, mdl, device, spec: str, sub: Optional[str], inquiry: str) -> List[str]: prompt = care_prompt(spec, sub, inquiry) temps = [0.1, 0.0] max_tokens = [120, 100] for t, mx in zip(temps, max_tokens): try: inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device) gen_kwargs = dict( max_new_tokens=mx, eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id, ) if t > 0.0: gen_kwargs.update(temperature=t, do_sample=True, top_p=0.9) else: gen_kwargs.update(do_sample=False) outputs = mdl.generate(**inputs, **gen_kwargs) decoded = tok.decode(outputs[0], skip_special_tokens=True) if decoded.startswith(prompt): decoded = decoded[len(prompt):].strip() data = extract_json(decoded) items = data.get("suggested_care") or [] return [str(x) for x in items][:5] except Exception: continue return [] # ----------------------------- # Rendering # ----------------------------- def render(payload: Dict[str, Any]) -> str: inquiry = payload["patient_inquiry"] triage = payload["triage"] urgency = triage["urgency"] conf = triage.get("confidence") score = triage.get("priority_score_1_to_10") urg_emoji = {"Emergency": "🚨", "Urgent": "🟠", "Routine": "🟒"}.get(urgency, "🟒") lines = [] lines.append("## πŸ₯ Healthcare Lead Analysis\n") lines.append(f"**Patient Inquiry:** \"{inquiry}\"\n") head = f"- **Urgency:** {urg_emoji} **{urgency}**" if isinstance(conf, (int, float)): head += f" Β· Confidence: {conf:.2f}" if isinstance(score, (int, float)): head += f" Β· Priority: {score}/10" lines.append(head + "\n") if payload.get("matches"): lines.append("### 🩺 Recommended Specialty Pathways:") for m in payload["matches"]: title = f"**{m['specialty']}**" + (f" β€Ί {m['subspecialty']}" if m.get("subspecialty") else "") sc = f" ({m['score']:.2f})" lines.append(f"- {title}{sc}") care = m.get("suggested_care") or [] if care: lines.append(" - _Suggested initial tests/treatments_: " + "; ".join(care)) lines.append("") acts = payload.get("recommended_actions") or [] if acts: lines.append("### βœ… Recommended Actions:") for a in acts: pr = a.get("priority") badge = f" **[{pr}]**" if pr else "" lines.append(f"- {a.get('action','')}{badge}") lines.append("") lines.append("\n---\n*Analysis powered by BioGPT + vector retrieval over your ontology*") return "\n".join(lines).strip() # ----------------------------- # Classify handlers # ----------------------------- def classify(text: str) -> str: if not READY: return "Model is still loading. Please try again in a moment." if not text or not text.strip(): return "Please enter some text to classify." forced_emergency = red_flag_urgency_rules(text) # 1) BioGPT: get triage + signals (robust) biogpt_out = biogpt_json(bio_tok, bio_mdl, bio_device, text) triage = biogpt_out.get("triage", {}) if isinstance(biogpt_out.get("triage", {}), dict) else {} urgency = (triage.get("urgency") or "Routine").title() if urgency not in ["Emergency", "Urgent", "Routine"]: urgency = "Routine" conf = triage.get("confidence") score = triage.get("priority_score_1_to_10") actions = biogpt_out.get("recommended_actions") or [] signals = biogpt_out.get("signals") or {} # Emergency override if forced_emergency: urgency = "Emergency" score = max(9, int(score) if isinstance(score, int) else 9) if not actions: actions = [{"action": "Immediate ER referral or call emergency services", "priority": "Emergency"}] # 2) Build AI-enriched query and rank against ontology query_text = ai_enriched_query(signals, text) q_emb = encode_sentences(emb_tok, emb_mdl, emb_device, [query_text]) # [1, D] ranked = rank_specialties(q_emb, ONT_INDEX, top_k=8) # 3) For top-N, ask BioGPT to suggest care steps top_matches = [] for r in ranked[:5]: care = suggest_care(bio_tok, bio_mdl, bio_device, r["specialty"], r.get("subspecialty"), text) top_matches.append({ "specialty": r["specialty"], "subspecialty": r.get("subspecialty"), "score": r["score"], "suggested_care": care }) payload = { "patient_inquiry": text, "triage": {"urgency": urgency, "confidence": conf, "priority_score_1_to_10": score}, "matches": top_matches, "recommended_actions": actions } gc.collect() return render(payload) def classify_batch(texts: str) -> str: if not READY: return "Model is still loading. Please try again in a moment." if not texts or not texts.strip(): return "Please enter text to analyze." items = [t.strip() for t in texts.replace(";", "\n").split("\n") if t.strip()] if len(items) == 1: return classify(items[0]) parts = [f"### Lead {i}\n\n{classify(t)}" for i, t in enumerate(items, 1)] return "\n\n---\n\n".join(parts) # ----------------------------- # Load models & ontology on startup (CPU-only) # ----------------------------- print("Loading ontology...") ONTOLOGY = load_ontology() print("Loading embedding encoder (CPU)...") emb_tok, emb_mdl, emb_device = load_text_encoder(EMB_MODEL_ID) print("Building ontology index...") ONT_INDEX = build_ontology_index(ONTOLOGY, emb_tok, emb_mdl, emb_device) print("Loading BioGPT (CPU)...") bio_tok, bio_mdl, bio_device = load_biogpt() print("Models loaded. Server will start...") # Signal ready AFTER all models are loaded READY = True print("Server ready.") # ----------------------------- # Gradio UI # ----------------------------- single = gr.Interface( fn=classify, inputs=gr.Textbox( label="πŸ“ Enter Patient Inquiry or Lead Description", placeholder="e.g., Severe chest pain radiating to left arm with shortness of breath.", lines=4, info="Describe symptoms, urgency, and what the patient is seeking." ), outputs=gr.Markdown(label="πŸ₯ Analysis"), title="DocMap Healthcare Lead Analyzer", description="Dynamic ontology + AI-derived specialty/subspecialty matching with suggested initial care.", examples=[ ["Emergency: Severe chest pain and shortness of breath, need immediate care"], ["Period pain for first three hours and tummy pain"], ["Knee pain after football, swelling and trouble bearing weight"], ["Chronic cough, wheezing at night, worse with exercise"], ["Severe headache with photophobia and nausea"] ], cache_examples=False, # βœ… prevents startup caching (and 500s) theme=gr.themes.Soft() ) batch = gr.Interface( fn=classify_batch, inputs=gr.Textbox( label="πŸ“‹ Batch Analysis (Multiple Leads)", placeholder="One inquiry per line or separated by semicolons…", lines=6 ), outputs=gr.Markdown(label="πŸ₯ Batch Results"), title="DocMap Batch Analyzer", description="Analyze multiple leads with AI-derived matching", cache_examples=False # βœ… prevents startup caching ) demo = gr.TabbedInterface([single, batch], ["Single Lead", "Batch Processing"]) if __name__ == "__main__": # Disable SSR (it can surface startup-errors as fatal in some envs) demo.queue(max_size=20).launch(server_name="0.0.0.0", ssr_mode=False)