Shree2604 commited on
Commit
b9b96ce
·
verified ·
1 Parent(s): 361b4d2

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +93 -60
server.py CHANGED
@@ -113,7 +113,7 @@ class SFTVisionT5Model(nn.Module):
113
  p.requires_grad = False
114
 
115
  def generate_reports(self, pixel_values, max_length=100):
116
- self.eval()
117
  with torch.no_grad():
118
  # Extract + project image features
119
  img_feats = self.img_encoder(pixel_values) # [B, feature_dim]
@@ -167,7 +167,6 @@ class PPOVisionT5Model(nn.Module):
167
  self.img_proj = nn.Linear(img_emb_dim, self.txt_model.config.d_model)
168
 
169
  def generate_reports(self, images, max_length=128):
170
- self.eval()
171
  with torch.no_grad():
172
  img_features = self.img_encoder(images) # [B, feature_dim]
173
  img_emb = self.img_proj(img_features).unsqueeze(1) # [B, 1, d_model]
@@ -427,10 +426,22 @@ def health():
427
  @app.post("/sft")
428
  async def sft_inference(file: UploadFile = File(...)):
429
  try:
430
- tensor = preprocess(await file.read())
431
- report = sft_model.generate_reports(tensor)[0]
432
- print(f"[SFT] Generated: {report}")
433
- return {"report": report[:81]}
 
 
 
 
 
 
 
 
 
 
 
 
434
  except Exception as e:
435
  traceback.print_exc()
436
  return {"report": f"ERROR: {str(e)}"}
@@ -439,57 +450,67 @@ async def sft_inference(file: UploadFile = File(...)):
439
  @app.post("/reward")
440
  async def reward_inference(file: UploadFile = File(...)):
441
  try:
442
- tensor = preprocess(await file.read())
443
-
444
- # First get the SFT report to score
445
- sft_report = sft_model.generate_reports(tensor)[0]
446
- print(f"[REWARD] Scoring SFT report: {sft_report}")
447
-
448
- if not sft_report.strip():
449
- return {"score": 0.0, "feedback": "", "sft_report": ""}
450
-
451
- enc = tokenizer(
452
- [sft_report],
453
- max_length=128,
454
- padding="max_length",
455
- truncation=True,
456
- return_tensors="pt"
457
- )
458
- input_ids = enc.input_ids.to(device)
459
- attention_mask = enc.attention_mask.to(device)
460
 
461
- with torch.no_grad():
462
- raw_score = reward_model(tensor, input_ids, attention_mask).item()
463
-
464
- # Detailed debug logging
465
- print(f"[REWARD] Raw neural network output: {raw_score:.6f}")
466
- print(f"[REWARD] Clamping to [0,1] range: max(0.0, min(1.0, {raw_score:.6f})) = {max(0.0, min(1.0, raw_score)):.6f}")
467
-
468
- # Quality assessment details
469
- rl = sft_report.lower()
470
- present = [t for t in KEY_MEDICAL_TERMS if t in rl]
471
- missing = [t for t in KEY_MEDICAL_TERMS if t not in rl]
472
- words = len(sft_report.split())
473
- length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long")
474
-
475
- print(f"[REWARD] Report analysis:")
476
- print(f" - Total words: {words} ({length_q})")
477
- print(f" - Medical terms present ({len(present)}/{len(KEY_MEDICAL_TERMS)}): {present}")
478
- print(f" - Medical terms missing: {missing}")
479
- print(f" - Key terms list: {KEY_MEDICAL_TERMS}")
480
-
481
- # Reward model architecture details
482
- print(f"[REWARD] Model architecture:")
483
- print(f" - CoAtNet feature dim: {reward_model.img_encoder.feature_dim}")
484
- print(f" - T5 d_model: {reward_model.txt_encoder.config.d_model}")
485
- print(f" - Combined feature dim: 1024 (512 img + 512 text)")
486
- print(f" - Reward head: 1024→512→256→1")
487
-
488
- # Clamped score for display
489
- score = float(max(0.0, min(1.0, raw_score)))
490
- feedback = reward_feedback(sft_report, score)
491
- print(f"[REWARD] Final Score={score:.3f}")
492
- return {"score": score, "feedback": feedback, "sft_report": sft_report}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  except Exception as e:
495
  traceback.print_exc()
@@ -499,10 +520,22 @@ async def reward_inference(file: UploadFile = File(...)):
499
  @app.post("/ppo")
500
  async def ppo_inference(file: UploadFile = File(...)):
501
  try:
502
- tensor = preprocess(await file.read())
503
- report = ppo_model.generate_reports(tensor)[0]
504
- print(f"[PPO] Generated: {report}")
505
- return {"report": report}
 
 
 
 
 
 
 
 
 
 
 
 
506
  except Exception as e:
507
  traceback.print_exc()
508
  return {"report": f"ERROR: {str(e)}"}
 
113
  p.requires_grad = False
114
 
115
  def generate_reports(self, pixel_values, max_length=100):
116
+ # Removed self.eval() to match Colab behavior
117
  with torch.no_grad():
118
  # Extract + project image features
119
  img_feats = self.img_encoder(pixel_values) # [B, feature_dim]
 
