| import cv2
|
| import torch
|
| from torch.utils.data import Dataset
|
| from torchvision.transforms import Compose
|
|
|
| from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
|
|
|
|
|
| class VKITTI2(Dataset):
|
| def __init__(self, filelist_path, mode, size=(518, 518)):
|
|
|
| self.mode = mode
|
| self.size = size
|
|
|
| with open(filelist_path, 'r') as f:
|
| self.filelist = f.read().splitlines()
|
|
|
| net_w, net_h = size
|
| self.transform = Compose([
|
| Resize(
|
| width=net_w,
|
| height=net_h,
|
| resize_target=True if mode == 'train' else False,
|
| keep_aspect_ratio=True,
|
| ensure_multiple_of=14,
|
| resize_method='lower_bound',
|
| image_interpolation_method=cv2.INTER_CUBIC,
|
| ),
|
| NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| PrepareForNet(),
|
| ] + ([Crop(size[0])] if self.mode == 'train' else []))
|
|
|
| def __getitem__(self, item):
|
| img_path = self.filelist[item].split(' ')[0]
|
| depth_path = self.filelist[item].split(' ')[1]
|
|
|
| image = cv2.imread(img_path)
|
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
|
|
|
| depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 100.0
|
|
|
| sample = self.transform({'image': image, 'depth': depth})
|
|
|
| sample['image'] = torch.from_numpy(sample['image'])
|
| sample['depth'] = torch.from_numpy(sample['depth'])
|
|
|
| sample['valid_mask'] = (sample['depth'] <= 80)
|
|
|
| sample['image_path'] = self.filelist[item].split(' ')[0]
|
|
|
| return sample
|
|
|
| def __len__(self):
|
| return len(self.filelist) |