eslamESssamM commited on
Commit
8bb5fb0
·
verified ·
1 Parent(s): e8ca3c4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -0
main.py CHANGED
@@ -185,6 +185,9 @@ async def lifespan(app: FastAPI):
185
  if os.path.exists("DECODE_Frequency_Twin.pth"):
186
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
187
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
 
 
 
188
  model = FrequencyPINN()
189
  model.load_state_dict(sd, strict=True)
190
  ml_assets["f_model"] = model.eval()
 
185
  if os.path.exists("DECODE_Frequency_Twin.pth"):
186
  ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
187
  sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
188
+ # Remove unexpected 'theta_h' key if present
189
+ if 'theta_h' in sd:
190
+ sd = {k: v for k, v in sd.items() if k != 'theta_h'}
191
  model = FrequencyPINN()
192
  model.load_state_dict(sd, strict=True)
193
  ml_assets["f_model"] = model.eval()