Segmentation / model.py
KaranNag's picture
Update model.py
8769ca8 verified
import torch
from monai.networks.nets import DynUNet
import os
def load_model(model_path="best_model_large_data.pth", device="cpu"):
"""Load DynUNet model with weights"""
try:
model = DynUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
kernel_size=[3, 3, 3, 3, 3],
strides=[1, 2, 2, 2, 2],
upsample_kernel_size=[2, 2, 2, 2],
filters=[32, 64, 128, 256, 512],
norm_name="INSTANCE",
res_block=True,
deep_supervision=False,
)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
except Exception as e:
print(f"❌ Model initialization failed: {e}")
raise
def predict_mask(model, image_tensor):
"""Predict segmentation mask with sigmoid activation."""
try:
if image_tensor.dim() != 4 or image_tensor.shape[1] != 1:
raise ValueError(f"Input tensor must be [1, 1, H, W]. Got {image_tensor.shape}")
with torch.no_grad():
return torch.sigmoid(model(image_tensor))
except Exception as e:
print(f"Prediction failed: {e}")
raise