| | from copy import deepcopy |
| | from pathlib import Path |
| | from typing import Any, Dict, List |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.utils.data as torchdata |
| | import torchvision.transforms as tvf |
| | from PIL import Image |
| | from pathlib import Path |
| |
|
| | from ...models.utils import deg2rad, rotmat2d |
| | from ...utils.io import read_image |
| | from ...utils.wrappers import Camera |
| | from ..image import pad_image, rectify_image, resize_image |
| | from ..utils import decompose_rotmat |
| | from ..schema import MIADataConfiguration |
| |
|
| |
|
| | class MapLocDataset(torchdata.Dataset): |
| | def __init__( |
| | self, |
| | stage: str, |
| | cfg: MIADataConfiguration, |
| | names: List[str], |
| | data: Dict[str, Any], |
| | image_dirs: Dict[str, Path], |
| | seg_mask_dirs: Dict[str, Path], |
| | flood_masks_dirs: Dict[str, Path], |
| | image_ext: str = "", |
| | ): |
| | self.stage = stage |
| | self.cfg = deepcopy(cfg) |
| | self.data = data |
| | self.image_dirs = image_dirs |
| | self.seg_mask_dirs = seg_mask_dirs |
| | self.flood_masks_dirs = flood_masks_dirs |
| | self.names = names |
| | self.image_ext = image_ext |
| |
|
| | tfs = [] |
| | self.tfs = tvf.Compose(tfs) |
| | self.augmentations = self.get_augmentations() |
| |
|
| | def __len__(self): |
| | return len(self.names) |
| |
|
| | def __getitem__(self, idx): |
| | if self.stage == "train" and self.cfg.random: |
| | seed = None |
| | else: |
| | seed = [self.cfg.seed, idx] |
| | (seed,) = np.random.SeedSequence(seed).generate_state(1) |
| |
|
| | scene, seq, name = self.names[idx] |
| |
|
| | view = self.get_view( |
| | idx, scene, seq, name, seed |
| | ) |
| |
|
| | return view |
| |
|
| | def get_augmentations(self): |
| | if self.stage != "train" or not self.cfg.augmentations.enabled: |
| | print(f"No Augmentation!", "\n" * 10) |
| | self.cfg.augmentations.random_flip = 0.0 |
| | return tvf.Compose([]) |
| |
|
| | print(f"Augmentation!", "\n" * 10) |
| | augmentations = [ |
| | tvf.ColorJitter( |
| | brightness=self.cfg.augmentations.brightness, |
| | contrast=self.cfg.augmentations.contrast, |
| | saturation=self.cfg.augmentations.saturation, |
| | hue=self.cfg.augmentations.hue, |
| | ) |
| | ] |
| |
|
| | if self.cfg.augmentations.random_resized_crop: |
| | augmentations.append( |
| | tvf.RandomResizedCrop(scale=(0.8, 1.0)) |
| | ) |
| |
|
| | if self.cfg.augmentations.gaussian_noise.enabled: |
| | augmentations.append( |
| | tvf.GaussianNoise( |
| | mean=self.cfg.augmentations.gaussian_noise.mean, |
| | std=self.cfg.augmentations.gaussian_noise.std, |
| | ) |
| | ) |
| |
|
| | if self.cfg.augmentations.brightness_contrast.enabled: |
| | augmentations.append( |
| | tvf.ColorJitter( |
| | brightness=self.cfg.augmentations.brightness_contrast.brightness_factor, |
| | contrast=self.cfg.augmentations.brightness_contrast.contrast_factor, |
| | saturation=0, |
| | hue=0, |
| | ) |
| | ) |
| |
|
| | return tvf.Compose(augmentations) |
| |
|
| | def random_flip(self, image, cam, valid, seg_mask, flood_mask, conf_mask): |
| | if torch.rand(1) < self.cfg.augmentations.random_flip: |
| | image = torch.flip(image, [-1]) |
| | cam = cam.flip() |
| | valid = torch.flip(valid, [-1]) |
| | seg_mask = torch.flip(seg_mask, [1]) |
| | flood_mask = torch.flip(flood_mask, [-1]) |
| | conf_mask = torch.flip(conf_mask, [-1]) |
| |
|
| | return image, cam, valid, seg_mask, flood_mask, conf_mask |
| |
|
| | def get_view(self, idx, scene, seq, name, seed): |
| | data = { |
| | "index": idx, |
| | "name": name, |
| | "scene": scene, |
| | "sequence": seq, |
| | } |
| | cam_dict = self.data["cameras"][scene][seq][self.data["camera_id"][idx]] |
| | cam = Camera.from_dict(cam_dict).float() |
| |
|
| | if "roll_pitch_yaw" in self.data: |
| | roll, pitch, yaw = self.data["roll_pitch_yaw"][idx].numpy() |
| | else: |
| | roll, pitch, yaw = decompose_rotmat( |
| | self.data["R_c2w"][idx].numpy()) |
| |
|
| | image = read_image(self.image_dirs[scene] / (name + self.image_ext)) |
| | image = Image.fromarray(image) |
| | image = self.augmentations(image) |
| | image = np.array(image) |
| |
|
| | if "plane_params" in self.data: |
| | |
| | plane_w = self.data["plane_params"][idx] |
| | data["ground_plane"] = torch.cat( |
| | [rotmat2d(deg2rad(torch.tensor(yaw))) |
| | @ plane_w[:2], plane_w[2:]] |
| | ) |
| |
|
| | image, valid, cam, roll, pitch = self.process_image( |
| | image, cam, roll, pitch, seed |
| | ) |
| |
|
| | if "chunk_index" in self.data: |
| | data["chunk_id"] = (scene, seq, self.data["chunk_index"][idx]) |
| |
|
| | |
| | seg_mask_path = self.seg_mask_dirs[scene] / \ |
| | (name.split("_")[0] + ".npy") |
| | seg_masks_ours = np.load(seg_mask_path) |
| | mask_center = ( |
| | seg_masks_ours.shape[0] // 2, seg_masks_ours.shape[1] // 2) |
| |
|
| | seg_masks_ours = seg_masks_ours[mask_center[0] - |
| | 100:mask_center[0], mask_center[1] - 50: mask_center[1] + 50] |
| |
|
| | if self.cfg.num_classes == 6: |
| | seg_masks_ours = seg_masks_ours[..., [0, 1, 2, 4, 6, 7]] |
| |
|
| | flood_mask_path = self.flood_masks_dirs[scene] / \ |
| | (name.split("_")[0] + ".npy") |
| | flood_mask = np.load(flood_mask_path) |
| |
|
| | flood_mask = flood_mask[mask_center[0]-100:mask_center[0], |
| | mask_center[1] - 50: mask_center[1] + 50] |
| |
|
| | confidence_map = flood_mask.copy() |
| | confidence_map = (confidence_map - confidence_map.min()) / \ |
| | (confidence_map.max() - confidence_map.min() + 1e-6) |
| |
|
| | seg_masks_ours = torch.from_numpy(seg_masks_ours).float() |
| | flood_mask = torch.from_numpy(flood_mask).float() |
| | confidence_map = torch.from_numpy(confidence_map).float() |
| |
|
| | |
| | with torch.random.fork_rng(devices=[]): |
| | torch.manual_seed(seed) |
| | image, cam, valid, seg_masks_ours, flood_mask, confidence_map = self.random_flip( |
| | image, cam, valid, seg_masks_ours, flood_mask, confidence_map) |
| |
|
| | return { |
| | **data, |
| | "image": image, |
| | "valid": valid, |
| | "camera": cam, |
| | "seg_masks": seg_masks_ours, |
| | "flood_masks": flood_mask, |
| | "roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(), |
| | "confidence_map": confidence_map |
| | |
| | } |
| |
|
| | def process_image(self, image, cam, roll, pitch, seed): |
| | image = ( |
| | torch.from_numpy(np.ascontiguousarray(image)) |
| | .permute(2, 0, 1) |
| | .float() |
| | .div_(255) |
| | ) |
| |
|
| | if not self.cfg.gravity_align: |
| | |
| | roll = 0.0 |
| | pitch = 0.0 |
| | image, valid = rectify_image(image, cam, roll, pitch) |
| | else: |
| | image, valid = rectify_image( |
| | image, cam, roll, pitch if self.cfg.rectify_pitch else None |
| | ) |
| | roll = 0.0 |
| | if self.cfg.rectify_pitch: |
| | pitch = 0.0 |
| |
|
| | if self.cfg.target_focal_length is not None: |
| | |
| | factor = self.cfg.target_focal_length / cam.f.numpy() |
| | size = (np.array(image.shape[-2:][::-1]) * factor).astype(int) |
| | image, _, cam, valid = resize_image( |
| | image, size, camera=cam, valid=valid) |
| | size_out = self.cfg.resize_image |
| | if size_out is None: |
| | |
| | stride = self.cfg.pad_to_multiple |
| | size_out = (np.ceil((size / stride)) * stride).astype(int) |
| | |
| | image, valid, cam = pad_image( |
| | image, size_out, cam, valid, crop_and_center=True |
| | ) |
| | elif self.cfg.resize_image is not None: |
| | image, _, cam, valid = resize_image( |
| | image, self.cfg.resize_image, fn=max, camera=cam, valid=valid |
| | ) |
| | if self.cfg.pad_to_square: |
| | |
| | image, valid, cam = pad_image( |
| | image, self.cfg.resize_image, cam, valid) |
| |
|
| | if self.cfg.reduce_fov is not None: |
| | h, w = image.shape[-2:] |
| | f = float(cam.f[0]) |
| | fov = np.arctan(w / f / 2) |
| | w_new = round(2 * f * np.tan(self.cfg.reduce_fov * fov)) |
| | image, valid, cam = pad_image( |
| | image, (w_new, h), cam, valid, crop_and_center=True |
| | ) |
| |
|
| | with torch.random.fork_rng(devices=[]): |
| | torch.manual_seed(seed) |
| | image = self.tfs(image) |
| |
|
| | return image, valid, cam, roll, pitch |
| |
|