AE-Shree commited on
Commit
f3e4ffb
Β·
1 Parent(s): 28046a7

Deploy BioStack RLHF Medical Demo

Browse files
Files changed (1) hide show
  1. server.py +340 -398
server.py CHANGED
@@ -2,7 +2,6 @@ import io
2
  import torch
3
  import torch.nn as nn
4
  import timm
5
- import pickle
6
  import traceback
7
  import os
8
  from PIL import Image
@@ -13,7 +12,7 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
13
  from huggingface_hub import hf_hub_download
14
 
15
  # ─────────────────────────────────────────────────────────────────────────────
16
- # CONFIGURATION
17
  # ─────────────────────────────────────────────────────────────────────────────
18
  CONFIG = {
19
  'coatnet_model': 'coatnet_1_rw_224',
@@ -21,6 +20,8 @@ CONFIG = {
21
  'img_emb_dim': 768,
22
  'train_last_stages': 2,
23
  'image_size': 224,
 
 
24
  }
25
 
26
  # ─────────────────────────────────────────────────────────────────────────────
@@ -30,18 +31,17 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  print(f"πŸ–₯️ Using device: {device}")
31
 
32
  # ─────────────────────────────────────────────────────────────────────────────
33
- # SECTION 7: Load Tokenizer and Image Transform
34
  # ─────────────────────────────────────────────────────────────────────────────
35
-
36
  print("\n" + "="*80)
37
- print("LOADING TOKENIZER AND IMAGE TRANSFORM")
38
  print("="*80)
39
-
40
- # Load tokenizer
41
  tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
42
  print(f"βœ“ Loaded tokenizer: {CONFIG['t5_model']}")
43
 
44
- # Define image transform
 
 
45
  transform = transforms.Compose([
46
  transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
47
  transforms.ToTensor(),
@@ -52,487 +52,429 @@ transform = transforms.Compose([
52
  ])
53
  print(f"βœ“ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
54
 
55
- def preprocess_image(image_path: str) -> torch.Tensor:
56
- """Load and preprocess image."""
57
- image = Image.open(image_path).convert('RGB')
58
- return transform(image)
59
-
60
  # ─────────────────────────────────────────────────────────────────────────────
61
- # ARCHITECTURE 1 β€” CoAtNet Encoder (shared by all three models)
62
- # Matches BOTH notebooks exactly.
63
  # ─────────────────────────────────────────────────────────────────────────────
64
  class CoAtNetEncoder(nn.Module):
65
- def __init__(self, model_name=None, pretrained=False, train_last_stages=None):
66
  super().__init__()
67
- # Use CONFIG defaults if not specified
68
- model_name = model_name or CONFIG['coatnet_model']
69
- train_last_stages = train_last_stages or CONFIG['train_last_stages']
70
-
71
- # pretrained=False at inference time β€” weights come from .pt file
72
- self.backbone = timm.create_model(model_name, pretrained=pretrained)
73
-
74
- for name, param in self.backbone.named_parameters():
75
- param.requires_grad = False
76
- for i in range(5 - train_last_stages, 5):
77
- if f"stages.{i}" in name:
78
- param.requires_grad = True
79
- break
80
 
81
- # Detect feature_dim dynamically (same as RM/PPO notebook Cell 4)
82
- with torch.no_grad():
83
- dummy = torch.randn(1, 3, 224, 224)
84
- features = self.backbone.forward_features(dummy)
85
- if len(features.shape) == 4:
86
- features = features.mean(dim=[2, 3])
87
- self.feature_dim = features.shape[-1]
88
 
89
- print(f" CoAtNetEncoder feature_dim = {self.feature_dim}")
 
 
 
 
 
90
 
91
  def forward(self, x):
92
- features = self.backbone.forward_features(x)
93
- if len(features.shape) == 4:
94
- features = features.mean(dim=[2, 3])
95
- return features
96
 
97
 
98
  # ─────────────────────────────────────────────────────────────────────────────
99
- # ARCHITECTURE 2 β€” SFT VisionT5Model
100
- # BUG FIX: Uses self.t5 and self.proj β€” exactly matching best_model.pt keys
101
- # from SFT notebook Cell 33. Do NOT rename these to txt_model/img_proj.
102
  # ─────────────────────────────────────────────────────────────────────────────
103
- class SFTVisionT5Model(nn.Module):
104
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
105
  super().__init__()
 
 
106
  self.img_encoder = img_encoder
107
- # ← self.t5 (NOT self.txt_model β€” must match saved keys)
 
108
  self.t5 = T5ForConditionalGeneration.from_pretrained(txt_model_name)
109
- # ← self.proj (NOT self.img_proj β€” must match saved keys)
 
110
  self.proj = nn.Linear(img_emb_dim, self.t5.config.d_model)
111
 
 
112
  for p in self.t5.shared.parameters():
113
  p.requires_grad = False
114
 
115
- def generate_reports(self, pixel_values, max_length=100):
116
- self.eval()
117
- with torch.no_grad():
118
- # Extract + project image features
119
- img_feats = self.img_encoder(pixel_values) # [B, feature_dim]
120
- img_feats = self.proj(img_feats) # [B, d_model]
121
- encoder_hidden_states = img_feats.unsqueeze(1) # [B, 1, d_model]
122
-
123
- # Encode
124
- encoder_outputs = self.t5.encoder(
125
- inputs_embeds=encoder_hidden_states
126
- )
127
 
128
- attn = torch.ones(
129
- encoder_hidden_states.size()[:2], device=pixel_values.device
130
- )
131
-
132
- # BUG FIX 3: repetition_penalty + no_repeat_ngram_size breaks
133
- # the "Projection: Projection: Projection:" loop
134
- generated_ids = self.t5.generate(
135
- encoder_outputs=encoder_outputs,
136
- attention_mask=attn,
137
- max_length=max_length,
138
- num_beams=4,
139
- early_stopping=True,
140
- no_repeat_ngram_size=3,
141
- repetition_penalty=1.3,
142
- )
143
 
144
- return generated_ids
145
-
146
-
147
- # ─────────────────────────────────────────────────────────────────────────────
148
- # ARCHITECTURE 3 β€” PPO VisionT5Model
149
- # Uses self.txt_model and self.img_proj β€” matching RM/PPO notebook Cell 4.
150
- # ─────────────────────────────────────────────────────────────────────────────
151
- class PPOVisionT5Model(nn.Module):
152
- def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
153
- super().__init__()
154
- self.img_encoder = img_encoder
155
- # ← self.txt_model (matches PPO notebook Cell 4)
156
- self.txt_model = T5ForConditionalGeneration.from_pretrained(txt_model_name)
157
- # ← self.img_proj (matches PPO notebook Cell 4)
158
- self.img_proj = nn.Linear(img_emb_dim, self.txt_model.config.d_model)
159
 
160
- def generate_reports(self, images, max_length=128):
161
- self.eval()
162
- with torch.no_grad():
163
- img_features = self.img_encoder(images) # [B, feature_dim]
164
- img_emb = self.img_proj(img_features).unsqueeze(1) # [B, 1, d_model]
165
 
166
- batch_size = images.size(0)
167
- img_attn = torch.ones(batch_size, 1, device=images.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- encoder_outputs = self.txt_model.encoder(
170
- inputs_embeds=img_emb,
171
- attention_mask=img_attn
172
- )
 
 
 
 
 
 
173
 
174
- # BUG FIX 3: same repetition guards as SFT
175
- generated = self.txt_model.generate(
176
- encoder_outputs=encoder_outputs,
177
- attention_mask=img_attn,
178
- max_length=max_length,
179
- num_beams=4,
180
- early_stopping=True,
181
- no_repeat_ngram_size=3,
182
- repetition_penalty=1.3,
183
- )
184
 
185
- return generated
186
 
 
187
 
188
  # ─────────────────────────────────────────────────────────────────────────────
189
- # ARCHITECTURE 4 β€” Reward Model
190
- # Matches RM/PPO notebook Cell 5 exactly.
191
  # ─────────────────────────────────────────────────────────────────────────────
192
- class RewardModel(nn.Module):
193
- def __init__(self, img_encoder, txt_model_name="t5-small"):
194
- super().__init__()
195
- self.img_encoder = img_encoder
196
- self.txt_encoder = T5ForConditionalGeneration.from_pretrained(txt_model_name).encoder
197
- img_dim = img_encoder.feature_dim
198
- txt_dim = self.txt_encoder.config.d_model
199
- self.img_proj = nn.Linear(img_dim, 512)
200
- self.txt_proj = nn.Linear(txt_dim, 512)
201
- self.reward_head = nn.Sequential(
202
- nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1),
203
- nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.1),
204
- nn.Linear(256, 1)
 
205
  )
206
 
207
- def forward(self, images, input_ids, attention_mask):
208
- img_features = self.img_encoder(images)
209
- img_emb = self.img_proj(img_features)
210
- txt_outputs = self.txt_encoder(input_ids=input_ids, attention_mask=attention_mask)
211
- txt_emb = txt_outputs.last_hidden_state.mean(dim=1)
212
- txt_emb = self.txt_proj(txt_emb)
213
- combined = torch.cat([img_emb, txt_emb], dim=1)
214
- return self.reward_head(combined).squeeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
 
217
  # ─────────────────────────────────────────────────────────────────────────────
218
- # MODEL LOADER β€” handles both .pt (state_dict) and .pkl (full model)
219
- # Prints a key-match diagnostic so you can see exactly what loaded.
220
  # ─────────────────────────────────────────────────────────────────────────────
221
- def remap_keys(raw_sd: dict, label: str) -> dict:
 
 
 
 
222
  """
223
- Remap state_dict keys to match current model attribute names.
224
-
225
- Known mismatches discovered from diagnostic output:
226
- SFT notebook used:
227
- img_encoder.encoder.* β†’ we use img_encoder.backbone.*
228
- t5.* β†’ we use t5.* (already correct for SFTVisionT5Model)
229
- proj.* β†’ we use proj.* (already correct for SFTVisionT5Model)
230
- PPO/RM notebooks used:
231
- img_encoder.backbone.* β†’ already correct βœ…
232
- txt_model.* β†’ already correct βœ…
233
- img_proj.* β†’ already correct βœ…
234
  """
235
- remapped = {}
236
- changed = 0
237
- for k, v in raw_sd.items():
238
- new_k = k
239
- # SFT encoder used self.encoder, our CoAtNetEncoder uses self.backbone
240
- if "img_encoder.encoder." in new_k:
241
- new_k = new_k.replace("img_encoder.encoder.", "img_encoder.backbone.")
242
- changed += 1
243
- remapped[new_k] = v
244
- if changed:
245
- print(f" πŸ”§ Remapped {changed} keys: img_encoder.encoder.* β†’ img_encoder.backbone.*")
246
- return remapped
247
-
248
-
249
- def load_model(path: str, model_obj: nn.Module, label: str) -> nn.Module:
250
- print(f"\nπŸ“‚ Loading {label} from: {path}")
251
-
252
- if path.endswith(".pkl"):
253
- with open(path, "rb") as f:
254
- loaded = pickle.load(f)
255
- print(f" βœ… Loaded full pickle object: {type(loaded)}")
256
- return loaded.to(device)
257
-
258
- # .pt state_dict
259
- raw_sd = torch.load(path, map_location=device)
260
-
261
- # Print first 5 saved keys for diagnosis
262
- saved_keys = list(raw_sd.keys())
263
- print(f" Saved keys (first 5): {saved_keys[:5]}")
264
- model_keys = list(model_obj.state_dict().keys())
265
- print(f" Model keys (first 5): {model_keys[:5]}")
266
-
267
- # Remap any mismatched key prefixes
268
- raw_sd = remap_keys(raw_sd, label)
269
-
270
- result = model_obj.load_state_dict(raw_sd, strict=False)
271
-
272
- # Ignore known-safe missing keys:
273
- # head.fc.* - classification head, intentionally removed (num_classes=0)
274
- # num_batches_tracked - BatchNorm counter, not a learned weight
275
- SAFE_MISSING = ("num_batches_tracked", "head.fc.")
276
- missing = [k for k in result.missing_keys if not any(s in k for s in SAFE_MISSING)]
277
- unexpected = [k for k in result.unexpected_keys if "num_batches_tracked" not in k]
278
-
279
- if missing:
280
- print(f" Missing keys: {missing[:5]}{'...' if len(missing)>5 else ''}")
281
- print(f" WARNING: {len(missing)} missing keys - weights NOT loaded for those layers!")
282
- if unexpected:
283
- print(f" Unexpected keys: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}")
284
- if not missing and not unexpected:
285
- print(f" OK: All keys matched perfectly!")
286
-
287
- return model_obj.to(device)
288
-
289
 
290
- # ─────────────────────────────────────────────────────────────────────────────
291
- # LOAD ALL THREE MODELS FROM HUGGING FACE HUB
292
- # Models are downloaded from Shree2604/BioStack repository
293
- # ─────────────────────────────────────────────────────────────────────────────
294
- def download_model_from_hf(model_filename: str, local_path: str = "models/") -> str:
295
- """Download model from Hugging Face Hub if not exists locally"""
296
- os.makedirs(local_path, exist_ok=True)
297
- full_path = os.path.join(local_path, model_filename)
298
-
299
- if not os.path.exists(full_path):
300
- print(f" Downloading {model_filename} from Hugging Face Hub...")
301
- try:
302
- downloaded_path = hf_hub_download(
303
- repo_id="Shree2604/BioStack",
304
- filename=model_filename,
305
- local_dir=local_path,
306
- local_dir_use_symlinks=False
307
  )
308
- print(f" Downloaded {model_filename}")
309
- return downloaded_path
310
- except Exception as e:
311
- print(f" Failed to download {model_filename}: {e}")
312
- raise
313
- else:
314
- print(f" Using local {model_filename}")
315
- return full_path
316
-
317
- print("\n" + "="*60)
318
- print(" LOADING MODELS FROM HUGGING FACE HUB")
319
- print("="*60)
320
-
321
- # Download models from Hugging Face
322
- SFT_MODEL_PATH = download_model_from_hf("best_model.pt")
323
- REWARD_MODEL_PATH = download_model_from_hf("reward_model.pt")
324
- PPO_MODEL_PATH = download_model_from_hf("rlhf_model.pt")
325
-
326
- # SFT
327
- _sft_enc = CoAtNetEncoder(pretrained=False)
328
- sft_model = load_model(SFT_MODEL_PATH, SFTVisionT5Model(_sft_enc), "SFT Model")
329
- sft_model.eval()
330
-
331
- # Reward
332
- _rm_enc = CoAtNetEncoder(pretrained=False)
333
- reward_model = load_model(REWARD_MODEL_PATH, RewardModel(_rm_enc), "Reward Model")
334
- reward_model.eval()
335
-
336
- # PPO
337
- _ppo_enc = CoAtNetEncoder(pretrained=False)
338
- ppo_model = load_model(PPO_MODEL_PATH, PPOVisionT5Model(_ppo_enc), "PPO Model")
339
- ppo_model.eval()
340
-
341
- print("\n All models loaded and ready!\n" + "="*60 + "\n")
342
 
 
 
343
 
344
- # ─────────────────────────────────────────────────────────────────────────────
345
- # IMAGE PREPROCESSING
346
- # Matches BOTH notebooks: RGB, 224Γ—224, ImageNet normalisation
347
- # ─────────────────────────────────────────────────────────────────────────────
348
- transform = transforms.Compose([
349
- transforms.Resize((224, 224)),
350
- transforms.ToTensor(),
351
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
352
- std=[0.229, 0.224, 0.225])
353
- ])
354
 
355
- def preprocess(file_bytes: bytes) -> torch.Tensor:
356
- img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
357
- return transform(img).unsqueeze(0).to(device) # [1, 3, 224, 224]
358
 
359
 
360
  # ─────────────────────────────────────────────────────────────────────────────
361
- # REWARD FEEDBACK GENERATOR
362
  # ─────────────────────────────────────────────────────────────────────────────
363
- KEY_MEDICAL_TERMS = [
364
- 'lung', 'heart', 'normal', 'clear', 'opacity', 'infiltrate',
365
- 'cardiomegaly', 'pleural', 'pulmonary', 'chest', 'thorax',
366
- 'pneumonia', 'edema', 'effusion', 'consolidation'
367
- ]
368
-
369
- def reward_feedback(report: str, score: float) -> str:
370
- rl = report.lower()
371
- present = [t for t in KEY_MEDICAL_TERMS if t in rl]
372
- missing = [t for t in KEY_MEDICAL_TERMS if t not in rl]
373
- words = len(report.split())
374
- length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long")
375
-
376
- # Quality factor assessments based on the score and analysis
377
- terminology_score = len(present) / len(KEY_MEDICAL_TERMS)
378
- completeness_score = min(1.0, words / 100.0) # Rough estimate based on length
379
- structure_score = 1.0 if 50 <= words <= 150 else 0.5 # Good structure if proper length
380
- radiological_score = score # The overall score represents alignment
381
-
382
- return (
383
- f"Reward Score: {score:.2f} | "
384
- f"Quality Factors - "
385
- f"Medical Terminology: {terminology_score:.1%} | "
386
- f"Clinical Completeness: {completeness_score:.1%} | "
387
- f"Report Structure: {structure_score:.1%}"
388
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  # ─────────────────────────────────────────────────────────────────────────────
392
  # FASTAPI APP
393
  # ─────────────────────────────────────────────────────────────────────────────
394
- app = FastAPI(title="RLHF Medical Demo")
395
 
396
  app.add_middleware(
397
  CORSMiddleware,
398
- allow_origins=["*"], # Allow all origins for Hugging Face Spaces
399
  allow_methods=["*"],
400
  allow_headers=["*"],
401
  )
402
 
403
 
 
 
 
 
 
 
404
  @app.get("/health")
405
  def health():
406
- return {"status": "ok", "device": str(device)}
 
 
 
 
 
407
 
408
 
409
  @app.post("/sft")
410
  async def sft_inference(file: UploadFile = File(...)):
 
 
 
411
  try:
412
- tensor = preprocess(await file.read())
413
- generated_ids = sft_model.generate_reports(tensor)
 
 
 
 
 
 
 
 
 
 
414
  report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
415
- # Strip any leading "Projection: X." prefix that leaked from training data
416
- if report.lower().startswith("projection:"):
417
- parts = report.split(".", 1)
418
- report = parts[1].strip() if len(parts) > 1 else report
419
  print(f"[SFT] Generated: {report}")
420
- return {"report": report[:81]}
421
- except Exception as e:
422
- traceback.print_exc()
423
- return {"report": f"ERROR: {str(e)}"}
424
-
425
-
426
- @app.post("/reward")
427
- async def reward_inference(file: UploadFile = File(...)):
428
- try:
429
- tensor = preprocess(await file.read())
430
-
431
- # First get the SFT report to score
432
- sft_generated_ids = sft_model.generate_reports(tensor)
433
- sft_report = tokenizer.decode(sft_generated_ids[0], skip_special_tokens=True).strip()
434
- # Strip any leading "Projection: X." prefix that leaked from training data
435
- if sft_report.lower().startswith("projection:"):
436
- parts = sft_report.split(".", 1)
437
- sft_report = parts[1].strip() if len(parts) > 1 else sft_report
438
- print(f"[REWARD] Scoring SFT report: {sft_report}")
439
-
440
- if not sft_report.strip():
441
- return {"score": 0.0, "feedback": "", "sft_report": ""}
442
-
443
- enc = tokenizer(
444
- [sft_report],
445
- max_length=128,
446
- padding="max_length",
447
- truncation=True,
448
- return_tensors="pt"
449
- )
450
- input_ids = enc.input_ids.to(device)
451
- attention_mask = enc.attention_mask.to(device)
452
-
453
- with torch.no_grad():
454
- raw_score = reward_model(tensor, input_ids, attention_mask).item()
455
-
456
- # Detailed debug logging
457
- print(f"[REWARD] Raw neural network output: {raw_score:.6f}")
458
- print(f"[REWARD] Clamping to [0,1] range: max(0.0, min(1.0, {raw_score:.6f})) = {max(0.0, min(1.0, raw_score)):.6f}")
459
-
460
- # Quality assessment details
461
- rl = sft_report.lower()
462
- present = [t for t in KEY_MEDICAL_TERMS if t in rl]
463
- missing = [t for t in KEY_MEDICAL_TERMS if t not in rl]
464
- words = len(sft_report.split())
465
- length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long")
466
-
467
- print(f"[REWARD] Report analysis:")
468
- print(f" - Total words: {words} ({length_q})")
469
- print(f" - Medical terms present ({len(present)}/{len(KEY_MEDICAL_TERMS)}): {present}")
470
- print(f" - Medical terms missing: {missing}")
471
- print(f" - Key terms list: {KEY_MEDICAL_TERMS}")
472
-
473
- # Reward model architecture details
474
- print(f"[REWARD] Model architecture:")
475
- print(f" - CoAtNet feature dim: {reward_model.img_encoder.feature_dim}")
476
- print(f" - T5 d_model: {reward_model.txt_encoder.config.d_model}")
477
- print(f" - Combined feature dim: 1024 (512 img + 512 text)")
478
- print(f" - Reward head: 1024β†’512β†’256β†’1")
479
-
480
- # Clamped score for display
481
- score = float(max(0.0, min(1.0, raw_score)))
482
- feedback = reward_feedback(sft_report, score)
483
- print(f"[REWARD] Final Score={score:.3f}")
484
- return {"score": score, "feedback": feedback, "sft_report": sft_report}
485
-
486
  except Exception as e:
487
  traceback.print_exc()
488
- return {"score": 0.0, "feedback": f"ERROR: {str(e)}", "sft_report": ""}
489
 
490
 
491
  @app.post("/ppo")
492
  async def ppo_inference(file: UploadFile = File(...)):
 
 
 
493
  try:
494
- tensor = preprocess(await file.read())
495
- generated_ids = ppo_model.generate_reports(tensor)
 
 
 
 
 
 
 
 
 
 
496
  report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
497
- # Strip any leading "Projection: X." prefix that leaked from training data
498
- if report.lower().startswith("projection:"):
499
- parts = report.split(".", 1)
500
- report = parts[1].strip() if len(parts) > 1 else report
501
  print(f"[PPO] Generated: {report}")
502
- return {"report": report}
 
 
 
503
  except Exception as e:
504
  traceback.print_exc()
505
- return {"report": f"ERROR: {str(e)}"}
506
 
507
 
508
- # ─────────────────────────────────────────────────────────────────────────────
509
- # DIAGNOSTIC ENDPOINT β€” call GET /debug_keys to verify key names in your files
510
- # e.g. curl http://localhost:8000/debug_keys
511
- # ─────────────────────────────────────────────────────────────────────────────
512
- @app.get("/debug_keys")
513
- def debug_keys():
514
- import os
515
- result = {}
516
- for label, path in [("SFT", SFT_MODEL_PATH), ("Reward", REWARD_MODEL_PATH), ("PPO", PPO_MODEL_PATH)]:
517
- if not os.path.exists(path):
518
- result[label] = f"FILE NOT FOUND: {path}"
519
- continue
520
- try:
521
- sd = torch.load(path, map_location="cpu")
522
- keys = list(sd.keys())
523
- result[label] = {"first_10_keys": keys[:10], "total_keys": len(keys)}
524
- except Exception as e:
525
- result[label] = f"ERROR: {e}"
526
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
 
529
  # ─────────────────────────────────────────────────────────────────────────────
530
- # STATIC FILE SERVING - Mount React build directory AFTER all API routes
531
  # ─────────────────────────────────────────────────────────────────────────────
532
  from fastapi.staticfiles import StaticFiles
533
- import os
534
 
535
- # Check if build directory exists, create fallback if needed
536
  if os.path.exists("build"):
537
  app.mount("/", StaticFiles(directory="build", html=True), name="static")
538
  print("βœ… React app mounted at /")
 
2
  import torch
3
  import torch.nn as nn
4
  import timm
 
5
  import traceback
6
  import os
7
  from PIL import Image
 
12
  from huggingface_hub import hf_hub_download
13
 
14
  # ─────────────────────────────────────────────────────────────────────────────
15
+ # CONFIGURATION - Matching Colab Notebook Exactly
16
  # ─────────────────────────────────────────────────────────────────────────────
17
  CONFIG = {
18
  'coatnet_model': 'coatnet_1_rw_224',
 
20
  'img_emb_dim': 768,
21
  'train_last_stages': 2,
22
  'image_size': 224,
23
+ 'max_length': 100,
24
+ 'num_beams': 4,
25
  }
26
 
27
  # ─────────────────────────────────────────────────────────────────────────────
 
31
  print(f"πŸ–₯️ Using device: {device}")
32
 
33
  # ─────────────────────────────────────────────────────────────────────────────
34
+ # LOAD TOKENIZER - Matching Colab
35
  # ─────────────────────────────────────────────────────────────────────────────
 
36
  print("\n" + "="*80)
37
+ print("LOADING TOKENIZER")
38
  print("="*80)
 
 
39
  tokenizer = T5Tokenizer.from_pretrained(CONFIG['t5_model'])
40
  print(f"βœ“ Loaded tokenizer: {CONFIG['t5_model']}")
41
 
42
+ # ─────────────────────────────────────────────────────────────────────────────
43
+ # IMAGE TRANSFORM - Matching Colab Exactly
44
+ # ─────────────────────────────────────────────────────────────────────────────
45
  transform = transforms.Compose([
46
  transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
47
  transforms.ToTensor(),
 
52
  ])
53
  print(f"βœ“ Image transform defined (size: {CONFIG['image_size']}x{CONFIG['image_size']})")
54
 
 
 
 
 
 
55
  # ─────────────────────────────────────────────────────────────────────────────
56
+ # ARCHITECTURE 1: CoAtNetEncoder - Exactly from Colab SECTION 6
 
57
  # ─────────────────────────────────────────────────────────────────────────────
58
  class CoAtNetEncoder(nn.Module):
59
+ def __init__(self, model_name="coatnet_1_rw_224", pretrained=True, train_last_stages=2):
60
  super().__init__()
61
+ self.encoder = timm.create_model(
62
+ model_name,
63
+ pretrained=pretrained,
64
+ num_classes=0,
65
+ global_pool="avg"
66
+ )
 
 
 
 
 
 
 
67
 
68
+ # Freeze all parameters
69
+ for p in self.encoder.parameters():
70
+ p.requires_grad = False
 
 
 
 
71
 
72
+ # Unfreeze last stages
73
+ if hasattr(self.encoder, "stages") and train_last_stages is not None:
74
+ stages = self.encoder.stages
75
+ for stage in stages[-train_last_stages:]:
76
+ for p in stage.parameters():
77
+ p.requires_grad = True
78
 
79
  def forward(self, x):
80
+ return self.encoder(x)
 
 
 
81
 
82
 
83
  # ─────────────────────────────────────────────────────────────────────────────
84
+ # ARCHITECTURE 2: VisionT5Model - Exactly from Colab SECTION 6
 
 
85
  # ─────────────────────────────────────────────────────────────────────────────
86
+ class VisionT5Model(nn.Module):
87
  def __init__(self, img_encoder, txt_model_name="t5-small", img_emb_dim=768):
88
  super().__init__()
89
+
90
+ # Vision encoder (CoAtNet)
91
  self.img_encoder = img_encoder
92
+
93
+ # Text decoder (T5)
94
  self.t5 = T5ForConditionalGeneration.from_pretrained(txt_model_name)
95
+
96
+ # Projection layer to match image features with T5 d_model
97
  self.proj = nn.Linear(img_emb_dim, self.t5.config.d_model)
98
 
99
+ # Freeze shared T5 embeddings for faster and stable training
100
  for p in self.t5.shared.parameters():
101
  p.requires_grad = False
102
 
103
+ def forward(self, pixel_values, input_ids, attention_mask, labels=None):
104
+ # Extract image features
105
+ img_feats = self.img_encoder(pixel_values)
 
 
 
 
 
 
 
 
 
106
 
107
+ # Project image features to T5 embedding space
108
+ img_feats = self.proj(img_feats)
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # Add sequence dimension
111
+ encoder_hidden_states = img_feats.unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ # Run T5 encoder using image embeddings
114
+ encoder_outputs = self.t5.encoder(
115
+ inputs_embeds=encoder_hidden_states
116
+ )
 
117
 
118
+ # Run T5 decoder and compute loss
119
+ outputs = self.t5(
120
+ encoder_outputs=encoder_outputs,
121
+ attention_mask=torch.ones(
122
+ encoder_hidden_states.size()[:2], device=device
123
+ ),
124
+ input_ids=input_ids,
125
+ labels=labels,
126
+ )
127
+ return outputs
128
+
129
+ def generate_reports(self, pixel_values, max_length=100, num_beams=4):
130
+ """
131
+ Generate reports - EXACTLY matching Colab SECTION 6
132
+ """
133
+ # Extract and project image features
134
+ img_feats = self.img_encoder(pixel_values)
135
+ img_feats = self.proj(img_feats)
136
+ encoder_hidden_states = img_feats.unsqueeze(1)
137
+
138
+ # Encode image features
139
+ encoder_outputs = self.t5.encoder(
140
+ inputs_embeds=encoder_hidden_states
141
+ )
142
 
143
+ # Generate report using beam search - EXACT parameters from Colab
144
+ generated_ids = self.t5.generate(
145
+ encoder_outputs=encoder_outputs,
146
+ attention_mask=torch.ones(
147
+ encoder_hidden_states.size()[:2], device=device
148
+ ),
149
+ max_length=max_length,
150
+ num_beams=num_beams,
151
+ early_stopping=True
152
+ )
153
 
154
+ return generated_ids
 
 
 
 
 
 
 
 
 
155
 
 
156
 
157
+ print("βœ“ Model architecture classes defined")
158
 
159
  # ─────────────────────────────────────────────────────────────────────────────
160
+ # MODEL LOADING FUNCTION - Exactly from Colab SECTION 8
 
161
  # ─────────────────────────────────────────────────────────────────────────────
162
+ def load_model_from_checkpoint(checkpoint_path: str, model_name: str, config: dict):
163
+ """
164
+ Load VisionT5Model from checkpoint - EXACT implementation from Colab
165
+ """
166
+ print(f"\nLoading {model_name} model...")
167
+ print(f" Checkpoint: {checkpoint_path}")
168
+
169
+ try:
170
+ # Create image encoder
171
+ print(f" Creating CoAtNet encoder: {config['coatnet_model']}")
172
+ img_encoder = CoAtNetEncoder(
173
+ model_name=config['coatnet_model'],
174
+ pretrained=False, # Weights will come from checkpoint
175
+ train_last_stages=config['train_last_stages']
176
  )
177
 
178
+ # Create full model
179
+ print(f" Creating VisionT5 model with T5: {config['t5_model']}")
180
+ model = VisionT5Model(
181
+ img_encoder=img_encoder,
182
+ txt_model_name=config['t5_model'],
183
+ img_emb_dim=config['img_emb_dim']
184
+ )
185
+
186
+ # Load checkpoint
187
+ print(f" Loading checkpoint weights...")
188
+ checkpoint = torch.load(checkpoint_path, map_location=device)
189
+
190
+ # Handle different checkpoint formats
191
+ if isinstance(checkpoint, dict):
192
+ if 'model_state_dict' in checkpoint:
193
+ state_dict = checkpoint['model_state_dict']
194
+ print(f" Found 'model_state_dict' in checkpoint")
195
+ elif 'state_dict' in checkpoint:
196
+ state_dict = checkpoint['state_dict']
197
+ print(f" Found 'state_dict' in checkpoint")
198
+ elif 'model' in checkpoint:
199
+ state_dict = checkpoint['model']
200
+ print(f" Found 'model' in checkpoint")
201
+ else:
202
+ # Assume checkpoint is the state dict
203
+ state_dict = checkpoint
204
+ print(f" Using checkpoint as state_dict directly")
205
+
206
+ # Print additional checkpoint info if available
207
+ if 'epoch' in checkpoint:
208
+ print(f" Checkpoint epoch: {checkpoint['epoch']}")
209
+ if 'loss' in checkpoint:
210
+ print(f" Checkpoint loss: {checkpoint['loss']:.4f}")
211
+ else:
212
+ state_dict = checkpoint
213
+ print(f" Checkpoint is a state_dict")
214
+
215
+ # Load state dict
216
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
217
+
218
+ if missing_keys:
219
+ print(f" ⚠️ Missing keys: {len(missing_keys)}")
220
+ if len(missing_keys) <= 5:
221
+ for key in missing_keys:
222
+ print(f" - {key}")
223
+
224
+ if unexpected_keys:
225
+ print(f" ⚠️ Unexpected keys: {len(unexpected_keys)}")
226
+ if len(unexpected_keys) <= 5:
227
+ for key in unexpected_keys:
228
+ print(f" - {key}")
229
+
230
+ # Move to device and set to eval mode
231
+ model = model.to(device)
232
+ model.eval()
233
+
234
+ print(f"βœ“ {model_name} model loaded successfully!")
235
+ return model
236
+
237
+ except Exception as e:
238
+ print(f"❌ Error loading {model_name} model: {str(e)}")
239
+ import traceback
240
+ traceback.print_exc()
241
+ raise
242
 
243
 
244
  # ─────────────────────────────────────────────────────────────────────────────
245
+ # INFERENCE FUNCTION - Exactly from Colab SECTION 9
 
246
  # ─────────────────────────────────────────────────────────────────────────────
247
+ def generate_report(
248
+ image_path: str,
249
+ model: VisionT5Model,
250
+ config: dict
251
+ ) -> str:
252
  """
253
+ Generate medical report from X-ray image - EXACT implementation from Colab
 
 
 
 
 
 
 
 
 
 
254
  """
255
+ try:
256
+ # Preprocess image
257
+ image = Image.open(image_path).convert('RGB')
258
+ pixel_values = transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ # Generate report - using EXACT parameters from Colab
261
+ with torch.no_grad():
262
+ generated_ids = model.generate_reports(
263
+ pixel_values,
264
+ max_length=config['max_length'],
265
+ num_beams=config['num_beams']
 
 
 
 
 
 
 
 
 
 
 
266
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ # Decode
269
+ report = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
270
 
271
+ return report.strip()
 
 
 
 
 
 
 
 
 
272
 
273
+ except Exception as e:
274
+ print(f"Error generating report for {image_path}: {str(e)}")
275
+ return ""
276
 
277
 
278
  # ─────────────────────────────────────────────────────────────────────────────
279
+ # LOAD MODELS FROM HUGGINGFACE
280
  # ─────────────────────────────────────────────────────────────────────────────
281
+ print("\n" + "="*80)
282
+ print("LOADING MODELS FROM HUGGINGFACE")
283
+ print("="*80)
284
+
285
+ # Download model files from Hugging Face
286
+ try:
287
+ SFT_MODEL_PATH = hf_hub_download(
288
+ repo_id="vinaykumarhs2020/RLHF_radiology_model",
289
+ filename="best_model.pt"
290
+ )
291
+ PPO_MODEL_PATH = hf_hub_download(
292
+ repo_id="vinaykumarhs2020/RLHF_radiology_model",
293
+ filename="rlhf_model.pt"
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
+ print(f"βœ“ Downloaded SFT model: {SFT_MODEL_PATH}")
296
+ print(f"βœ“ Downloaded PPO model: {PPO_MODEL_PATH}")
297
+ except Exception as e:
298
+ print(f"❌ Error downloading models: {e}")
299
+ # Fallback to local paths if downloads fail
300
+ SFT_MODEL_PATH = "/content/best_model.pt"
301
+ PPO_MODEL_PATH = "/content/rlhf_model.pt"
302
+ print(f"⚠️ Using local paths instead")
303
+
304
+ # Load both models
305
+ print("\n" + "="*80)
306
+ print("LOADING MODELS")
307
+ print("="*80)
308
 
309
+ sft_model = load_model_from_checkpoint(
310
+ SFT_MODEL_PATH,
311
+ "SFT",
312
+ CONFIG
313
+ )
314
+
315
+ ppo_model = load_model_from_checkpoint(
316
+ PPO_MODEL_PATH,
317
+ "PPO",
318
+ CONFIG
319
+ )
320
+
321
+ print("\nβœ“ Both models loaded successfully!")
322
 
323
  # ─────────────────────────────────────────────────────────────────────────────
324
  # FASTAPI APP
325
  # ─────────────────────────────────────────────────────────────────────────────
326
+ app = FastAPI(title="Medical Report Generation - Matching Colab")
327
 
328
  app.add_middleware(
329
  CORSMiddleware,
330
+ allow_origins=["*"],
331
  allow_methods=["*"],
332
  allow_headers=["*"],
333
  )
334
 
335
 
336
+ def preprocess_bytes(file_bytes: bytes) -> torch.Tensor:
337
+ """Preprocess image bytes for inference"""
338
+ img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
339
+ return transform(img).unsqueeze(0).to(device)
340
+
341
+
342
  @app.get("/health")
343
  def health():
344
+ return {
345
+ "status": "ok",
346
+ "device": str(device),
347
+ "models_loaded": True,
348
+ "config": CONFIG
349
+ }
350
 
351
 
352
  @app.post("/sft")
353
  async def sft_inference(file: UploadFile = File(...)):
354
+ """
355
+ SFT model inference - EXACTLY matching Colab behavior
356
+ """
357
  try:
358
+ # Preprocess image
359
+ tensor = preprocess_bytes(await file.read())
360
+
361
+ # Generate report using EXACT Colab parameters
362
+ with torch.no_grad():
363
+ generated_ids = sft_model.generate_reports(
364
+ tensor,
365
+ max_length=CONFIG['max_length'],
366
+ num_beams=CONFIG['num_beams']
367
+ )
368
+
369
+ # Decode - EXACTLY as Colab does
370
  report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
371
+
 
 
 
372
  print(f"[SFT] Generated: {report}")
373
+
374
+ # Return FULL report without truncation
375
+ return {"report": report, "model": "SFT", "config_used": CONFIG}
376
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  except Exception as e:
378
  traceback.print_exc()
379
+ return {"report": f"ERROR: {str(e)}", "model": "SFT"}
380
 
381
 
382
  @app.post("/ppo")
383
  async def ppo_inference(file: UploadFile = File(...)):
384
+ """
385
+ PPO model inference - EXACTLY matching Colab behavior
386
+ """
387
  try:
388
+ # Preprocess image
389
+ tensor = preprocess_bytes(await file.read())
390
+
391
+ # Generate report using EXACT Colab parameters
392
+ with torch.no_grad():
393
+ generated_ids = ppo_model.generate_reports(
394
+ tensor,
395
+ max_length=CONFIG['max_length'],
396
+ num_beams=CONFIG['num_beams']
397
+ )
398
+
399
+ # Decode - EXACTLY as Colab does
400
  report = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
401
+
 
 
 
402
  print(f"[PPO] Generated: {report}")
403
+
404
+ # Return FULL report without truncation
405
+ return {"report": report, "model": "PPO", "config_used": CONFIG}
406
+
407
  except Exception as e:
408
  traceback.print_exc()
409
+ return {"report": f"ERROR: {str(e)}", "model": "PPO"}
410
 
411
 
412
+ @app.post("/compare")
413
+ async def compare_models(file: UploadFile = File(...)):
414
+ """
415
+ Generate reports from both models for comparison
416
+ """
417
+ try:
418
+ file_bytes = await file.read()
419
+ tensor = preprocess_bytes(file_bytes)
420
+
421
+ # SFT Generation
422
+ with torch.no_grad():
423
+ sft_ids = sft_model.generate_reports(
424
+ tensor,
425
+ max_length=CONFIG['max_length'],
426
+ num_beams=CONFIG['num_beams']
427
+ )
428
+ sft_report = tokenizer.decode(sft_ids[0], skip_special_tokens=True).strip()
429
+
430
+ # PPO Generation
431
+ with torch.no_grad():
432
+ ppo_ids = ppo_model.generate_reports(
433
+ tensor,
434
+ max_length=CONFIG['max_length'],
435
+ num_beams=CONFIG['num_beams']
436
+ )
437
+ ppo_report = tokenizer.decode(ppo_ids[0], skip_special_tokens=True).strip()
438
+
439
+ print(f"[COMPARE] SFT: {sft_report}")
440
+ print(f"[COMPARE] PPO: {ppo_report}")
441
+
442
+ return {
443
+ "sft_report": sft_report,
444
+ "ppo_report": ppo_report,
445
+ "config_used": CONFIG
446
+ }
447
+
448
+ except Exception as e:
449
+ traceback.print_exc()
450
+ return {
451
+ "sft_report": f"ERROR: {str(e)}",
452
+ "ppo_report": f"ERROR: {str(e)}"
453
+ }
454
+
455
+
456
+ @app.get("/debug_config")
457
+ def debug_config():
458
+ """Debug endpoint to check configuration"""
459
+ return {
460
+ "config": CONFIG,
461
+ "device": str(device),
462
+ "tokenizer": CONFIG['t5_model'],
463
+ "image_size": CONFIG['image_size'],
464
+ "max_length": CONFIG['max_length'],
465
+ "num_beams": CONFIG['num_beams'],
466
+ "models_loaded": {
467
+ "sft": sft_model is not None,
468
+ "ppo": ppo_model is not None
469
+ }
470
+ }
471
 
472
 
473
  # ─────────────────────────────────────────────────────────────────────────────
474
+ # STATIC FILE SERVING
475
  # ─────────────────────────────────────────────────────────────────────────────
476
  from fastapi.staticfiles import StaticFiles
 
477
 
 
478
  if os.path.exists("build"):
479
  app.mount("/", StaticFiles(directory="build", html=True), name="static")
480
  print("βœ… React app mounted at /")