import torch from monai.networks.nets import DynUNet import os def load_model(model_path="best_model_large_data.pth", device="cpu"): """Load DynUNet model with weights""" try: model = DynUNet( spatial_dims=2, in_channels=1, out_channels=1, kernel_size=[3, 3, 3, 3, 3], strides=[1, 2, 2, 2, 2], upsample_kernel_size=[2, 2, 2, 2], filters=[32, 64, 128, 256, 512], norm_name="INSTANCE", res_block=True, deep_supervision=False, ) state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() return model except Exception as e: print(f"❌ Model initialization failed: {e}") raise def predict_mask(model, image_tensor): """Predict segmentation mask with sigmoid activation.""" try: if image_tensor.dim() != 4 or image_tensor.shape[1] != 1: raise ValueError(f"Input tensor must be [1, 1, H, W]. Got {image_tensor.shape}") with torch.no_grad(): return torch.sigmoid(model(image_tensor)) except Exception as e: print(f"Prediction failed: {e}") raise