Spaces:
Runtime error
Runtime error
| import sys | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from skimage.measure import label, regionprops | |
| import matplotlib.pyplot as plt | |
| from matplotlib import cm | |
| from io import BytesIO | |
| from PIL import Image, ImageOps, ImageDraw | |
| from skimage.transform import resize | |
| from skimage.measure import label, regionprops # From scikit-image | |
| from PIL import Image, ImageOps, ImageDraw | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| class GradCAMSegmentation: | |
| def __init__(self, model, target_layer_name): | |
| self.model = model | |
| self.target_layer = self._find_layer(target_layer_name) | |
| self.activations = None | |
| self.gradients = None | |
| self._register_hooks() | |
| def _find_layer(self, target_layer_name): | |
| """More robust layer finding implementation""" | |
| module = self.model | |
| for attr in target_layer_name.split('.'): | |
| try: | |
| if attr.isdigit(): | |
| module = module[int(attr)] | |
| else: | |
| module = getattr(module, attr) | |
| except (AttributeError, IndexError) as e: | |
| raise ValueError(f"Could not find layer {target_layer_name}: {str(e)}") | |
| return module | |
| def _register_hooks(self): | |
| def forward_hook(module, input, output): | |
| self.activations = output.detach() | |
| def backward_hook(module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| self.target_layer.register_forward_hook(forward_hook) | |
| self.target_layer.register_backward_hook(backward_hook) | |
| def __call__(self, input_tensor): | |
| self.model.zero_grad() | |
| output = self.model(input_tensor) | |
| output.mean().backward() | |
| # More robust gradient calculation | |
| if self.gradients is None: | |
| raise RuntimeError("Gradients not captured - check hook registration") | |
| weights = torch.mean(self.gradients, dim=[2, 3], keepdim=True) | |
| cam = torch.sum(self.activations * weights, dim=1, keepdim=True) | |
| cam = torch.relu(cam) | |
| # Better normalization | |
| cam_min, cam_max = torch.min(cam), torch.max(cam) | |
| if cam_max - cam_min > 1e-8: | |
| cam = (cam - cam_min) / (cam_max - cam_min) | |
| else: | |
| cam = torch.zeros_like(cam) | |
| return cam.squeeze().cpu().numpy() | |
| def preprocess_image_pil(image): | |
| """Convert PIL image to properly shaped torch tensor""" | |
| try: | |
| # Convert to grayscale | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| # Validate original size | |
| if not all(dim > 0 for dim in image.size): | |
| raise ValueError(f"Invalid image dimensions: {image.size}") | |
| # Transform pipeline | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| ]) | |
| tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| logger.debug(f"Output tensor shape: {tensor.shape}") | |
| if tensor.shape != (1, 1, 256, 256): | |
| raise ValueError(f"Invalid output shape: {tensor.shape}") | |
| return tensor, image | |
| except Exception as e: | |
| logger.error(f"Preprocessing failed: {e}") | |
| return None, None | |
| def postprocess_mask(mask_tensor, original_size): | |
| """Convert model output to PIL mask with thresholding""" | |
| try: | |
| if not isinstance(mask_tensor, torch.Tensor): | |
| raise ValueError("Input must be a torch.Tensor") | |
| mask_np = mask_tensor.squeeze().cpu().numpy() | |
| logger.info(f"Mask range before threshold: {mask_np.min():.2f}-{mask_np.max():.2f}") | |
| # Apply threshold and convert to binary mask | |
| binary_mask = (mask_np > 0.5).astype(np.uint8) * 255 | |
| mask_img = Image.fromarray(binary_mask).convert("L") | |
| # Resize to original dimensions | |
| if original_size != (256, 256): | |
| mask_img = mask_img.resize(original_size) | |
| logger.info(f"Resized mask to original size: {original_size}") | |
| return mask_img | |
| except Exception as e: | |
| logger.error(f"Postprocessing failed: {str(e)}", exc_info=True) | |
| raise | |
| def overlay_mask(original, mask): | |
| """Create overlay visualization with validation""" | |
| try: | |
| if not all(isinstance(img, Image.Image) for img in [original, mask]): | |
| raise ValueError("Both inputs must be PIL Images") | |
| # Convert original to RGB if needed | |
| if original.mode != 'RGB': | |
| original = original.convert("RGB") | |
| # Create red mask overlay | |
| mask_gray = mask.convert("L") | |
| red_mask = ImageOps.colorize(mask_gray, black="black", white="red").convert("RGBA") | |
| # Blend with original | |
| overlayed = Image.blend( | |
| original.convert("RGBA"), | |
| red_mask, | |
| alpha=0.4 # Adjust transparency | |
| ) | |
| return overlayed.convert("RGB") | |
| except Exception as e: | |
| logger.error(f"Overlay creation failed: {str(e)}", exc_info=True) | |
| error_img = Image.new("RGB", (256, 256), color="black") | |
| draw = ImageDraw.Draw(error_img) | |
| draw.text((10, 10), f"Overlay Error: {str(e)}", fill="white") | |
| return error_img | |
| def generate_gradcam(model, input_tensor, original_size): | |
| """Guaranteed-to-work Grad-CAM implementation""" | |
| try: | |
| # Verify input | |
| if not isinstance(input_tensor, torch.Tensor): | |
| raise ValueError("Input must be a torch.Tensor") | |
| # Generate CAM | |
| gradcam = GradCAMSegmentation(model, "output_block.conv.conv") | |
| cam = gradcam(input_tensor) | |
| # Convert to PIL Image with guaranteed visualization | |
| cam_uint8 = np.uint8(255 * cam) | |
| heatmap = Image.fromarray(cam_uint8).convert('L') | |
| # Apply color mapping that always works | |
| heatmap_color = ImageOps.colorize( | |
| heatmap, | |
| black='blue', | |
| white='red', | |
| mid='yellow' | |
| ).resize(original_size) | |
| return heatmap_color | |
| except Exception as e: | |
| # Create error image with detailed message | |
| error_img = Image.new("RGB", original_size, color="black") | |
| draw = ImageDraw.Draw(error_img) | |
| draw.text((10, 10), "GRAD-CAM ERROR", fill="red") | |
| draw.text((10, 40), str(e)[:50], fill="white") | |
| return error_img | |
| def overlay_gradcam(original, gradcam): | |
| """Foolproof overlay implementation""" | |
| try: | |
| original = original.convert("RGB") | |
| gradcam = gradcam.convert("RGB") | |
| # Simple, guaranteed overlay | |
| return Image.blend(original, gradcam, alpha=0.5) | |
| except Exception as e: | |
| error_img = Image.new("RGB", (256, 256), color="black") | |
| draw = ImageDraw.Draw(error_img) | |
| draw.text((10, 10), "OVERLAY ERROR", fill="red") | |
| return error_img | |