| import math |
| import os.path as osp |
|
|
| import pytest |
| from torch.utils.data import (DistributedSampler, RandomSampler, |
| SequentialSampler) |
|
|
| from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader, |
| build_dataset) |
|
|
|
|
| @DATASETS.register_module() |
| class ToyDataset(object): |
|
|
| def __init__(self, cnt=0): |
| self.cnt = cnt |
|
|
| def __item__(self, idx): |
| return idx |
|
|
| def __len__(self): |
| return 100 |
|
|
|
|
| def test_build_dataset(): |
| cfg = dict(type='ToyDataset') |
| dataset = build_dataset(cfg) |
| assert isinstance(dataset, ToyDataset) |
| assert dataset.cnt == 0 |
| dataset = build_dataset(cfg, default_args=dict(cnt=1)) |
| assert isinstance(dataset, ToyDataset) |
| assert dataset.cnt == 1 |
|
|
| data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') |
| img_dir = 'imgs/' |
| ann_dir = 'gts/' |
|
|
| |
| |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=[img_dir, img_dir], |
| ann_dir=[ann_dir, ann_dir]) |
| dataset = build_dataset(cfg) |
| assert isinstance(dataset, ConcatDataset) |
| assert len(dataset) == 10 |
|
|
| |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=img_dir, |
| ann_dir=ann_dir, |
| split=['splits/train.txt', 'splits/val.txt']) |
| dataset = build_dataset(cfg) |
| assert isinstance(dataset, ConcatDataset) |
| assert len(dataset) == 5 |
|
|
| |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=img_dir, |
| ann_dir=[ann_dir, ann_dir], |
| split=['splits/train.txt', 'splits/val.txt']) |
| dataset = build_dataset(cfg) |
| assert isinstance(dataset, ConcatDataset) |
| assert len(dataset) == 5 |
|
|
| |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=[img_dir, img_dir], |
| test_mode=True) |
| dataset = build_dataset(cfg) |
| assert isinstance(dataset, ConcatDataset) |
| assert len(dataset) == 10 |
|
|
| |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=[img_dir, img_dir], |
| split=['splits/val.txt', 'splits/val.txt'], |
| test_mode=True) |
| dataset = build_dataset(cfg) |
| assert isinstance(dataset, ConcatDataset) |
| assert len(dataset) == 2 |
|
|
| |
| with pytest.raises(AssertionError): |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=[img_dir, img_dir], |
| ann_dir=[ann_dir, ann_dir, ann_dir]) |
| build_dataset(cfg) |
|
|
| |
| with pytest.raises(AssertionError): |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=[img_dir, img_dir], |
| split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) |
| build_dataset(cfg) |
|
|
| |
| |
| with pytest.raises(AssertionError): |
| cfg = dict( |
| type='CustomDataset', |
| pipeline=[], |
| data_root=data_root, |
| img_dir=img_dir, |
| ann_dir=[ann_dir, ann_dir], |
| split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) |
| build_dataset(cfg) |
|
|
|
|
| def test_build_dataloader(): |
| dataset = ToyDataset() |
| samples_per_gpu = 3 |
| |
| dataloader = build_dataloader( |
| dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2) |
| assert dataloader.batch_size == samples_per_gpu |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
| assert isinstance(dataloader.sampler, DistributedSampler) |
| assert dataloader.sampler.shuffle |
|
|
| |
| dataloader = build_dataloader( |
| dataset, |
| samples_per_gpu=samples_per_gpu, |
| workers_per_gpu=2, |
| shuffle=False) |
| assert dataloader.batch_size == samples_per_gpu |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
| assert isinstance(dataloader.sampler, DistributedSampler) |
| assert not dataloader.sampler.shuffle |
|
|
| |
| dataloader = build_dataloader( |
| dataset, |
| samples_per_gpu=samples_per_gpu, |
| workers_per_gpu=2, |
| num_gpus=8) |
| assert dataloader.batch_size == samples_per_gpu |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
| assert dataloader.num_workers == 2 |
|
|
| |
| dataloader = build_dataloader( |
| dataset, |
| samples_per_gpu=samples_per_gpu, |
| workers_per_gpu=2, |
| dist=False) |
| assert dataloader.batch_size == samples_per_gpu |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
| assert isinstance(dataloader.sampler, RandomSampler) |
| assert dataloader.num_workers == 2 |
|
|
| |
| dataloader = build_dataloader( |
| dataset, |
| samples_per_gpu=3, |
| workers_per_gpu=2, |
| shuffle=False, |
| dist=False) |
| assert dataloader.batch_size == samples_per_gpu |
| assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
| assert isinstance(dataloader.sampler, SequentialSampler) |
| assert dataloader.num_workers == 2 |
|
|
| |
| dataloader = build_dataloader( |
| dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False) |
| assert dataloader.batch_size == samples_per_gpu * 8 |
| assert len(dataloader) == int( |
| math.ceil(len(dataset) / samples_per_gpu / 8)) |
| assert isinstance(dataloader.sampler, RandomSampler) |
| assert dataloader.num_workers == 16 |
|
|