| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import warnings |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw, ImageFont |
| from sklearn.cluster import KMeans |
| from sklearn.decomposition import PCA |
| from transformers import ( |
| AutoImageProcessor, |
| ViTModel, |
| ViTForImageClassification, |
| AutoConfig, |
| ) |
| import plotly.express as px |
|
|
| warnings.filterwarnings("ignore") |
|
|
| MODEL_NAME = "google/vit-base-patch16-224" |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| BASE_MODEL = None |
| CLF_MODEL = None |
| PROCESSOR = None |
|
|
|
|
| |
| def load_models(): |
| global BASE_MODEL, CLF_MODEL, PROCESSOR |
| if BASE_MODEL is not None and CLF_MODEL is not None and PROCESSOR is not None: |
| return BASE_MODEL, CLF_MODEL, PROCESSOR |
|
|
| PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME) |
|
|
| |
| cfg = AutoConfig.from_pretrained(MODEL_NAME) |
| cfg.attn_implementation = "eager" |
| cfg.output_attentions = True |
| cfg.output_hidden_states = True |
|
|
| |
| BASE_MODEL = ViTModel.from_pretrained(MODEL_NAME, config=cfg) |
| BASE_MODEL.to(DEVICE).eval() |
|
|
| |
| CLF_MODEL = ViTForImageClassification.from_pretrained(MODEL_NAME) |
| CLF_MODEL.to(DEVICE).eval() |
|
|
| return BASE_MODEL, CLF_MODEL, PROCESSOR |
|
|
|
|
| |
| def patch_grid_info(image_size: int = 224, patch_size: int = 16): |
| grid_size = image_size // patch_size |
| positions = [] |
| for i in range(grid_size): |
| for j in range(grid_size): |
| |
| cx = int((j + 0.5) * patch_size) |
| cy = int((i + 0.5) * patch_size) |
| positions.append((cx, cy)) |
| return grid_size, positions |
|
|
|
|
| |
| def draw_patch_grid(img: Image.Image, patch_size: int = 16, outline=(0, 180, 0)) -> Image.Image: |
| img = img.convert("RGB").resize((224, 224)) |
| draw = ImageDraw.Draw(img) |
| w, h = img.size |
| for x in range(0, w, patch_size): |
| draw.line([(x, 0), (x, h)], fill=outline, width=1) |
| for y in range(0, h, patch_size): |
| draw.line([(0, y), (w, y)], fill=outline, width=1) |
| return img |
|
|
|
|
| def draw_cluster_blocks(img: Image.Image, labels: np.ndarray, n_clusters: int = 4, patch_size: int = 16): |
| """ |
| labels: (n_patches,) cluster labels assigned to each patch index (left→right, top→bottom) |
| """ |
| img = img.convert("RGB").resize((224, 224)) |
| draw = ImageDraw.Draw(img, "RGBA") |
| grid_size, positions = patch_grid_info() |
| colors = [ |
| (255, 99, 71, 140), |
| (60, 179, 113, 140), |
| (65, 105, 225, 140), |
| (255, 215, 0, 140), |
| (199, 21, 133, 140), |
| (0, 206, 209, 140), |
| ] |
| for idx, lab in enumerate(labels): |
| i = idx // grid_size |
| j = idx % grid_size |
| x0 = j * patch_size |
| y0 = i * patch_size |
| x1 = x0 + patch_size |
| y1 = y0 + patch_size |
| col = colors[int(lab) % len(colors)] |
| draw.rectangle([x0, y0, x1, y1], fill=col) |
| return img |
|
|
|
|
| def draw_attention_arrows(img: Image.Image, att_matrix: np.ndarray, top_k: int = 3, query_idx: Optional[int] = None): |
| """ |
| att_matrix: (n_patches, n_patches) attention from query->keys (already preprocessed) |
| If query_idx is None -> use CLS (not plotted as patch), else 0..n_patches-1 |
| We'll draw arrows from query patch centers to top-k key patch centers. |
| """ |
| img = img.convert("RGB").resize((224, 224)) |
| draw = ImageDraw.Draw(img, "RGBA") |
| grid_size, positions = patch_grid_info() |
| |
| if query_idx is None: |
| query_idx = (grid_size * grid_size) // 2 |
| qpos = positions[query_idx] |
| |
| vec = att_matrix[query_idx] |
| top_idx = vec.argsort()[-top_k:][::-1] |
| for t in top_idx: |
| kpos = positions[t] |
| |
| draw.line([qpos, kpos], fill=(255, 0, 0, 200), width=3) |
| |
| dx = kpos[0] - qpos[0] |
| dy = kpos[1] - qpos[1] |
| ang = math.atan2(dy, dx) |
| |
| ah = 8 |
| p1 = (kpos[0] - ah * math.cos(ang - 0.3), kpos[1] - ah * math.sin(ang - 0.3)) |
| p2 = (kpos[0] - ah * math.cos(ang + 0.3), kpos[1] - ah * math.sin(ang + 0.3)) |
| draw.polygon([kpos, p1, p2], fill=(255, 0, 0, 200)) |
| |
| r = 10 |
| draw.ellipse([qpos[0] - r, qpos[1] - r, qpos[0] + r, qpos[1] + r], outline=(0, 0, 255, 220), width=2) |
| return img |
|
|
|
|
| def make_focus_overlay(img: Image.Image, heat_grid: np.ndarray, alpha: float = 0.6): |
| """ |
| heat_grid: (G,G) float map |
| overlay colored transparency on image where heat is high |
| """ |
| img = img.convert("RGB").resize((224, 224)) |
| g = np.array(heat_grid, dtype=np.float32) |
| if np.any(g): |
| g = g - g.min() |
| if g.max() > 0: |
| g = g / g.max() |
| else: |
| g = np.zeros_like(g) |
| heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR) |
| heat = np.array(heat_img).astype(np.float32) / 255.0 |
| draw = ImageDraw.Draw(img, "RGBA") |
| |
| H, W = heat.shape |
| for y in range(H): |
| for x in range(W): |
| v = heat[y, x] |
| if v > 0.05: |
| |
| r = int(255 * v) |
| gcol = int(200 * (1 - v)) |
| draw.point((x, y), fill=(r, gcol, 40, int(255 * alpha * v))) |
| return img |
|
|
|
|
| |
| def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray: |
| avg_mats = [] |
| for a in all_attentions: |
| mat = a[0].mean(dim=0).detach().cpu().numpy() |
| avg_mats.append(mat) |
| seq = avg_mats[0].shape[0] |
| aug = [] |
| for A in avg_mats: |
| A_hat = A + np.eye(seq) |
| row_sums = A_hat.sum(axis=-1, keepdims=True) |
| row_sums[row_sums == 0] = 1.0 |
| A_hat = A_hat / row_sums |
| aug.append(A_hat) |
| R = aug[0] |
| for A in aug[1:]: |
| R = A @ R |
| return R |
|
|
|
|
| |
| def pca_plot_from_hidden(hidden_states: List[torch.Tensor], layers: List[int]): |
| pts_all = [] |
| labels = [] |
| for li in layers: |
| hs = hidden_states[li][0].detach().cpu().numpy() |
| patches = hs[1:, :] |
| pca = PCA(n_components=2) |
| pts = pca.fit_transform(patches) |
| pts_all.append(pts) |
| labels.append(np.array([li] * pts.shape[0])) |
| coords = np.vstack(pts_all) |
| layer_labels = np.concatenate(labels) |
| df = {"x": coords[:, 0], "y": coords[:, 1], "layer": layer_labels.astype(str)} |
| fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)") |
| fig.update_traces(marker=dict(size=6)) |
| fig.update_layout(height=480) |
| return fig |
|
|
|
|
| |
| def analyze_all(img: Optional[Image.Image], mode_simple: bool): |
| if img is None: |
| |
| empty = None |
| return empty, empty, empty, empty, "", empty, empty, empty |
|
|
| base, clf, processor = load_models() |
|
|
| |
| img224 = img.convert("RGB").resize((224, 224)) |
| inputs = processor(images=img224, return_tensors="pt").to(DEVICE) |
|
|
| |
| with torch.no_grad(): |
| outputs = base(**inputs) |
|
|
| attentions = outputs.attentions |
| hidden_states = outputs.hidden_states |
|
|
| |
| grid_size, positions = patch_grid_info() |
| seq_len = attentions[0].shape[-1] |
| n_patches = seq_len - 1 |
|
|
| |
| patch_grid_img = draw_patch_grid(img224.copy()) |
|
|
| |
| last_hidden = hidden_states[-1][0].detach().cpu().numpy() |
| patch_embeddings = last_hidden[1:, :] |
| |
| n_clusters = 4 |
| try: |
| kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(patch_embeddings) |
| cluster_labels = kmeans.labels_ |
| except Exception: |
| |
| cluster_labels = np.zeros(n_patches, dtype=int) |
|
|
| cluster_img = draw_cluster_blocks(img224.copy(), cluster_labels, n_clusters=n_clusters) |
|
|
| |
| last_att = attentions[-1][0].mean(dim=0).cpu().numpy() |
| |
| |
| |
| if last_att.shape[0] >= n_patches + 1: |
| patch_to_patch = last_att[1:, 1:] |
| else: |
| |
| patch_to_patch = np.zeros((n_patches, n_patches)) |
| |
| arrow_img = draw_attention_arrows(img224.copy(), patch_to_patch, top_k=4, query_idx=(n_patches // 2)) |
|
|
| |
| rollout = compute_attention_rollout(attentions) |
| |
| rollout_cls = rollout[0, 1:] |
| if rollout_cls.shape[0] != grid_size * grid_size: |
| tmp = np.zeros(grid_size * grid_size, dtype=float) |
| nmin = min(len(rollout_cls), tmp.shape[0]) |
| tmp[:nmin] = rollout_cls[:nmin] |
| rollout_cls = tmp |
| rollout_grid = rollout_cls.reshape(grid_size, grid_size) |
| focus_img = make_focus_overlay(img224.copy(), rollout_grid, alpha=0.6) |
|
|
| |
| with torch.no_grad(): |
| logits = clf(**inputs).logits[0].cpu().numpy() |
| probs = np.exp(logits - logits.max()) |
| probs = probs / probs.sum() |
| top5 = probs.argsort()[-5:][::-1] |
| labels = clf.config.id2label |
| preds_text = "\n".join([f"{labels[i]} — {probs[i]*100:.2f}%" for i in top5]) |
|
|
| |
| pca_fig = pca_plot_from_hidden(hidden_states, [0, max(0, len(hidden_states) // 2), len(hidden_states) - 1]) |
|
|
| |
| att_np = attentions[-1][0].cpu().numpy() |
| |
| cls_to_patches = att_np.mean(axis=0)[0, 1:] |
| if cls_to_patches.shape[0] != grid_size * grid_size: |
| tmp = np.zeros(grid_size * grid_size, dtype=float) |
| nmin = min(len(cls_to_patches), tmp.shape[0]) |
| tmp[:nmin] = cls_to_patches[:nmin] |
| cls_to_patches = tmp |
| cls_grid = cls_to_patches.reshape(grid_size, grid_size) |
| |
| from PIL import Image |
| focus_overlay_default = make_focus_overlay(img224.copy(), cls_grid, alpha=0.5) |
|
|
| |
| state = { |
| "attentions": [a.cpu() for a in attentions], |
| "hidden_states": [h.cpu() for h in hidden_states], |
| "grid_size": grid_size, |
| "num_layers": len(attentions), |
| "num_heads": attentions[0].shape[1], |
| "base_image": img, |
| } |
|
|
| |
| |
| |
| simple_explain = """ |
| **How ViT Sees — Simple Steps** |
| |
| 1) **Chop** — The image is chopped into small square tiles (patches) like LEGO pieces. |
| 2) **Understand** — Each piece gets a code that describes colors/edges. Pieces that look similar are grouped. |
| 3) **Talk** — Pieces tell each other what they see (we draw arrows to show that). |
| 4) **Focus & Guess** — The model merges clues and focuses on important areas, then guesses what the image shows. |
| """ |
|
|
| advanced_explain = """ |
| **Advanced View:** Explore attention per layer/head, the PCA of patch embeddings, and the model's internal focus. |
| Use sliders to change layer/head and see how ViT's attention evolves. |
| """ |
|
|
| return ( |
| patch_grid_img, |
| cluster_img, |
| arrow_img, |
| focus_img, |
| preds_text, |
| simple_explain, |
| focus_overlay_default, |
| pca_fig, |
| preds_text, |
| advanced_explain, |
| state, |
| ) |
|
|
|
|
| |
| def advanced_update_attention(state: Dict[str, Any], layer_idx: int, head_idx: int): |
| if not state: |
| return None |
| l = max(0, min(int(layer_idx), state["num_layers"] - 1)) |
| h = max(0, min(int(head_idx), state["num_heads"] - 1)) |
| att_tensor = state["attentions"][l] |
| if att_tensor.ndim == 4: |
| att_tensor = att_tensor[0] |
| att_np = att_tensor.numpy() |
| |
| vec = att_np[h, 0, 1:] |
| grid = state["grid_size"] |
| if vec.shape[0] != grid * grid: |
| tmp = np.zeros(grid * grid, dtype=float) |
| nmin = min(vec.shape[0], tmp.shape[0]) |
| tmp[:nmin] = vec[:nmin] |
| vec = tmp |
| grid_map = vec.reshape(grid, grid) |
| return make_focus_overlay(state["base_image"].convert("RGB"), grid_map, alpha=0.55) |
|
|
|
|
| def advanced_update_rollout(state: Dict[str, Any]): |
| if not state: |
| return None |
| mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in state["attentions"]] |
| R = compute_attention_rollout(mats) |
| grid = state["grid_size"] |
| rollout_cls = R[0, 1:] |
| if rollout_cls.shape[0] != grid * grid: |
| tmp = np.zeros(grid * grid, dtype=float) |
| nmin = min(len(rollout_cls), tmp.shape[0]) |
| tmp[:nmin] = rollout_cls[:nmin] |
| rollout_cls = tmp |
| rollout_grid = rollout_cls.reshape(grid, grid) |
| return make_focus_overlay(state["base_image"].convert("RGB"), rollout_grid, alpha=0.6) |
|
|
|
|
| def advanced_update_pca(state: Dict[str, Any], txt: str): |
| if not state: |
| return None |
| try: |
| layers = [int(x.strip()) for x in txt.split(",") if x.strip() != ""] |
| except Exception: |
| layers = [0, max(0, state["num_layers"] - 1)] |
| return pca_plot_from_hidden(state["hidden_states"], layers) |
|
|
|
|
| |
| with gr.Blocks(title="ViT Visualizer — Simple + Advanced") as demo: |
| gr.Markdown("# 👀 How Vision Transformers (ViT) See Images\n" |
| "Simple mode (story-style) + Advanced mode (inspect internals). Model: **google/vit-base-patch16-224**") |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Simple (for everyone)"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| img_input = gr.Image(label="Upload an image (photo / object)", type="pil") |
| run_btn = gr.Button("🔎 Explain simply") |
| gr.Markdown("Tip: use clear images of objects, animals, scenes for best examples.") |
| with gr.Column(scale=1): |
| pass |
|
|
| gr.Markdown("### Step 1 — Chopped into patches") |
| step1 = gr.Image(label="Patch Grid (ViT chops image into 16×16 patches)") |
|
|
| gr.Markdown("### Step 2 — The model groups similar patches") |
| step2 = gr.Image(label="Clustered patches (colored blocks)") |
|
|
| gr.Markdown("### Step 3 — Patches talk to each other (simplified)") |
| step3 = gr.Image(label="Patch-to-Patch arrows") |
|
|
| gr.Markdown("### Step 4 — Model focus map and guess") |
| with gr.Row(): |
| step4 = gr.Image(label="Focus map (where model looked most)") |
| preds_simple = gr.Textbox(label="Model guesses (Top-5)", lines=4) |
|
|
| explanation_simple = gr.Markdown() |
|
|
| run_btn.click( |
| fn=analyze_all, |
| inputs=[img_input, gr.State(True)], |
| outputs=[step1, step2, step3, step4, preds_simple, explanation_simple, |
| gr.State(), gr.Plot(), gr.Textbox(), gr.Markdown(), gr.State()], |
| ) |
|
|
| with gr.TabItem("Advanced (inspect internals)"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| img_adv = gr.Image(label="Upload image for advanced view", type="pil") |
| run_adv = gr.Button("Analyze (advanced)") |
| gr.Markdown("Use the sliders to explore attention per layer and head.") |
| layer_slider = gr.Slider(0, 11, value=11, step=1, label="Layer (0=shallow)") |
| head_slider = gr.Slider(0, 11, value=0, step=1, label="Head index") |
| rollout_btn = gr.Button("Refresh Rollout Overlay") |
| pca_txt = gr.Textbox(label="PCA layers (comma separated)", value="0,6,11") |
| pca_btn = gr.Button("Update PCA") |
| with gr.Column(scale=1): |
| adv_attn = gr.Image(label="Attention overlay (layer/head CLS->patch)") |
| adv_rollout = gr.Image(label="Attention rollout overlay (aggregated)") |
| adv_pca = gr.Plot(label="PCA of patch embeddings") |
| adv_preds = gr.Textbox(label="Top-5 predictions", lines=5) |
| adv_explain = gr.Markdown() |
|
|
| state_box = gr.State() |
|
|
| |
| run_adv.click( |
| fn=analyze_all, |
| inputs=[img_adv, gr.State(False)], |
| outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image(), adv_preds, gr.Markdown(), |
| adv_attn, adv_pca, adv_preds, adv_explain, state_box], |
| ) |
|
|
| |
| layer_slider.change( |
| fn=advanced_update_attention, |
| inputs=[state_box, layer_slider, head_slider], |
| outputs=[adv_attn], |
| ) |
| head_slider.change( |
| fn=advanced_update_attention, |
| inputs=[state_box, layer_slider, head_slider], |
| outputs=[adv_attn], |
| ) |
|
|
| rollout_btn.click( |
| fn=advanced_update_rollout, |
| inputs=[state_box], |
| outputs=[adv_rollout], |
| ) |
|
|
| pca_btn.click( |
| fn=advanced_update_pca, |
| inputs=[state_box, pca_txt], |
| outputs=[adv_pca], |
| ) |
|
|
| demo.launch() |