| import torch |
| from torch import nn |
| import numpy as np |
| from cv2 import resize |
| import cv2 |
| from pathlib import Path |
|
|
| from network import EfficientViT_l1_r224 |
| from losses import IISLoss, activate |
| from utils import minmaxnorm, load_from_ckpt |
|
|
|
|
| class Busam: |
| def __init__(self, checkpoint, device, side=224): |
| out_channels = 16 |
| use_norm_params = False |
| net = EfficientViT_l1_r224( |
| out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False |
| ) |
| net = load_from_ckpt(net, checkpoint) |
| net = net.to(device) |
| net.eval() |
| self.net = net |
| self.device = device |
| self.side = side |
|
|
| def prepare_img(self, img): |
| """ |
| assume H, W, 3 image |
| """ |
| assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape) |
| assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape) |
| assert img.min() >= 0, "min should be more than 0 but is " + str(img.min()) |
| assert img.max() <= 255, "max should be less than 255 but is " + str(img.max()) |
| assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str( |
| img.dtype |
| ) |
| nimg = resize(img, (self.side, self.side)) |
| tensorimg = ( |
| (torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5) |
| .float()[None] |
| .to(self.device) |
| ) |
| return tensorimg |
|
|
| def process_image(self, img, do_activate=False): |
| with torch.no_grad(): |
| x = self.prepare_img(img) |
| pred = self.net(x) |
| H, W = img.shape[:2] |
| if do_activate: |
| B, F, pH, pW = pred.shape |
| features, _, _, _ = activate( |
| pred.view(F, pH * pW), None, "symlog", False, False, False |
| ) |
| pred = features.view(B, F, pH, pW) |
| return pred, (H, W) |
|
|
| def get_mask(self, aux, click): |
| """assume click is (row, col)""" |
| pred = aux[0][0] |
| oH, oW = aux[1] |
| F, H, W = pred.shape |
| features = pred.view(F, H * W) |
| rclick = click[0] * H // oH, click[1] * W // oW |
| sindex = rclick[0] * W + rclick[1] |
| mask = IISLoss.get_mask_from_query(features, sindex) |
| mask = mask.reshape(H, W) |
| mask = ( |
| resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100 |
| ).astype(bool) |
| return mask |
|
|
| def get_gradients(self, pred, size): |
| F, H, W = pred[0].shape |
| sobel_x = ( |
| torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device) |
| ) |
| sobel_y = sobel_x.T |
| sobel_x = sobel_x.repeat(F, 1, 1, 1) |
| sobel_y = sobel_y.repeat(F, 1, 1, 1) |
| edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view( |
| F, H, W |
| ) |
| edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view( |
| F, H, W |
| ) |
| edge_x = torch.norm(edge_x, dim=0, p=2) |
| edge_y = torch.norm(edge_y, dim=0, p=2) |
| return edge_x, edge_y |
|
|
| def sobel_from_pred(self, pred, size): |
| edge_x, edge_y = self.get_gradients(pred, size) |
| edge = torch.sqrt(edge_x**2 + edge_y**2) |
| return edge |
|
|
| def canny_from_pred(self, pred, size, th_low=10000, th_high=20000): |
| th_low = th_low or th_high |
| th_high = th_high or th_low |
|
|
| edge_x, edge_y = self.get_gradients(pred, size) |
| amin = min(edge_x.min(), edge_y.min()) |
| amax = max(edge_x.max(), edge_y.max()) |
| edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / ( |
| amax - amin |
| ) |
| canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high) |
| return canny |
|
|
|
|
| def cast_to_int16(x): |
| if isinstance(x, torch.Tensor): |
| x = x.cpu().numpy() |
| return (x * 32767).astype(np.int16) |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|