| |
| |
| |
| |
|
|
| import cv2 |
| import numpy as np |
| from PIL import Image, ImageDraw |
| import torch |
| import matplotlib.pyplot as plt |
| from skimage import filters |
| from IPython.display import display |
|
|
| def gaussian_blur(heatmap, kernel_size=7): |
| |
| heatmap = heatmap.cpu().numpy() |
| heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), 0) |
| heatmap = torch.tensor(heatmap) |
| |
| return heatmap |
|
|
| def show_cam_on_image(img, mask): |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
| heatmap = np.float32(heatmap) / 255 |
| cam = heatmap + np.float32(img) |
| cam = cam / np.max(cam) |
| return cam |
|
|
| def show_image_and_heatmap(heatmap: torch.Tensor, image: Image.Image, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3): |
| image = image.resize((relevnace_res, relevnace_res)) |
| image = np.array(image) |
| image = (image - image.min()) / (image.max() - image.min()) |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1]) |
| heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation) |
| heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) |
| heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu() |
|
|
| vis = show_cam_on_image(image, heatmap) |
| vis = np.uint8(255 * vis) |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
|
| vis = vis.astype(np.uint8) |
| vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res)) |
|
|
| return vis |
|
|
| def show_only_heatmap(heatmap: torch.Tensor, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3): |
| |
| |
|
|
| heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1]) |
| heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation) |
| heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) |
| heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu() |
|
|
| vis = heatmap |
| vis = np.uint8(255 * vis) |
|
|
| |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_GRAY2BGR) |
|
|
| vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res)) |
|
|
| return vis |
|
|
| def visualize_tokens_attentions(attention, tokens, image, heatmap_interpolation="nearest", show_on_image=True): |
| |
| |
| token_vis = [] |
| for j, token in enumerate(tokens): |
| if j >= attention.shape[0]: |
| break |
| |
| if show_on_image: |
| vis = show_image_and_heatmap(attention[j], image, relevnace_res=512, interpolation=heatmap_interpolation) |
| else: |
| vis = show_only_heatmap(attention[j], relevnace_res=512, interpolation=heatmap_interpolation) |
| |
| token_vis.append((token, vis)) |
|
|
| |
| K = 4 |
| n_rows = (len(token_vis) + K - 1) // K |
| fig, axs = plt.subplots(n_rows, K, figsize=(K*5, n_rows*5)) |
|
|
| for i, (token, vis) in enumerate(token_vis): |
| row, col = divmod(i, K) |
| if n_rows > 1: |
| ax = axs[row, col] |
| elif K > 1: |
| ax = axs[col] |
| else: |
| ax = axs |
|
|
| ax.imshow(vis) |
| ax.set_title(token) |
| ax.axis("off") |
|
|
| |
| for j in range(i + 1, n_rows * K): |
| row, col = divmod(j, K) |
| if n_rows > 1: |
| axs[row, col].axis('off') |
| elif K > 1: |
| axs[col].axis('off') |
|
|
| plt.tight_layout() |
|
|
| |
| return fig |
|
|
| def show_images(images, titles=None, size=1024, max_row_length=5, figsize=None, col_height=10, save_path=None): |
| if isinstance(images, Image.Image): |
| images = [images] |
|
|
| if len(images) == 1: |
| img = images[0] |
| img = img.resize((size, size)) |
| plt.imshow(img) |
| plt.axis('off') |
|
|
| if titles is not None: |
| plt.title(titles[0]) |
| |
| if save_path: |
| plt.savefig(save_path, bbox_inches='tight', dpi=150) |
| |
| plt.show() |
| else: |
| images = [img.resize((size, size)) for img in images] |
|
|
| |
| if titles is not None: |
| assert len(images) == len(titles), "Number of titles should match the number of images" |
|
|
| n_images = len(images) |
| n_cols = min(n_images, max_row_length) |
| n_rows = (n_images + n_cols - 1) // n_cols |
|
|
| if figsize is None: |
| figsize=(n_cols * col_height, n_rows * col_height) |
|
|
| fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize) |
| axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs] |
|
|
| |
| for i, img in enumerate(images): |
| axs[i].imshow(img) |
| if titles is not None: |
| axs[i].set_title(titles[i]) |
| axs[i].axis("off") |
|
|
| |
| for ax in axs[len(images):]: |
| ax.axis("off") |
|
|
| if save_path: |
| plt.savefig(save_path, bbox_inches='tight', dpi=150) |
| |
| plt.show() |
|
|
| def show_tensors(tensors, titles=None, size=None, max_row_length=5): |
| |
| if size is not None: |
| tensors = [torch.nn.functional.interpolate(t.unsqueeze(0).unsqueeze(0), size=(size, size), mode='bilinear').squeeze() for t in tensors] |
|
|
| if len(tensors) == 1: |
| plt.imshow(tensors[0].cpu().numpy()) |
| plt.axis('off') |
|
|
| if titles is not None: |
| plt.title(titles[0]) |
| |
| plt.show() |
| else: |
| |
| if titles is not None: |
| assert len(tensors) == len(titles), "Number of titles should match the number of images" |
|
|
| n_tensors = len(tensors) |
| n_cols = min(n_tensors, max_row_length) |
| n_rows = (n_tensors + n_cols - 1) // n_cols |
|
|
| fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10)) |
| axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs] |
|
|
| for i, tensor in enumerate(tensors): |
| axs[i].imshow(tensor.cpu().numpy()) |
| if titles is not None: |
| axs[i].set_title(titles[i]) |
| axs[i].axis("off") |
|
|
| for ax in axs[len(tensors):]: |
| ax.axis("off") |
| |
| plt.show() |
|
|
| def draw_bboxes_on_image(image, bboxes, color="red", thickness=2): |
| image = image.copy() |
| draw = ImageDraw.Draw(image) |
| for bbox in bboxes: |
| draw.rectangle(bbox, outline=color, width=thickness) |
| return image |
|
|
| def draw_points_on_pil_image(pil_image, point_coords, point_color="red", radius=5): |
| """ |
| Draw points (circles) on a PIL image and return the modified image. |
| |
| :param pil_image: PIL Image (e.g., sam_masked_image) |
| :param point_coords: An array-like of shape (N, 2), with x,y coordinates |
| :param point_color: Color of the point (default 'red') |
| :param radius: Radius of the drawn circles |
| :return: PIL Image with points drawn |
| """ |
| |
| out_img = pil_image.copy() |
| draw = ImageDraw.Draw(out_img) |
| |
| |
| for x, y in point_coords: |
| |
| left_up_point = (x - radius, y - radius) |
| right_down_point = (x + radius, y + radius) |
| |
| draw.ellipse([left_up_point, right_down_point], fill=point_color, outline=point_color) |
| |
| return out_img |