| import io |
| import torch |
| import torch.nn as nn |
| import timm |
| import pickle |
| import traceback |
| import os |
| from PIL import Image |
| from fastapi import FastAPI, File, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from torchvision import transforms |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
| CONFIG = { |
| 'coatnet_model': 'coatnet_1_rw_224', |
| 't5_model': 't5-small', |
| 'img_emb_dim': 768, |
| 'train_last_stages': 2, |
| 'image_size': 224, |
| } |
|
|
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"π₯οΈ Using device: {device}") |
|
|
| |
| |
| |
|
|
| print("\n" + "="*80) |
| print("LOADING TOKENIZER AND IMAGE TRANSFORM") |
| print("="*80) |
|
|
| |
| tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model']) |
| print(f"β Loaded tokenizer: {CONFIG['t5_model']}") |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ]) |
| print(f"β Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})") |
|
|
| def preprocess_image(image_path: str) -> torch.Tensor: |
| """Load and preprocess image.""" |
| image = Image.open(image_path).convert('RGB') |
| return transform(image) |
|
|
| |
| |
| |
| |
| class CoAtNetEncoder(nn.Module): |
| def __init__(self, model_name=None, pretrained=False, train_last_stages=None): |
| super().__init__() |
| |
| model_name = model_name or CONFIG['coatnet_model'] |
| train_last_stages = train_last_stages or CONFIG['train_last_stages'] |
| |
| |
| self.backbone = timm.create_model(model_name, pretrained=pretrained) |
|
|
| for name, param in self.backbone.named_parameters(): |
| param.requires_grad = False |
| for i in range(5 - train_last_stages, 5): |
| if f"stages.{i}" in name: |
| param.requires_grad = True |
| break |
|
|
| |
| with torch.no_grad(): |
| dummy = torch.randn(1, 3, 224, 224) |
| features = self.backbone.forward_features(dummy) |
| if len(features.shape) == 4: |
| features = features.mean(dim=[2, 3]) |
| self.feature_dim = features.shape[-1] |
|
|
| print(f" CoAtNetEncoder feature_dim = {self.feature_dim}") |
|
|
| def forward(self, x): |
| features = self.backbone.forward_features(x) |
| if len(features.shape) == 4: |
| features = features.mean(dim=[2, 3]) |
| return features |
|
|
|
|
| |
| |
| |
| |
| |
| class SFTVisionT5Model(nn.Module): |
| def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768): |
| super().__init__() |
| self.img_encoder = img_encoder |
| |
| self.t5 = T5ForConditionalGeneration.from_pretrained(txt_model_name) |
| |
| self.proj = nn.Linear(img_emb_dim, self.t5.config.d_model) |
|
|
| for p in self.t5.shared.parameters(): |
| p.requires_grad = False |
|
|
| def generate_reports(self, pixel_values, max_length=100): |
| |
| with torch.no_grad(): |
| |
| img_feats = self.img_encoder(pixel_values) |
| img_feats = self.proj(img_feats) |
| encoder_hidden_states = img_feats.unsqueeze(1) |
|
|
| |
| encoder_outputs = self.t5.encoder( |
| inputs_embeds=encoder_hidden_states |
| ) |
|
|
| attn = torch.ones( |
| encoder_hidden_states.size()[:2], device=pixel_values.device |
| ) |
|
|
| generated_ids = self.t5.generate( |
| encoder_outputs=encoder_outputs, |
| attention_mask=attn, |
| max_length=max_length, |
| num_beams=4, |
| early_stopping=True, |
| ) |
|
|
| reports = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| |
| cleaned = [] |
| for r in reports: |
| if r.lower().startswith("projection:"): |
| |
| parts = r.split(".", 1) |
| r = parts[1].strip() if len(parts) > 1 else r |
| cleaned.append(r) |
| return cleaned |
|
|
|
|
| |
| |
| |
| |
| class PPOVisionT5Model(nn.Module): |
| def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768): |
| super().__init__() |
| self.img_encoder = img_encoder |
| |
| self.txt_model = T5ForConditionalGeneration.from_pretrained(txt_model_name) |
| |
| self.img_proj = nn.Linear(img_emb_dim, self.txt_model.config.d_model) |
|
|
| def generate_reports(self, images, max_length=128): |
| with torch.no_grad(): |
| img_features = self.img_encoder(images) |
| img_emb = self.img_proj(img_features).unsqueeze(1) |
|
|
| batch_size = images.size(0) |
| img_attn = torch.ones(batch_size, 1, device=images.device) |
|
|
| encoder_outputs = self.txt_model.encoder( |
| inputs_embeds=img_emb, |
| attention_mask=img_attn |
| ) |
|
|
| |
| generated = self.txt_model.generate( |
| encoder_outputs=encoder_outputs, |
| attention_mask=img_attn, |
| max_length=max_length, |
| num_beams=4, |
| early_stopping=True, |
| ) |
|
|
| reports = tokenizer.batch_decode(generated, skip_special_tokens=True) |
| |
| cleaned = [] |
| for r in reports: |
| if r.lower().startswith("projection:"): |
| |
| parts = r.split(".", 1) |
| r = parts[1].strip() if len(parts) > 1 else r |
| cleaned.append(r) |
| return cleaned |
|
|
|
|
| |
| |
| |
| |
| class RewardModel(nn.Module): |
| def __init__(self, img_encoder, txt_model_name="t5-small"): |
| super().__init__() |
| self.img_encoder = img_encoder |
| self.txt_encoder = T5ForConditionalGeneration.from_pretrained(txt_model_name).encoder |
| img_dim = img_encoder.feature_dim |
| txt_dim = self.txt_encoder.config.d_model |
| self.img_proj = nn.Linear(img_dim, 512) |
| self.txt_proj = nn.Linear(txt_dim, 512) |
| self.reward_head = nn.Sequential( |
| nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1), |
| nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.1), |
| nn.Linear(256, 1) |
| ) |
|
|
| def forward(self, images, input_ids, attention_mask): |
| img_features = self.img_encoder(images) |
| img_emb = self.img_proj(img_features) |
| txt_outputs = self.txt_encoder(input_ids=input_ids, attention_mask=attention_mask) |
| txt_emb = txt_outputs.last_hidden_state.mean(dim=1) |
| txt_emb = self.txt_proj(txt_emb) |
| combined = torch.cat([img_emb, txt_emb], dim=1) |
| return self.reward_head(combined).squeeze(-1) |
|
|
|
|
| |
| |
| |
| |
| def remap_keys(raw_sd: dict, label: str) -> dict: |
| """ |
| Remap state_dict keys to match current model attribute names. |
| |
| Known mismatches discovered from diagnostic output: |
| SFT notebook used: |
| img_encoder.encoder.* β we use img_encoder.backbone.* |
| t5.* β we use t5.* (already correct for SFTVisionT5Model) |
| proj.* β we use proj.* (already correct for SFTVisionT5Model) |
| PPO/RM notebooks used: |
| img_encoder.backbone.* β already correct β
|
| txt_model.* β already correct β
|
| img_proj.* β already correct β
|
| """ |
| remapped = {} |
| changed = 0 |
| for k, v in raw_sd.items(): |
| new_k = k |
| |
| if "img_encoder.encoder." in new_k: |
| new_k = new_k.replace("img_encoder.encoder.", "img_encoder.backbone.") |
| changed += 1 |
| remapped[new_k] = v |
| if changed: |
| print(f" π§ Remapped {changed} keys: img_encoder.encoder.* β img_encoder.backbone.*") |
| return remapped |
|
|
|
|
| def load_model(path: str, model_obj: nn.Module, label: str) -> nn.Module: |
| print(f"\nπ Loading {label} from: {path}") |
|
|
| if path.endswith(".pkl"): |
| with open(path, "rb") as f: |
| loaded = pickle.load(f) |
| print(f" β
Loaded full pickle object: {type(loaded)}") |
| return loaded.to(device) |
|
|
| |
| raw_sd = torch.load(path, map_location=device) |
|
|
| |
| saved_keys = list(raw_sd.keys()) |
| print(f" Saved keys (first 5): {saved_keys[:5]}") |
| model_keys = list(model_obj.state_dict().keys()) |
| print(f" Model keys (first 5): {model_keys[:5]}") |
|
|
| |
| raw_sd = remap_keys(raw_sd, label) |
|
|
| result = model_obj.load_state_dict(raw_sd, strict=False) |
|
|
| |
| |
| |
| SAFE_MISSING = ("num_batches_tracked", "head.fc.") |
| missing = [k for k in result.missing_keys if not any(s in k for s in SAFE_MISSING)] |
| unexpected = [k for k in result.unexpected_keys if "num_batches_tracked" not in k] |
|
|
| if missing: |
| print(f" Missing keys: {missing[:5]}{'...' if len(missing)>5 else ''}") |
| print(f" WARNING: {len(missing)} missing keys - weights NOT loaded for those layers!") |
| if unexpected: |
| print(f" Unexpected keys: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}") |
| if not missing and not unexpected: |
| print(f" OK: All keys matched perfectly!") |
|
|
| return model_obj.to(device) |
|
|
|
|
| |
| |
| |
| |
| def download_model_from_hf(model_filename: str, local_path: str = "models/") -> str: |
| """Download model from Hugging Face Hub if not exists locally""" |
| os.makedirs(local_path, exist_ok=True) |
| full_path = os.path.join(local_path, model_filename) |
| |
| if not os.path.exists(full_path): |
| print(f" Downloading {model_filename} from Hugging Face Hub...") |
| try: |
| downloaded_path = hf_hub_download( |
| repo_id="Shree2604/BioStack", |
| filename=model_filename, |
| local_dir=local_path, |
| local_dir_use_symlinks=False |
| ) |
| print(f" Downloaded {model_filename}") |
| return downloaded_path |
| except Exception as e: |
| print(f" Failed to download {model_filename}: {e}") |
| raise |
| else: |
| print(f" Using local {model_filename}") |
| return full_path |
|
|
| print("\n" + "="*60) |
| print(" LOADING MODELS FROM HUGGING FACE HUB") |
| print("="*60) |
|
|
| |
| SFT_MODEL_PATH = download_model_from_hf("best_model.pt") |
| REWARD_MODEL_PATH = download_model_from_hf("reward_model.pt") |
| PPO_MODEL_PATH = download_model_from_hf("rlhf_model.pt") |
|
|
| |
| _sft_enc = CoAtNetEncoder(pretrained=False) |
| sft_model = load_model(SFT_MODEL_PATH, SFTVisionT5Model(_sft_enc), "SFT Model") |
| sft_model.eval() |
|
|
| |
| _rm_enc = CoAtNetEncoder(pretrained=False) |
| reward_model = load_model(REWARD_MODEL_PATH, RewardModel(_rm_enc), "Reward Model") |
| reward_model.eval() |
|
|
| |
| _ppo_enc = CoAtNetEncoder(pretrained=False) |
| ppo_model = load_model(PPO_MODEL_PATH, PPOVisionT5Model(_ppo_enc), "PPO Model") |
| ppo_model.eval() |
|
|
| print("\n All models loaded and ready!\n" + "="*60 + "\n") |
|
|
|
|
| |
| |
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| def preprocess(file_bytes: bytes) -> torch.Tensor: |
| img = Image.open(io.BytesIO(file_bytes)).convert("RGB") |
| return transform(img).unsqueeze(0).to(device) |
|
|
|
|
| |
| |
| |
| KEY_MEDICAL_TERMS = [ |
| 'lung', 'heart', 'normal', 'clear', 'opacity', 'infiltrate', |
| 'cardiomegaly', 'pleural', 'pulmonary', 'chest', 'thorax', |
| 'pneumonia', 'edema', 'effusion', 'consolidation' |
| ] |
|
|
| def reward_feedback(report: str, score: float) -> str: |
| rl = report.lower() |
| present = [t for t in KEY_MEDICAL_TERMS if t in rl] |
| missing = [t for t in KEY_MEDICAL_TERMS if t not in rl] |
| words = len(report.split()) |
| length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long") |
|
|
| |
| terminology_score = len(present) / len(KEY_MEDICAL_TERMS) |
| completeness_score = min(1.0, words / 100.0) |
| structure_score = 1.0 if 50 <= words <= 150 else 0.5 |
| radiological_score = score |
|
|
| return ( |
| f"Reward Score: {score:.2f} | " |
| f"Quality Factors - " |
| f"Medical Terminology: {terminology_score:.1%} | " |
| f"Clinical Completeness: {completeness_score:.1%} | " |
| f"Report Structure: {structure_score:.1%}" |
| ) |
|
|
|
|
| |
| |
| |
| app = FastAPI(title="RLHF Medical Demo") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "device": str(device)} |
|
|
|
|
| @app.post("/sft") |
| async def sft_inference(file: UploadFile = File(...)): |
| try: |
| |
| import tempfile |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: |
| temp_file.write(await file.read()) |
| temp_path = temp_file.name |
|
|
| try: |
| |
| tensor = preprocess_image(temp_path).unsqueeze(0).to(device) |
| report = sft_model.generate_reports(tensor)[0] |
| print(f"[SFT] Generated: {report}") |
| return {"report": report[:81]} |
| finally: |
| |
| os.unlink(temp_path) |
|
|
| except Exception as e: |
| traceback.print_exc() |
| return {"report": f"ERROR: {str(e)}"} |
|
|
|
|
| @app.post("/reward") |
| async def reward_inference(file: UploadFile = File(...)): |
| try: |
| |
| import tempfile |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: |
| temp_file.write(await file.read()) |
| temp_path = temp_file.name |
|
|
| try: |
| |
| tensor = preprocess_image(temp_path).unsqueeze(0).to(device) |
| |
| sft_report = sft_model.generate_reports(tensor)[0] |
| print(f"[REWARD] Scoring SFT report: {sft_report}") |
|
|
| if not sft_report.strip(): |
| return {"score": 0.0, "feedback": "", "sft_report": ""} |
|
|
| enc = tokenizer( |
| [sft_report], |
| max_length=128, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ) |
| input_ids = enc.input_ids.to(device) |
| attention_mask = enc.attention_mask.to(device) |
|
|
| with torch.no_grad(): |
| raw_score = reward_model(tensor, input_ids, attention_mask).item() |
|
|
| |
| print(f"[REWARD] Raw neural network output: {raw_score:.6f}") |
| print(f"[REWARD] Clamping to [0,1] range: max(0.0, min(1.0, {raw_score:.6f})) = {max(0.0, min(1.0, raw_score)):.6f}") |
|
|
| |
| rl = sft_report.lower() |
| present = [t for t in KEY_MEDICAL_TERMS if t in rl] |
| missing = [t for t in KEY_MEDICAL_TERMS if t not in rl] |
| words = len(sft_report.split()) |
| length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long") |
|
|
| print(f"[REWARD] Report analysis:") |
| print(f" - Total words: {words} ({length_q})") |
| print(f" - Medical terms present ({len(present)}/{len(KEY_MEDICAL_TERMS)}): {present}") |
| print(f" - Medical terms missing: {missing}") |
| print(f" - Key terms list: {KEY_MEDICAL_TERMS}") |
|
|
| |
| print(f"[REWARD] Model architecture:") |
| print(f" - CoAtNet feature dim: {reward_model.img_encoder.feature_dim}") |
| print(f" - T5 d_model: {reward_model.txt_encoder.config.d_model}") |
| print(f" - Combined feature dim: 1024 (512 img + 512 text)") |
| print(f" - Reward head: 1024β512β256β1") |
|
|
| |
| score = float(max(0.0, min(1.0, raw_score))) |
| feedback = reward_feedback(sft_report, score) |
| print(f"[REWARD] Final Score={score:.3f}") |
| return {"score": score, "feedback": feedback, "sft_report": sft_report} |
| finally: |
| |
| os.unlink(temp_path) |
|
|
| except Exception as e: |
| traceback.print_exc() |
| return {"score": 0.0, "feedback": f"ERROR: {str(e)}", "sft_report": ""} |
|
|
|
|
| @app.post("/ppo") |
| async def ppo_inference(file: UploadFile = File(...)): |
| try: |
| |
| import tempfile |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file: |
| temp_file.write(await file.read()) |
| temp_path = temp_file.name |
|
|
| try: |
| |
| tensor = preprocess_image(temp_path).unsqueeze(0).to(device) |
| report = ppo_model.generate_reports(tensor)[0] |
| print(f"[PPO] Generated: {report}") |
| return {"report": report} |
| finally: |
| |
| os.unlink(temp_path) |
|
|
| except Exception as e: |
| traceback.print_exc() |
| return {"report": f"ERROR: {str(e)}"} |
|
|
|
|
| |
| |
| |
| |
| @app.get("/debug_keys") |
| def debug_keys(): |
| import os |
| result = {} |
| for label, path in [("SFT", SFT_MODEL_PATH), ("Reward", REWARD_MODEL_PATH), ("PPO", PPO_MODEL_PATH)]: |
| if not os.path.exists(path): |
| result[label] = f"FILE NOT FOUND: {path}" |
| continue |
| try: |
| sd = torch.load(path, map_location="cpu") |
| keys = list(sd.keys()) |
| result[label] = {"first_10_keys": keys[:10], "total_keys": len(keys)} |
| except Exception as e: |
| result[label] = f"ERROR: {e}" |
| return result |
|
|
|
|
| |
| |
| |
| from fastapi.staticfiles import StaticFiles |
| import os |
|
|
| |
| if os.path.exists("build"): |
| app.mount("/", StaticFiles(directory="build", html=True), name="static") |
| print("β
React app mounted at /") |
| else: |
| print("β οΈ Build directory not found, serving API only") |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860, reload=False) |