Spaces:
Sleeping
Sleeping
| """ | |
| step4_visualize.py | |
| ================== | |
| STEP 4 β Overlay heatmaps on the image and save a 2Γ5 grid PNG. | |
| Responsibilities: | |
| - Accept the original image, list of tokens, and parallel list of heatmaps. | |
| - Overlay each heatmap on the image using the INFERNO colormap. | |
| - Compose a 2Γ5 grid (original image + up to 9 word panels). | |
| - Save as a high-DPI PNG for clean presentation. | |
| The overlay uses a pixel-level transparency mask so that only the | |
| truly "hot" regions of the heatmap are blended, keeping the rest of | |
| the image fully visible and unobscured. | |
| """ | |
| import os | |
| import numpy as np | |
| import cv2 | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def overlay_heatmap_on_image( | |
| image_pil: Image.Image, | |
| heatmap_np: np.ndarray, | |
| alpha: float = 0.50, | |
| hot_threshold: float = 0.10, | |
| colormap=cv2.COLORMAP_INFERNO, | |
| ) -> Image.Image: | |
| """ | |
| Blend a normalised [0,1] heatmap on top of image_pil. | |
| Args: | |
| image_pil : PIL image (will be resized to match heatmap). | |
| heatmap_np : (H, W) float32 array, values in [0,1]. | |
| alpha : Opacity of hot regions (0 = invisible, 1 = fully colored). | |
| hot_threshold : Pixels below this value are NOT blended (stays the original image). | |
| colormap : OpenCV colormap to apply. | |
| Returns: | |
| blended PIL image. | |
| """ | |
| h, w = heatmap_np.shape | |
| img = np.array(image_pil.resize((w, h), Image.LANCZOS)) | |
| # Apply colormap | |
| hm_u8 = np.uint8(255.0 * heatmap_np) | |
| colored = cv2.applyColorMap(hm_u8, colormap) | |
| colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB) | |
| # Pixel-level mask: only blend "hot" spots | |
| mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None] | |
| blended = img * (1 - mask * alpha) + colored * (mask * alpha) | |
| return Image.fromarray(blended.astype(np.uint8)) | |
| def save_attention_grid( | |
| image_pil: Image.Image, | |
| tokens: list, | |
| heatmaps: list, | |
| out_path: str, | |
| n_rows: int = 2, | |
| n_cols: int = 5, | |
| dpi: int = 150, | |
| alpha: float = 0.50, | |
| verbose: bool = True, | |
| ) -> str: | |
| """ | |
| Build and save a (n_rows Γ n_cols) grid of per-word attention overlays. | |
| Layout: | |
| Cell [0]: Original image. | |
| Cells [1..n_rows*n_cols-1]: Word heatmaps in generation order. | |
| Args: | |
| image_pil : Original (or 224Γ224) PIL image. | |
| tokens : List of decoded word strings. | |
| heatmaps : Parallel list of (224, 224) numpy heatmaps. | |
| out_path : Absolute path for the saved PNG. | |
| n_rows, n_cols : Grid dimensions. | |
| dpi : Output DPI (150 recommended for presentation). | |
| alpha : Heatmap blend opacity. | |
| verbose : Print save confirmation. | |
| Returns: | |
| out_path (so callers can chain this call). | |
| """ | |
| n_panels = n_rows * n_cols # total slots | |
| n_words = min(n_panels - 1, len(tokens)) # slot 0 = original | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4.5, n_rows * 4.5)) | |
| axes = axes.flatten() | |
| # Panel 0 β original image | |
| axes[0].imshow(image_pil) | |
| axes[0].set_title("Original Image", fontsize=13, fontweight="bold", pad=4) | |
| axes[0].axis("off") | |
| # Panels 1β¦n_words β per-word heatmap overlays | |
| for i in range(n_words): | |
| overlay = overlay_heatmap_on_image(image_pil, heatmaps[i], alpha=alpha) | |
| ax = axes[i + 1] | |
| ax.imshow(overlay) | |
| ax.set_title(f"'{tokens[i]}'", fontsize=13, fontweight="bold", pad=4) | |
| ax.axis("off") | |
| # Turn off unused panels | |
| for j in range(n_words + 1, n_panels): | |
| axes[j].axis("off") | |
| caption = " ".join(tokens) | |
| fig.suptitle( | |
| f"Attention Flow (Multi-Layer GradCAM)\nCaption: \"{caption}\"", | |
| fontsize=15, | |
| fontweight="bold", | |
| y=1.02, | |
| ) | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=dpi, bbox_inches="tight") | |
| plt.close() | |
| if verbose: | |
| print(f"β Attention grid saved β {out_path}") | |
| return out_path | |