| |
| |
| """ |
| Misc functions, including distributed helpers. |
| |
| Mostly copy-paste from torchvision references. |
| """ |
| from typing import List, Optional |
| from collections import OrderedDict |
| from scipy.io import loadmat |
| import numpy as np |
| import csv |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import torch |
| import torch.distributed as dist |
| import torchvision |
| from torch import Tensor |
|
|
|
|
| def _max_by_axis(the_list): |
| |
| maxes = the_list[0] |
| for sublist in the_list[1:]: |
| for index, item in enumerate(sublist): |
| maxes[index] = max(maxes[index], item) |
| return maxes |
|
|
| def get_world_size() -> int: |
| if not dist.is_available(): |
| return 1 |
| if not dist.is_initialized(): |
| return 1 |
| return dist.get_world_size() |
|
|
| def reduce_dict(input_dict, average=True): |
| """ |
| Args: |
| input_dict (dict): all the values will be reduced |
| average (bool): whether to do average or sum |
| Reduce the values in the dictionary from all processes so that all processes |
| have the averaged results. Returns a dict with the same fields as |
| input_dict, after reduction. |
| """ |
| world_size = get_world_size() |
| if world_size < 2: |
| return input_dict |
| with torch.no_grad(): |
| names = [] |
| values = [] |
| |
| for k in sorted(input_dict.keys()): |
| names.append(k) |
| values.append(input_dict[k]) |
| values = torch.stack(values, dim=0) |
| dist.all_reduce(values) |
| if average: |
| values /= world_size |
| reduced_dict = {k: v for k, v in zip(names, values)} |
| return reduced_dict |
|
|
| class NestedTensor(object): |
| def __init__(self, tensors, mask: Optional[Tensor]): |
| self.tensors = tensors |
| self.mask = mask |
|
|
| def to(self, device): |
| |
| cast_tensor = self.tensors.to(device) |
| mask = self.mask |
| if mask is not None: |
| assert mask is not None |
| cast_mask = mask.to(device) |
| else: |
| cast_mask = None |
| return NestedTensor(cast_tensor, cast_mask) |
|
|
| def decompose(self): |
| return self.tensors, self.mask |
|
|
| def __repr__(self): |
| return str(self.tensors) |
|
|
| def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
| |
| if tensor_list[0].ndim == 3: |
| if torchvision._is_tracing(): |
| |
| |
| return _onnx_nested_tensor_from_tensor_list(tensor_list) |
|
|
| |
| max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
| |
| batch_shape = [len(tensor_list)] + max_size |
| b, c, h, w = batch_shape |
| dtype = tensor_list[0].dtype |
| device = tensor_list[0].device |
| tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
| mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
| for img, pad_img, m in zip(tensor_list, tensor, mask): |
| pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
| m[: img.shape[1], : img.shape[2]] = False |
| else: |
| raise ValueError("not supported") |
| return NestedTensor(tensor, mask) |
|
|
| |
| |
| @torch.jit.unused |
| def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: |
| max_size = [] |
| for i in range(tensor_list[0].dim()): |
| max_size_i = torch.max( |
| torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) |
| ).to(torch.int64) |
| max_size.append(max_size_i) |
| max_size = tuple(max_size) |
|
|
| |
| |
| |
| |
| padded_imgs = [] |
| padded_masks = [] |
| for img in tensor_list: |
| padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] |
| padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) |
| padded_imgs.append(padded_img) |
|
|
| m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) |
| padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) |
| padded_masks.append(padded_mask.to(torch.bool)) |
|
|
| tensor = torch.stack(padded_imgs) |
| mask = torch.stack(padded_masks) |
|
|
| return NestedTensor(tensor, mask=mask) |
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|
| def load_parallal_model(model, state_dict_): |
| state_dict = OrderedDict() |
| for key in state_dict_: |
| if key.startswith('module') and not key.startswith('module_list'): |
| state_dict[key[7:]] = state_dict_[key] |
| else: |
| state_dict[key] = state_dict_[key] |
|
|
| |
| model_state_dict = model.state_dict() |
| for key in state_dict: |
| if key in model_state_dict: |
| if state_dict[key].shape != model_state_dict[key].shape: |
| print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format( |
| key, model_state_dict[key].shape, state_dict[key].shape)) |
| state_dict[key] = model_state_dict[key] |
| else: |
| print('Drop parameter {}.'.format(key)) |
| for key in model_state_dict: |
| if key not in state_dict: |
| print('No param {}.'.format(key)) |
| state_dict[key] = model_state_dict[key] |
| model.load_state_dict(state_dict, strict=False) |
|
|
| return model |
|
|
| class ADEVisualize(object): |
| def __init__(self): |
| self.colors = loadmat('dataset/color150.mat')['colors'] |
| self.names = {} |
| with open('dataset/object150_info.csv') as f: |
| reader = csv.reader(f) |
| next(reader) |
| for row in reader: |
| self.names[int(row[0])] = row[5].split(";")[0] |
|
|
| def unique(self, ar, return_index=False, return_inverse=False, return_counts=False): |
| ar = np.asanyarray(ar).flatten() |
|
|
| optional_indices = return_index or return_inverse |
| optional_returns = optional_indices or return_counts |
|
|
| if ar.size == 0: |
| if not optional_returns: |
| ret = ar |
| else: |
| ret = (ar,) |
| if return_index: |
| ret += (np.empty(0, np.bool),) |
| if return_inverse: |
| ret += (np.empty(0, np.bool),) |
| if return_counts: |
| ret += (np.empty(0, np.intp),) |
| return ret |
| if optional_indices: |
| perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') |
| aux = ar[perm] |
| else: |
| ar.sort() |
| aux = ar |
| flag = np.concatenate(([True], aux[1:] != aux[:-1])) |
|
|
| if not optional_returns: |
| ret = aux[flag] |
| else: |
| ret = (aux[flag],) |
| if return_index: |
| ret += (perm[flag],) |
| if return_inverse: |
| iflag = np.cumsum(flag) - 1 |
| inv_idx = np.empty(ar.shape, dtype=np.intp) |
| inv_idx[perm] = iflag |
| ret += (inv_idx,) |
| if return_counts: |
| idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) |
| ret += (np.diff(idx),) |
| return ret |
|
|
| def colorEncode(self, labelmap, colors, mode='RGB'): |
| labelmap = labelmap.astype('int') |
| labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), |
| dtype=np.uint8) |
| for label in self.unique(labelmap): |
| if label < 0: |
| continue |
| labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ |
| np.tile(colors[label], |
| (labelmap.shape[0], labelmap.shape[1], 1)) |
|
|
| if mode == 'BGR': |
| return labelmap_rgb[:, :, ::-1] |
| else: |
| return labelmap_rgb |
|
|
| def show_result(self, img, pred, save_path=None): |
| pred = np.int32(pred) |
| |
| pred_color = self.colorEncode(pred, self.colors) |
| pil_img = img.convert('RGBA') |
| pred_color = Image.fromarray(pred_color).convert('RGBA') |
| im_vis = Image.blend(pil_img, pred_color, 0.6) |
| if save_path is not None: |
| im_vis.save(save_path) |
| |
| else: |
| plt.imshow(im_vis) |