| from .dataset import Dataset, ValDataset, TestDataset |
| from torch.utils.data import DataLoader |
|
|
| def find_dataset_using_name(name): |
| mapping = { |
| "Video": Dataset, |
| "VideoVal": ValDataset, |
| "VideoTest": TestDataset, |
| } |
| cls = mapping.get(name, None) |
| if cls is None: |
| raise ValueError(f"Fail to find dataset {name}") |
| return cls |
|
|
|
|
| def create_dataset(metainfo, split): |
| dataset_cls = find_dataset_using_name(split.type) |
| dataset = dataset_cls(metainfo, split) |
| return DataLoader( |
| dataset, |
| batch_size=split.batch_size, |
| drop_last=split.drop_last, |
| shuffle=split.shuffle, |
| num_workers=split.worker, |
| pin_memory=True |
| ) |