| import cv2 |
| import numpy as np |
| |
| import random |
| import torch |
| import torchvision.transforms as transforms |
| |
| from pathlib import Path |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from ..utils import letterbox, augment_hsv, random_perspective, xyxy2xywh, cutout |
|
|
|
|
| class AutoDriveDataset(Dataset): |
| """ |
| A general Dataset for some common function |
| """ |
| def __init__(self, cfg, is_train, inputsize=640, transform=None): |
| """ |
| initial all the characteristic |
| |
| Inputs: |
| -cfg: configurations |
| -is_train(bool): whether train set or not |
| -transform: ToTensor and Normalize |
| |
| Returns: |
| None |
| """ |
| self.is_train = is_train |
| self.cfg = cfg |
| self.transform = transform |
| self.inputsize = inputsize |
| self.Tensor = transforms.ToTensor() |
| img_root = Path(cfg.DATASET.DATAROOT) |
| label_root = Path(cfg.DATASET.LABELROOT) |
| mask_root = Path(cfg.DATASET.MASKROOT) |
| lane_root = Path(cfg.DATASET.LANEROOT) |
| if is_train: |
| indicator = cfg.DATASET.TRAIN_SET |
| else: |
| indicator = cfg.DATASET.TEST_SET |
| self.img_root = img_root / indicator |
| self.label_root = label_root / indicator |
| self.mask_root = mask_root / indicator |
| self.lane_root = lane_root / indicator |
| |
| self.mask_list = self.mask_root.iterdir() |
|
|
| self.db = [] |
|
|
| self.data_format = cfg.DATASET.DATA_FORMAT |
|
|
| self.scale_factor = cfg.DATASET.SCALE_FACTOR |
| self.rotation_factor = cfg.DATASET.ROT_FACTOR |
| self.flip = cfg.DATASET.FLIP |
| self.color_rgb = cfg.DATASET.COLOR_RGB |
|
|
| |
| self.shapes = np.array(cfg.DATASET.ORG_IMG_SIZE) |
| |
| def _get_db(self): |
| """ |
| finished on children Dataset(for dataset which is not in Bdd100k format, rewrite children Dataset) |
| """ |
| raise NotImplementedError |
|
|
| def evaluate(self, cfg, preds, output_dir): |
| """ |
| finished on children dataset |
| """ |
| raise NotImplementedError |
| |
| def __len__(self,): |
| """ |
| number of objects in the dataset |
| """ |
| return len(self.db) |
|
|
| def __getitem__(self, idx): |
| """ |
| Get input and groud-truth from database & add data augmentation on input |
| |
| Inputs: |
| -idx: the index of image in self.db(database)(list) |
| self.db(list) [a,b,c,...] |
| a: (dictionary){'image':, 'information':} |
| |
| Returns: |
| -image: transformed image, first passed the data augmentation in __getitem__ function(type:numpy), then apply self.transform |
| -target: ground truth(det_gt,seg_gt) |
| |
| function maybe useful |
| cv2.imread |
| cv2.cvtColor(data, cv2.COLOR_BGR2RGB) |
| cv2.warpAffine |
| """ |
| data = self.db[idx] |
| img = cv2.imread(data["image"], cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
| if self.cfg.num_seg_class == 3: |
| seg_label = cv2.imread(data["mask"]) |
| else: |
| seg_label = cv2.imread(data["mask"], 0) |
| lane_label = cv2.imread(data["lane"], 0) |
| |
| |
| |
| |
| resized_shape = self.inputsize |
| if isinstance(resized_shape, list): |
| resized_shape = max(resized_shape) |
| h0, w0 = img.shape[:2] |
| r = resized_shape / max(h0, w0) |
| if r != 1: |
| interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR |
| img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp) |
| seg_label = cv2.resize(seg_label, (int(w0 * r), int(h0 * r)), interpolation=interp) |
| lane_label = cv2.resize(lane_label, (int(w0 * r), int(h0 * r)), interpolation=interp) |
| h, w = img.shape[:2] |
| |
| (img, seg_label, lane_label), ratio, pad = letterbox((img, seg_label, lane_label), resized_shape, auto=True, scaleup=self.is_train) |
| shapes = (h0, w0), ((h / h0, w / w0), pad) |
| |
| |
| |
| det_label = data["label"] |
| labels=[] |
| |
| if det_label.size > 0: |
| |
| labels = det_label.copy() |
| labels[:, 1] = ratio[0] * w * (det_label[:, 1] - det_label[:, 3] / 2) + pad[0] |
| labels[:, 2] = ratio[1] * h * (det_label[:, 2] - det_label[:, 4] / 2) + pad[1] |
| labels[:, 3] = ratio[0] * w * (det_label[:, 1] + det_label[:, 3] / 2) + pad[0] |
| labels[:, 4] = ratio[1] * h * (det_label[:, 2] + det_label[:, 4] / 2) + pad[1] |
| |
| if self.is_train: |
| combination = (img, seg_label, lane_label) |
| (img, seg_label, lane_label), labels = random_perspective( |
| combination=combination, |
| targets=labels, |
| degrees=self.cfg.DATASET.ROT_FACTOR, |
| translate=self.cfg.DATASET.TRANSLATE, |
| scale=self.cfg.DATASET.SCALE_FACTOR, |
| shear=self.cfg.DATASET.SHEAR |
| ) |
| |
| augment_hsv(img, hgain=self.cfg.DATASET.HSV_H, sgain=self.cfg.DATASET.HSV_S, vgain=self.cfg.DATASET.HSV_V) |
| |
|
|
| if len(labels): |
| |
| labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) |
|
|
| |
| labels[:, [2, 4]] /= img.shape[0] |
| labels[:, [1, 3]] /= img.shape[1] |
|
|
| |
| |
| lr_flip = True |
| if lr_flip and random.random() < 0.5: |
| img = np.fliplr(img) |
| seg_label = np.fliplr(seg_label) |
| lane_label = np.fliplr(lane_label) |
| if len(labels): |
| labels[:, 1] = 1 - labels[:, 1] |
|
|
| |
| ud_flip = False |
| if ud_flip and random.random() < 0.5: |
| img = np.flipud(img) |
| seg_label = np.filpud(seg_label) |
| lane_label = np.filpud(lane_label) |
| if len(labels): |
| labels[:, 2] = 1 - labels[:, 2] |
| |
| else: |
| if len(labels): |
| |
| labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) |
|
|
| |
| labels[:, [2, 4]] /= img.shape[0] |
| labels[:, [1, 3]] /= img.shape[1] |
|
|
| labels_out = torch.zeros((len(labels), 6)) |
| if len(labels): |
| labels_out[:, 1:] = torch.from_numpy(labels) |
| |
| |
| |
| img = np.ascontiguousarray(img) |
| |
| |
| |
|
|
| if self.cfg.num_seg_class == 3: |
| _,seg0 = cv2.threshold(seg_label[:,:,0],128,255,cv2.THRESH_BINARY) |
| _,seg1 = cv2.threshold(seg_label[:,:,1],1,255,cv2.THRESH_BINARY) |
| _,seg2 = cv2.threshold(seg_label[:,:,2],1,255,cv2.THRESH_BINARY) |
| else: |
| _,seg1 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY) |
| _,seg2 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY_INV) |
| _,lane1 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY) |
| _,lane2 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY_INV) |
| |
| |
| |
| |
| |
| |
| if self.cfg.num_seg_class == 3: |
| seg0 = self.Tensor(seg0) |
| seg1 = self.Tensor(seg1) |
| seg2 = self.Tensor(seg2) |
| |
| |
| lane1 = self.Tensor(lane1) |
| lane2 = self.Tensor(lane2) |
|
|
| |
| if self.cfg.num_seg_class == 3: |
| seg_label = torch.stack((seg0[0],seg1[0],seg2[0]),0) |
| else: |
| seg_label = torch.stack((seg2[0], seg1[0]),0) |
| |
| lane_label = torch.stack((lane2[0], lane1[0]),0) |
| |
| |
| |
|
|
| target = [labels_out, seg_label, lane_label] |
| img = self.transform(img) |
|
|
| return img, target, data["image"], shapes |
|
|
| def select_data(self, db): |
| """ |
| You can use this function to filter useless images in the dataset |
| |
| Inputs: |
| -db: (list)database |
| |
| Returns: |
| -db_selected: (list)filtered dataset |
| """ |
| db_selected = ... |
| return db_selected |
|
|
| @staticmethod |
| def collate_fn(batch): |
| img, label, paths, shapes= zip(*batch) |
| label_det, label_seg, label_lane = [], [], [] |
| for i, l in enumerate(label): |
| l_det, l_seg, l_lane = l |
| l_det[:, 0] = i |
| label_det.append(l_det) |
| label_seg.append(l_seg) |
| label_lane.append(l_lane) |
| return torch.stack(img, 0), [torch.cat(label_det, 0), torch.stack(label_seg, 0), torch.stack(label_lane, 0)], paths, shapes |
|
|
|
|