167
  self.img_proj = nn.Linear(img_emb_dim, self.txt_model.config.d_model)
168
 
169
  def generate_reports(self, images, max_length=128):
 
170
  with torch.no_grad():
171
  img_features = self.img_encoder(images) # [B, feature_dim]
172
  img_emb = self.img_proj(img_features).unsqueeze(1) # [B, 1, d_model]
 
426
  @app.post("/sft")
427
  async def sft_inference(file: UploadFile = File(...)):
428
  try:
429
+ # Save uploaded file to temp path (matching Colab approach)
430
+ import tempfile
431
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
432
+ temp_file.write(await file.read())
433
+ temp_path = temp_file.name
434
+
435
+ try:
436
+ # Use file path preprocessing (exact Colab match)
437
+ tensor = preprocess_image(temp_path)
438
+ report = sft_model.generate_reports(tensor)[0]
439
+ print(f"[SFT] Generated: {report}")
440
+ return {"report": report[:81]}
441
+ finally:
442
+ # Clean up temp file
443
+ os.unlink(temp_path)
444
+
445
  except Exception as e:
446
  traceback.print_exc()
447
  return {"report": f"ERROR: {str(e)}"}
 
450
  @app.post("/reward")
451
  async def reward_inference(file: UploadFile = File(...)):
452
  try:
453
+ # Save uploaded file to temp path (matching Colab approach)
454
+ import tempfile
455
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
456
+ temp_file.write(await file.read())
457
+ temp_path = temp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
+ try:
460
+ # Use file path preprocessing (exact Colab match)
461
+ tensor = preprocess_image(temp_path)
462
+ # First get the SFT report to score
463
+ sft_report = sft_model.generate_reports(tensor)[0]
464
+ print(f"[REWARD] Scoring SFT report: {sft_report}")
465
+
466
+ if not sft_report.strip():
467
+ return {"score": 0.0, "feedback": "", "sft_report": ""}
468
+
469
+ enc = tokenizer(
470
+ [sft_report],
471
+ max_length=128,
472
+ padding="max_length",
473
+ truncation=True,
474
+ return_tensors="pt"
475
+ )
476
+ input_ids = enc.input_ids.to(device)
477
+ attention_mask = enc.attention_mask.to(device)
478
+
479
+ with torch.no_grad():
480
+ raw_score = reward_model(tensor, input_ids, attention_mask).item()
481
+
482
+ # Detailed debug logging
483
+ print(f"[REWARD] Raw neural network output: {raw_score:.6f}")
484
+ print(f"[REWARD] Clamping to [0,1] range: max(0.0, min(1.0, {raw_score:.6f})) = {max(0.0, min(1.0, raw_score)):.6f}")
485
+
486
+ # Quality assessment details
487
+ rl = sft_report.lower()
488
+ present = [t for t in KEY_MEDICAL_TERMS if t in rl]
489
+ missing = [t for t in KEY_MEDICAL_TERMS if t not in rl]
490
+ words = len(sft_report.split())
491
+ length_q = "good" if 50 <= words <= 150 else ("too short" if words < 50 else "too long")
492
+
493
+ print(f"[REWARD] Report analysis:")
494
+ print(f" - Total words: {words} ({length_q})")
495
+ print(f" - Medical terms present ({len(present)}/{len(KEY_MEDICAL_TERMS)}): {present}")
496
+ print(f" - Medical terms missing: {missing}")
497
+ print(f" - Key terms list: {KEY_MEDICAL_TERMS}")
498
+
499
+ # Reward model architecture details
500
+ print(f"[REWARD] Model architecture:")
501
+ print(f" - CoAtNet feature dim: {reward_model.img_encoder.feature_dim}")
502
+ print(f" - T5 d_model: {reward_model.txt_encoder.config.d_model}")
503
+ print(f" - Combined feature dim: 1024 (512 img + 512 text)")
504
+ print(f" - Reward head: 1024→512→256→1")
505
+
506
+ # Clamped score for display
507
+ score = float(max(0.0, min(1.0, raw_score)))
508
+ feedback = reward_feedback(sft_report, score)
509
+ print(f"[REWARD] Final Score={score:.3f}")
510
+ return {"score": score, "feedback": feedback, "sft_report": sft_report}
511
+ finally:
512
+ # Clean up temp file
513
+ os.unlink(temp_path)
514
 
515
  except Exception as e:
516
  traceback.print_exc()
 
520
  @app.post("/ppo")
521
  async def ppo_inference(file: UploadFile = File(...)):
522
  try:
523
+ # Save uploaded file to temp path (matching Colab approach)
524
+ import tempfile
525
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
526
+ temp_file.write(await file.read())
527
+ temp_path = temp_file.name
528
+
529
+ try:
530
+ # Use file path preprocessing (exact Colab match)
531
+ tensor = preprocess_image(temp_path)
532
+ report = ppo_model.generate_reports(tensor)[0]
533
+ print(f"[PPO] Generated: {report}")
534
+ return {"report": report}
535
+ finally:
536
+ # Clean up temp file
537
+ os.unlink(temp_path)
538
+
539
  except Exception as e:
540
  traceback.print_exc()
541
  return {"report": f"ERROR: {str(e)}"}