Spaces:
Sleeping
Sleeping
File size: 4,268 Bytes
0710b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
step4_visualize.py
==================
STEP 4 β Overlay heatmaps on the image and save a 2Γ5 grid PNG.
Responsibilities:
- Accept the original image, list of tokens, and parallel list of heatmaps.
- Overlay each heatmap on the image using the INFERNO colormap.
- Compose a 2Γ5 grid (original image + up to 9 word panels).
- Save as a high-DPI PNG for clean presentation.
The overlay uses a pixel-level transparency mask so that only the
truly "hot" regions of the heatmap are blended, keeping the rest of
the image fully visible and unobscured.
"""
import os
import numpy as np
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def overlay_heatmap_on_image(
image_pil: Image.Image,
heatmap_np: np.ndarray,
alpha: float = 0.50,
hot_threshold: float = 0.10,
colormap=cv2.COLORMAP_INFERNO,
) -> Image.Image:
"""
Blend a normalised [0,1] heatmap on top of image_pil.
Args:
image_pil : PIL image (will be resized to match heatmap).
heatmap_np : (H, W) float32 array, values in [0,1].
alpha : Opacity of hot regions (0 = invisible, 1 = fully colored).
hot_threshold : Pixels below this value are NOT blended (stays the original image).
colormap : OpenCV colormap to apply.
Returns:
blended PIL image.
"""
h, w = heatmap_np.shape
img = np.array(image_pil.resize((w, h), Image.LANCZOS))
# Apply colormap
hm_u8 = np.uint8(255.0 * heatmap_np)
colored = cv2.applyColorMap(hm_u8, colormap)
colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
# Pixel-level mask: only blend "hot" spots
mask = (heatmap_np > hot_threshold).astype(np.float32)[..., None]
blended = img * (1 - mask * alpha) + colored * (mask * alpha)
return Image.fromarray(blended.astype(np.uint8))
def save_attention_grid(
image_pil: Image.Image,
tokens: list,
heatmaps: list,
out_path: str,
n_rows: int = 2,
n_cols: int = 5,
dpi: int = 150,
alpha: float = 0.50,
verbose: bool = True,
) -> str:
"""
Build and save a (n_rows Γ n_cols) grid of per-word attention overlays.
Layout:
Cell [0]: Original image.
Cells [1..n_rows*n_cols-1]: Word heatmaps in generation order.
Args:
image_pil : Original (or 224Γ224) PIL image.
tokens : List of decoded word strings.
heatmaps : Parallel list of (224, 224) numpy heatmaps.
out_path : Absolute path for the saved PNG.
n_rows, n_cols : Grid dimensions.
dpi : Output DPI (150 recommended for presentation).
alpha : Heatmap blend opacity.
verbose : Print save confirmation.
Returns:
out_path (so callers can chain this call).
"""
n_panels = n_rows * n_cols # total slots
n_words = min(n_panels - 1, len(tokens)) # slot 0 = original
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4.5, n_rows * 4.5))
axes = axes.flatten()
# Panel 0 β original image
axes[0].imshow(image_pil)
axes[0].set_title("Original Image", fontsize=13, fontweight="bold", pad=4)
axes[0].axis("off")
# Panels 1β¦n_words β per-word heatmap overlays
for i in range(n_words):
overlay = overlay_heatmap_on_image(image_pil, heatmaps[i], alpha=alpha)
ax = axes[i + 1]
ax.imshow(overlay)
ax.set_title(f"'{tokens[i]}'", fontsize=13, fontweight="bold", pad=4)
ax.axis("off")
# Turn off unused panels
for j in range(n_words + 1, n_panels):
axes[j].axis("off")
caption = " ".join(tokens)
fig.suptitle(
f"Attention Flow (Multi-Layer GradCAM)\nCaption: \"{caption}\"",
fontsize=15,
fontweight="bold",
y=1.02,
)
plt.tight_layout()
plt.savefig(out_path, dpi=dpi, bbox_inches="tight")
plt.close()
if verbose:
print(f"β
Attention grid saved β {out_path}")
return out_path
|