| import cv2 |
| import numpy as np |
|
|
| import torch |
| from dataset.range_transform import inv_im_trans, inv_lll2rgb_trans |
| from collections import defaultdict |
|
|
| from PIL import Image |
| from skimage import color, io |
|
|
| import util.functional as F |
| class Normalize(object): |
| def __init__(self): |
| pass |
|
|
| def __call__(self, inputs): |
| inputs[0:1, :, :] = F.normalize(inputs[0:1, :, :], 50, 1) |
| inputs[1:3, :, :] = F.normalize(inputs[1:3, :, :], (0, 0), (1, 1)) |
| return inputs |
|
|
| def tensor_to_numpy(image): |
| image_np = (image.numpy() * 255).astype('uint8') |
| return image_np |
|
|
| def tensor_to_np_float(image): |
| image_np = image.numpy().astype('float32') |
| return image_np |
|
|
| def detach_to_cpu(x): |
| return x.detach().cpu() |
|
|
| def transpose_np(x): |
| return np.transpose(x, [1,2,0]) |
|
|
| def tensor_to_gray_im(x): |
| x = detach_to_cpu(x) |
| x = tensor_to_numpy(x) |
| x = transpose_np(x) |
| return x |
|
|
| def tensor_to_im(x): |
| x = detach_to_cpu(x) |
| x = inv_im_trans(x).clamp(0, 1) |
| x = tensor_to_numpy(x) |
| x = transpose_np(x) |
| return x |
|
|
| |
| key_captions = { |
| 'im': 'Image', |
| 'gt': 'GT', |
| } |
|
|
| """ |
| Return an image array with captions |
| keys in dictionary will be used as caption if not provided |
| values should contain lists of cv2 images |
| """ |
| def get_image_array(images, grid_shape, captions={}): |
| h, w = grid_shape |
| cate_counts = len(images) |
| rows_counts = len(next(iter(images.values()))) |
|
|
| font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
| output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8) |
| col_cnt = 0 |
| for k, v in images.items(): |
|
|
| |
| caption = captions.get(k, k) |
|
|
| |
| dy = 40 |
| for i, line in enumerate(caption.split('\n')): |
| cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy), |
| font, 0.8, (255,255,255), 2, cv2.LINE_AA) |
|
|
| |
| for row_cnt, img in enumerate(v): |
| im_shape = img.shape |
| if len(im_shape) == 2: |
| img = img[..., np.newaxis] |
|
|
| img = (img * 255).astype('uint8') |
|
|
| output_image[(col_cnt+0)*w:(col_cnt+1)*w, |
| (row_cnt+1)*h:(row_cnt+2)*h, :] = img |
| |
| col_cnt += 1 |
|
|
| return output_image |
|
|
| def base_transform(im, size): |
| im = tensor_to_np_float(im) |
| if len(im.shape) == 3: |
| im = im.transpose((1, 2, 0)) |
| else: |
| im = im[:, :, None] |
|
|
| |
| if im.shape[1] != size: |
| im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) |
|
|
| return im.clip(0, 1) |
|
|
| def im_transform(im, size): |
| return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) |
|
|
| def mask_transform(mask, size): |
| return base_transform(detach_to_cpu(mask), size=size) |
|
|
| def out_transform(mask, size): |
| return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) |
|
|
| def lll2rgb_transform(mask, size): |
| flag_test = False |
|
|
| mask_d = detach_to_cpu(mask) |
|
|
| mask_d[1:3,:,:] = 0 |
| |
| if flag_test: print('before inv', mask_d.size(), torch.min(mask_d), torch.max(mask_d)) |
| mask_d = inv_lll2rgb_trans(mask_d) |
| if flag_test: print('after inv', mask_d.size(), torch.min(mask_d), torch.max(mask_d));assert 1==0 |
|
|
| im = tensor_to_np_float(mask_d) |
|
|
| if len(im.shape) == 3: |
| im = im.transpose((1, 2, 0)) |
| else: |
| im = im[:, :, None] |
|
|
| im = color.lab2rgb(im) |
|
|
| |
| if im.shape[1] != size: |
| im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) |
|
|
| return im.clip(0, 1) |
|
|
| def lab2rgb_transform(mask, size): |
| flag_test = False |
|
|
| mask_d = detach_to_cpu(mask) |
| |
| if flag_test: print('before inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d)) |
| mask_d = inv_lll2rgb_trans(mask_d) |
| if flag_test: print('after inv', mask_d.size(), torch.max(mask_d), torch.min(mask_d));assert 1==0 |
|
|
| im = tensor_to_np_float(mask_d) |
|
|
| if len(im.shape) == 3: |
| im = im.transpose((1, 2, 0)) |
| else: |
| im = im[:, :, None] |
|
|
| im = color.lab2rgb(im) |
|
|
| |
| if im.shape[1] != size: |
| im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) |
|
|
| return im.clip(0, 1) |
|
|
|
|
|
|
| def pool_pairs_221128_TransColorization(images, size, num_objects): |
| req_images = defaultdict(list) |
|
|
| b, t = images['rgb'].shape[:2] |
|
|
| |
| b = min(2, b) |
|
|
| |
|
|
| |
| max_num_objects = 1 |
|
|
| GT_suffix = '' |
| for bi in range(b): |
| GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] |
|
|
| |
| |
| |
| |
|
|
| for bi in range(b): |
| for ti in range(t): |
|
|
| req_images['RGB'].append(lll2rgb_transform(images['rgb'][bi,ti], size)) |
|
|
| for oi in range(max_num_objects): |
| if ti == 0 or oi >= num_objects[bi]: |
|
|
| |
| |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['first_frame_gt'][bi][0,:]], dim=0), size)) |
|
|
|
|
| else: |
| |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['masks_%d'%ti][bi][:]], dim=0), size)) |
|
|
| |
| |
| req_images['GT_%d_%s'%(oi, GT_suffix)].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['cls_gt'][bi,ti,:,:]], dim=0), size)) |
|
|
| |
| |
|
|
|
|
| return get_image_array(req_images, size, key_captions) |
|
|
|
|
| def pool_pairs_221128_TransColorization_val(images, size, num_objects): |
| req_images = defaultdict(list) |
|
|
| b, t = images['rgb'].shape[:2] |
|
|
| |
| b = min(2, b) |
|
|
| |
|
|
| |
| max_num_objects = 1 |
|
|
| GT_suffix = '' |
| for bi in range(b): |
| GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] |
|
|
| |
| |
| |
| |
|
|
| for bi in range(b): |
| for ti in range(t): |
|
|
| req_images['RGB'].append(lll2rgb_transform(images['rgb'][bi,ti], size)) |
|
|
| for oi in range(max_num_objects): |
| if ti == 0 or oi >= num_objects[bi]: |
|
|
| |
| |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['first_frame_gt'][bi][0,:]], dim=0), size)) |
|
|
|
|
| else: |
| |
| req_images['Mask_%d'%oi].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['masks_%d'%ti][bi][:]], dim=0), size)) |
|
|
| |
| |
| req_images['GT_%d_%s'%(oi, GT_suffix)].append(lab2rgb_transform(torch.cat([images['rgb'][bi,ti][:1,:,:], images['cls_gt'][bi,ti,:,:]], dim=0), size)) |
|
|
| |
| |
|
|
|
|
| return get_image_array(req_images, size, key_captions) |
|
|
|
|
|
|
| def pool_pairs(images, size, num_objects): |
| req_images = defaultdict(list) |
|
|
| b, t = images['rgb'].shape[:2] |
|
|
| |
| b = min(2, b) |
|
|
| |
| max_num_objects = max(num_objects[:b]) |
|
|
| GT_suffix = '' |
| for bi in range(b): |
| GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] |
|
|
| for bi in range(b): |
| for ti in range(t): |
| req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size)) |
| for oi in range(max_num_objects): |
| if ti == 0 or oi >= num_objects[bi]: |
| req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size)) |
| |
| |
| else: |
| req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size)) |
| |
| |
| |
| req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size)) |
| |
| |
|
|
|
|
| return get_image_array(req_images, size, key_captions) |