Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -185,6 +185,9 @@ async def lifespan(app: FastAPI):
|
|
| 185 |
if os.path.exists("DECODE_Frequency_Twin.pth"):
|
| 186 |
ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
|
| 187 |
sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
|
|
|
|
|
|
|
|
|
|
| 188 |
model = FrequencyPINN()
|
| 189 |
model.load_state_dict(sd, strict=True)
|
| 190 |
ml_assets["f_model"] = model.eval()
|
|
|
|
| 185 |
if os.path.exists("DECODE_Frequency_Twin.pth"):
|
| 186 |
ckpt = torch.load("DECODE_Frequency_Twin.pth", map_location='cpu')
|
| 187 |
sd = ckpt['model_state_dict'] if isinstance(ckpt, dict) and 'model_state_dict' in ckpt else ckpt
|
| 188 |
+
# Remove unexpected 'theta_h' key if present
|
| 189 |
+
if 'theta_h' in sd:
|
| 190 |
+
sd = {k: v for k, v in sd.items() if k != 'theta_h'}
|
| 191 |
model = FrequencyPINN()
|
| 192 |
model.load_state_dict(sd, strict=True)
|
| 193 |
ml_assets["f_model"] = model.eval()
|