| import os, torch, cv2, re |
| import numpy as np |
|
|
| from PIL import Image |
| import torch.nn.functional as F |
| import torchvision.transforms as T |
|
|
| |
| img2mse = lambda x, y: torch.mean((x - y) ** 2) |
| mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) |
| to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) |
| mse2psnr2 = lambda x: -10. * np.log(x) / np.log(10.) |
|
|
|
|
| def get_psnr(imgs_pred, imgs_gt): |
| psnrs = [] |
| for (img, tar) in zip(imgs_pred, imgs_gt): |
| psnrs.append(mse2psnr2(np.mean((img - tar.cpu().numpy()) ** 2))) |
| return np.array(psnrs) |
|
|
|
|
| def init_log(log, keys): |
| for key in keys: |
| log[key] = torch.tensor([0.0], dtype=float) |
| return log |
|
|
|
|
| def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): |
| """ |
| depth: (H, W) |
| """ |
|
|
| x = np.nan_to_num(depth) |
| if minmax is None: |
| mi = np.min(x[x > 0]) |
| ma = np.max(x) |
| else: |
| mi, ma = minmax |
|
|
| x = (x - mi) / (ma - mi + 1e-8) |
| x = (255 * x).astype(np.uint8) |
| x_ = cv2.applyColorMap(x, cmap) |
| return x_, [mi, ma] |
|
|
|
|
| def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): |
| """ |
| depth: (H, W) |
| """ |
| if type(depth) is not np.ndarray: |
| depth = depth.cpu().numpy() |
|
|
| x = np.nan_to_num(depth) |
| if minmax is None: |
| mi = np.min(x[x > 0]) |
| ma = np.max(x) |
| else: |
| mi, ma = minmax |
|
|
| x = (x - mi) / (ma - mi + 1e-8) |
| x = (255 * x).astype(np.uint8) |
| x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) |
| x_ = T.ToTensor()(x_) |
| return x_, [mi, ma] |
|
|
|
|
| def abs_error_numpy(depth_pred, depth_gt, mask): |
| depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] |
| return np.abs(depth_pred - depth_gt) |
|
|
|
|
| def abs_error(depth_pred, depth_gt, mask): |
| depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] |
| err = depth_pred - depth_gt |
| return np.abs(err) if type(depth_pred) is np.ndarray else err.abs() |
|
|
|
|
| def acc_threshold(depth_pred, depth_gt, mask, threshold): |
| """ |
| computes the percentage of pixels whose depth error is less than @threshold |
| """ |
| errors = abs_error(depth_pred, depth_gt, mask) |
| acc_mask = errors < threshold |
| return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float() |
|
|
|
|
| def to_tensor_cuda(data, device, filter): |
| for item in data.keys(): |
|
|
| if item in filter: |
| continue |
|
|
| if type(data[item]) is np.ndarray: |
| data[item] = torch.tensor(data[item], dtype=torch.float32, device=device) |
| else: |
| data[item] = data[item].float().to(device) |
| return data |
|
|
|
|
| def to_cuda(data, device, filter): |
| for item in data.keys(): |
| if item in filter: |
| continue |
|
|
| data[item] = data[item].float().to(device) |
| return data |
|
|
|
|
| def tensor_unsqueeze(data, filter): |
| for item in data.keys(): |
| if item in filter: |
| continue |
|
|
| data[item] = data[item][None] |
| return data |
|
|
|
|
| def filter_keys(dict): |
| dict.pop('N_samples') |
| if 'ndc' in dict.keys(): |
| dict.pop('ndc') |
| if 'lindisp' in dict.keys(): |
| dict.pop('lindisp') |
| return dict |
|
|
|
|
| def sub_selete_data(data_batch, device, idx, filtKey=[], |
| filtIndex=['view_ids_all', 'c2ws_all', 'scan', 'bbox', 'w2ref', 'ref2w', 'light_id', 'ckpt', |
| 'idx']): |
| data_sub_selete = {} |
| for item in data_batch.keys(): |
| data_sub_selete[item] = data_batch[item][:, idx].float() if ( |
| item not in filtIndex and torch.is_tensor(item) and item.dim() > 2) else data_batch[item].float() |
| if not data_sub_selete[item].is_cuda: |
| data_sub_selete[item] = data_sub_selete[item].to(device) |
| return data_sub_selete |
|
|
|
|
| def detach_data(dictionary): |
| dictionary_new = {} |
| for key in dictionary.keys(): |
| dictionary_new[key] = dictionary[key].detach().clone() |
| return dictionary_new |
|
|
|
|
| def read_pfm(filename): |
| file = open(filename, 'rb') |
| color = None |
| width = None |
| height = None |
| scale = None |
| endian = None |
|
|
| header = file.readline().decode('utf-8').rstrip() |
| if header == 'PF': |
| color = True |
| elif header == 'Pf': |
| color = False |
| else: |
| raise Exception('Not a PFM file.') |
|
|
| dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) |
| if dim_match: |
| width, height = map(int, dim_match.groups()) |
| else: |
| raise Exception('Malformed PFM header.') |
|
|
| scale = float(file.readline().rstrip()) |
| if scale < 0: |
| endian = '<' |
| scale = -scale |
| else: |
| endian = '>' |
|
|
| data = np.fromfile(file, endian + 'f') |
| shape = (height, width, 3) if color else (height, width) |
|
|
| data = np.reshape(data, shape) |
| data = np.flipud(data) |
| file.close() |
| return data, scale |
|
|
|
|
| from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR |
|
|
|
|
| |
| def get_scheduler(hparams, optimizer): |
| eps = 1e-8 |
| if hparams.lr_scheduler == 'steplr': |
| scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, |
| gamma=hparams.decay_gamma) |
| elif hparams.lr_scheduler == 'cosine': |
| scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) |
|
|
| else: |
| raise ValueError('scheduler not recognized!') |
|
|
| |
| |
| |
| return scheduler |
|
|
|
|
| |
| def get_nearest_pose_ids(tar_pose, ref_poses, num_select): |
| ''' |
| Args: |
| tar_pose: target pose [N, 4, 4] |
| ref_poses: reference poses [M, 4, 4] |
| num_select: the number of nearest views to select |
| Returns: the selected indices |
| ''' |
|
|
| dists = np.linalg.norm(tar_pose[:, None, :3, 3] - ref_poses[None, :, :3, 3], axis=-1) |
|
|
| sorted_ids = np.argsort(dists, axis=-1) |
| selected_ids = sorted_ids[:, :num_select] |
| return selected_ids |
|
|