project_02_DS / task /task_02 /step4_visualize.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step4_visualize.py
==================
STEP 4 β€” Overlay heatmaps on the image and save a 2Γ—5 grid PNG.
Responsibilities:
- Accept the original image, list of tokens, and parallel list of heatmaps.
- Overlay each heatmap on the image using the INFERNO colormap.
- Compose a 2Γ—5 grid (original image + up to 9 word panels).
- Save as a high-DPI PNG for clean presentation.
The overlay uses a pixel-level transparency mask so that only the
truly "hot" regions of the heatmap are blended, keeping the rest of
the image fully visible and unobscured.
"""
import os
import numpy as np
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image
# ────────────────────────────────────────────────────────────────────────────
def overlay_heatmap_on_image(
image_pil: Image.Image,
heatmap_np: np.ndarray,
alpha: float = 0.50,
hot_threshold: float = 0.10,
colormap=cv2.COLORMAP_INFERNO,
) -> Image.Image:
"""
Blend a normalised [0,1] heatmap on top of image_pil.
Args:
image_pil : PIL image (will be resized to match heatmap).
heatmap_np : (H, W) float32 array, values in [0,1].
alpha : Opacity of hot regions (0 = invisible, 1 = fully colored).
hot_threshold : Pixels below this value are NOT blended (stays the original image).
colormap : OpenCV colormap to apply.
Returns:
blended PIL image.
"""
h, w = heatmap_np.shape
img = np.array(image_pil.resize((w, h), Image.LANCZOS))
# Apply colormap
hm_u8 = np.uint8(255.0 * heatmap_np)
colored = cv2.applyColorMap(hm_u8, colormap)
colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
# Pixel-level mask: only blend "hot" spots
mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None]
blended = img * (1 - mask * alpha) + colored * (mask * alpha)
return Image.fromarray(blended.astype(np.uint8))
def save_attention_grid(
image_pil: Image.Image,
tokens: list,
heatmaps: list,
out_path: str,
n_rows: int = 2,
n_cols: int = 5,
dpi: int = 150,
alpha: float = 0.50,
verbose: bool = True,
) -> str:
"""
Build and save a (n_rows Γ— n_cols) grid of per-word attention overlays.
Layout:
Cell [0]: Original image.
Cells [1..n_rows*n_cols-1]: Word heatmaps in generation order.
Args:
image_pil : Original (or 224Γ—224) PIL image.
tokens : List of decoded word strings.
heatmaps : Parallel list of (224, 224) numpy heatmaps.
out_path : Absolute path for the saved PNG.
n_rows, n_cols : Grid dimensions.
dpi : Output DPI (150 recommended for presentation).
alpha : Heatmap blend opacity.
verbose : Print save confirmation.
Returns:
out_path (so callers can chain this call).
"""
n_panels = n_rows * n_cols # total slots
n_words = min(n_panels - 1, len(tokens)) # slot 0 = original
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4.5, n_rows * 4.5))
axes = axes.flatten()
# Panel 0 β€” original image
axes[0].imshow(image_pil)
axes[0].set_title("Original Image", fontsize=13, fontweight="bold", pad=4)
axes[0].axis("off")
# Panels 1…n_words β€” per-word heatmap overlays
for i in range(n_words):
overlay = overlay_heatmap_on_image(image_pil, heatmaps[i], alpha=alpha)
ax = axes[i + 1]
ax.imshow(overlay)
ax.set_title(f"'{tokens[i]}'", fontsize=13, fontweight="bold", pad=4)
ax.axis("off")
# Turn off unused panels
for j in range(n_words + 1, n_panels):
axes[j].axis("off")
caption = " ".join(tokens)
fig.suptitle(
f"Attention Flow (Multi-Layer GradCAM)\nCaption: \"{caption}\"",
fontsize=15,
fontweight="bold",
y=1.02,
)
plt.tight_layout()
plt.savefig(out_path, dpi=dpi, bbox_inches="tight")
plt.close()
if verbose:
print(f"βœ… Attention grid saved β†’ {out_path}")
return out_path