File size: 4,268 Bytes
0710b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
step4_visualize.py
==================
STEP 4 β€” Overlay heatmaps on the image and save a 2Γ—5 grid PNG.

Responsibilities:
- Accept the original image, list of tokens, and parallel list of heatmaps.
- Overlay each heatmap on the image using the INFERNO colormap.
- Compose a 2Γ—5 grid (original image + up to 9 word panels).
- Save as a high-DPI PNG for clean presentation.

The overlay uses a pixel-level transparency mask so that only the
truly "hot" regions of the heatmap are blended, keeping the rest of
the image fully visible and unobscured.
"""

import os
import numpy as np
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image


# ────────────────────────────────────────────────────────────────────────────
def overlay_heatmap_on_image(
    image_pil: Image.Image,
    heatmap_np: np.ndarray,
    alpha: float = 0.50,
    hot_threshold: float = 0.10,
    colormap=cv2.COLORMAP_INFERNO,
) -> Image.Image:
    """
    Blend a normalised [0,1] heatmap on top of image_pil.

    Args:
        image_pil     : PIL image (will be resized to match heatmap).
        heatmap_np    : (H, W) float32 array, values in [0,1].
        alpha         : Opacity of hot regions (0 = invisible, 1 = fully colored).
        hot_threshold : Pixels below this value are NOT blended (stays the original image).
        colormap      : OpenCV colormap to apply.

    Returns:
        blended PIL image.
    """
    h, w = heatmap_np.shape
    img  = np.array(image_pil.resize((w, h), Image.LANCZOS))

    # Apply colormap
    hm_u8    = np.uint8(255.0 * heatmap_np)
    colored  = cv2.applyColorMap(hm_u8, colormap)
    colored  = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)

    # Pixel-level mask: only blend "hot" spots
    mask     = (heatmap_np > hot_threshold).astype(np.float32)[..., None]
    blended  = img * (1 - mask * alpha) + colored * (mask * alpha)

    return Image.fromarray(blended.astype(np.uint8))


def save_attention_grid(
    image_pil: Image.Image,
    tokens: list,
    heatmaps: list,
    out_path: str,
    n_rows: int = 2,
    n_cols: int = 5,
    dpi: int = 150,
    alpha: float = 0.50,
    verbose: bool = True,
) -> str:
    """
    Build and save a (n_rows Γ— n_cols) grid of per-word attention overlays.

    Layout:
        Cell [0]: Original image.
        Cells [1..n_rows*n_cols-1]: Word heatmaps in generation order.

    Args:
        image_pil   : Original (or 224Γ—224) PIL image.
        tokens      : List of decoded word strings.
        heatmaps    : Parallel list of (224, 224) numpy heatmaps.
        out_path    : Absolute path for the saved PNG.
        n_rows, n_cols : Grid dimensions.
        dpi         : Output DPI (150 recommended for presentation).
        alpha       : Heatmap blend opacity.
        verbose     : Print save confirmation.

    Returns:
        out_path (so callers can chain this call).
    """
    n_panels = n_rows * n_cols           # total slots
    n_words  = min(n_panels - 1, len(tokens))   # slot 0 = original

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4.5, n_rows * 4.5))
    axes = axes.flatten()

    # Panel 0 β€” original image
    axes[0].imshow(image_pil)
    axes[0].set_title("Original Image", fontsize=13, fontweight="bold", pad=4)
    axes[0].axis("off")

    # Panels 1…n_words β€” per-word heatmap overlays
    for i in range(n_words):
        overlay = overlay_heatmap_on_image(image_pil, heatmaps[i], alpha=alpha)
        ax = axes[i + 1]
        ax.imshow(overlay)
        ax.set_title(f"'{tokens[i]}'", fontsize=13, fontweight="bold", pad=4)
        ax.axis("off")

    # Turn off unused panels
    for j in range(n_words + 1, n_panels):
        axes[j].axis("off")

    caption = " ".join(tokens)
    fig.suptitle(
        f"Attention Flow (Multi-Layer GradCAM)\nCaption: \"{caption}\"",
        fontsize=15,
        fontweight="bold",
        y=1.02,
    )
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, bbox_inches="tight")
    plt.close()

    if verbose:
        print(f"βœ… Attention grid saved β†’ {out_path}")

    return out_path