Segmentation / utils.py
KaranNag's picture
Update utils.py
2c10342 verified
import sys
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image, ImageDraw
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt
from matplotlib import cm
from io import BytesIO
from PIL import Image, ImageOps, ImageDraw
from skimage.transform import resize
from skimage.measure import label, regionprops # From scikit-image
from PIL import Image, ImageOps, ImageDraw
# Configure logging
logger = logging.getLogger(__name__)
class GradCAMSegmentation:
def __init__(self, model, target_layer_name):
self.model = model
self.target_layer = self._find_layer(target_layer_name)
self.activations = None
self.gradients = None
self._register_hooks()
def _find_layer(self, target_layer_name):
"""More robust layer finding implementation"""
module = self.model
for attr in target_layer_name.split('.'):
try:
if attr.isdigit():
module = module[int(attr)]
else:
module = getattr(module, attr)
except (AttributeError, IndexError) as e:
raise ValueError(f"Could not find layer {target_layer_name}: {str(e)}")
return module
def _register_hooks(self):
def forward_hook(module, input, output):
self.activations = output.detach()
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
def __call__(self, input_tensor):
self.model.zero_grad()
output = self.model(input_tensor)
output.mean().backward()
# More robust gradient calculation
if self.gradients is None:
raise RuntimeError("Gradients not captured - check hook registration")
weights = torch.mean(self.gradients, dim=[2, 3], keepdim=True)
cam = torch.sum(self.activations * weights, dim=1, keepdim=True)
cam = torch.relu(cam)
# Better normalization
cam_min, cam_max = torch.min(cam), torch.max(cam)
if cam_max - cam_min > 1e-8:
cam = (cam - cam_min) / (cam_max - cam_min)
else:
cam = torch.zeros_like(cam)
return cam.squeeze().cpu().numpy()
def preprocess_image_pil(image):
"""Convert PIL image to properly shaped torch tensor"""
try:
# Convert to grayscale
if image.mode != 'L':
image = image.convert('L')
# Validate original size
if not all(dim > 0 for dim in image.size):
raise ValueError(f"Invalid image dimensions: {image.size}")
# Transform pipeline
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
tensor = transform(image).unsqueeze(0) # Add batch dimension
logger.debug(f"Output tensor shape: {tensor.shape}")
if tensor.shape != (1, 1, 256, 256):
raise ValueError(f"Invalid output shape: {tensor.shape}")
return tensor, image
except Exception as e:
logger.error(f"Preprocessing failed: {e}")
return None, None
def postprocess_mask(mask_tensor, original_size):
"""Convert model output to PIL mask with thresholding"""
try:
if not isinstance(mask_tensor, torch.Tensor):
raise ValueError("Input must be a torch.Tensor")
mask_np = mask_tensor.squeeze().cpu().numpy()
logger.info(f"Mask range before threshold: {mask_np.min():.2f}-{mask_np.max():.2f}")
# Apply threshold and convert to binary mask
binary_mask = (mask_np > 0.5).astype(np.uint8) * 255
mask_img = Image.fromarray(binary_mask).convert("L")
# Resize to original dimensions
if original_size != (256, 256):
mask_img = mask_img.resize(original_size)
logger.info(f"Resized mask to original size: {original_size}")
return mask_img
except Exception as e:
logger.error(f"Postprocessing failed: {str(e)}", exc_info=True)
raise
def overlay_mask(original, mask):
"""Create overlay visualization with validation"""
try:
if not all(isinstance(img, Image.Image) for img in [original, mask]):
raise ValueError("Both inputs must be PIL Images")
# Convert original to RGB if needed
if original.mode != 'RGB':
original = original.convert("RGB")
# Create red mask overlay
mask_gray = mask.convert("L")
red_mask = ImageOps.colorize(mask_gray, black="black", white="red").convert("RGBA")
# Blend with original
overlayed = Image.blend(
original.convert("RGBA"),
red_mask,
alpha=0.4 # Adjust transparency
)
return overlayed.convert("RGB")
except Exception as e:
logger.error(f"Overlay creation failed: {str(e)}", exc_info=True)
error_img = Image.new("RGB", (256, 256), color="black")
draw = ImageDraw.Draw(error_img)
draw.text((10, 10), f"Overlay Error: {str(e)}", fill="white")
return error_img
def generate_gradcam(model, input_tensor, original_size):
"""Guaranteed-to-work Grad-CAM implementation"""
try:
# Verify input
if not isinstance(input_tensor, torch.Tensor):
raise ValueError("Input must be a torch.Tensor")
# Generate CAM
gradcam = GradCAMSegmentation(model, "output_block.conv.conv")
cam = gradcam(input_tensor)
# Convert to PIL Image with guaranteed visualization
cam_uint8 = np.uint8(255 * cam)
heatmap = Image.fromarray(cam_uint8).convert('L')
# Apply color mapping that always works
heatmap_color = ImageOps.colorize(
heatmap,
black='blue',
white='red',
mid='yellow'
).resize(original_size)
return heatmap_color
except Exception as e:
# Create error image with detailed message
error_img = Image.new("RGB", original_size, color="black")
draw = ImageDraw.Draw(error_img)
draw.text((10, 10), "GRAD-CAM ERROR", fill="red")
draw.text((10, 40), str(e)[:50], fill="white")
return error_img
def overlay_gradcam(original, gradcam):
"""Foolproof overlay implementation"""
try:
original = original.convert("RGB")
gradcam = gradcam.convert("RGB")
# Simple, guaranteed overlay
return Image.blend(original, gradcam, alpha=0.5)
except Exception as e:
error_img = Image.new("RGB", (256, 256), color="black")
draw = ImageDraw.Draw(error_img)
draw.text((10, 10), "OVERLAY ERROR", fill="red")
return error_img