| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from PIL import Image, ImageDraw |
| | import os |
| | import torch |
| | import monai.transforms as transforms |
| |
|
| | def draw_result(category, image, bboxes, points, logits, gt3D): |
| | zoom_out_transform = transforms.Compose([ |
| | transforms.AddChanneld(keys=["image", "label", "logits"]), |
| | transforms.Resized(keys=["image", "label", "logits"], spatial_size=(32,256,256), mode='nearest-exact') |
| | ]) |
| | print(image.shape, gt3D.shape, logits.shape) |
| | image = image[0,0] |
| | post_item = zoom_out_transform({ |
| | 'image': image, |
| | 'label': gt3D, |
| | 'logits': logits |
| | }) |
| | image, gt3D, logits = post_item['image'][0], post_item['label'][0], post_item['logits'][0] |
| | preds = torch.sigmoid(logits) |
| | preds = (preds > 0.5).int() |
| |
|
| | root_dir=os.path.join(f'./fig_examples/{category}/') |
| |
|
| | if not os.path.exists(root_dir): |
| | os.makedirs(root_dir) |
| | if bboxes is not None: |
| | x1, y1, z1, x2, y2, z2 = bboxes[0].cpu().numpy() |
| | if points is not None: |
| | points = (points[0].cpu().numpy(), points[1].cpu().numpy()) |
| | points_ax = points[0] |
| | points_label = points[1] |
| | |
| |
|
| | for j in range(image.shape[0]): |
| | img_2d = image[j, :, :].detach().cpu().numpy() |
| | preds_2d = preds[j, :, :].detach().cpu().numpy() |
| | label_2d = gt3D[j, :, :].detach().cpu().numpy() |
| | |
| | |
| |
|
| | |
| | fig, (ax1, ax2, ax3) = plt.subplots(1, 3) |
| | ax1.imshow(img_2d, cmap='gray') |
| | ax1.set_title('Image with prompt') |
| | ax1.axis('off') |
| |
|
| | |
| | ax2.imshow(img_2d, cmap='gray') |
| | show_mask(label_2d, ax2) |
| | ax2.set_title('Ground truth') |
| | ax2.axis('off') |
| |
|
| | |
| | ax3.imshow(img_2d, cmap='gray') |
| | show_mask(preds_2d, ax3) |
| | ax3.set_title('Prediction') |
| | ax3.axis('off') |
| |
|
| | |
| | if bboxes is not None: |
| | if j >= x1 and j <= x2: |
| | show_box((z1, y1, z2, y2), ax1) |
| | |
| | if points is not None: |
| | for point_idx in range(points_label.shape[0]): |
| | point = points_ax[point_idx] |
| | label = points_label[point_idx] |
| | if j == point[0]: |
| | show_points(point, label, ax1) |
| | |
| | fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) |
| | plt.savefig(os.path.join(root_dir, f'{category}_{j}.png'), bbox_inches='tight') |
| | plt.close() |
| |
|
| | def show_mask(mask, ax): |
| | color = np.array([251/255, 252/255, 30/255, 0.6]) |
| | h, w = mask.shape[-2:] |
| | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| | ax.imshow(mask_image, alpha=0.35) |
| |
|
| | def show_box(box, ax): |
| | x0, y0 = box[0], box[1] |
| | w, h = box[2] - box[0], box[3] - box[1] |
| | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) |
| |
|
| | def show_points(points_ax, points_label, ax): |
| | print('draw point') |
| | color = 'red' if points_label == 0 else 'blue' |
| | ax.scatter(points_ax[2], points_ax[1], c=color, marker='o', s=200) |
| |
|