| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from torchvision.transforms.v2 import Compose |
| import os, sys |
| from argparse import ArgumentParser |
| from typing import Union, Tuple |
|
|
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| sys.path.append(parent_dir) |
|
|
| import datasets |
|
|
|
|
| def get_dataloader(args: ArgumentParser, split: str = "train", ddp: bool = False) -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]: |
| if split == "train": |
| transforms = Compose([ |
| datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.min_scale, args.max_scale)), |
| datasets.RandomHorizontalFlip(), |
| datasets.RandomApply([ |
| datasets.ColorJitter(brightness=args.brightness, contrast=args.contrast, saturation=args.saturation, hue=args.hue), |
| datasets.GaussianBlur(kernel_size=args.kernel_size, sigma=(0.1, 5.0)), |
| datasets.PepperSaltNoise(saltiness=args.saltiness, spiciness=args.spiciness), |
| ], p=(args.jitter_prob, args.blur_prob, args.noise_prob)), |
| ]) |
|
|
| elif args.sliding_window: |
| if args.resize_to_multiple: |
| transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride) |
| elif args.zero_pad_to_multiple: |
| transforms = datasets.ZeroPad2Multiple(args.window_size, stride=args.stride) |
| else: |
| transforms = None |
|
|
| else: |
| transforms = None |
|
|
| dataset = datasets.Crowd( |
| dataset=args.dataset, |
| split=split, |
| transforms=transforms, |
| sigma=None, |
| return_filename=False, |
| num_crops=args.num_crops if split == "train" else 1, |
| ) |
|
|
| if ddp and split == "train": |
| sampler = DistributedSampler(dataset) |
| data_loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| sampler=sampler, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=datasets.collate_fn, |
| ) |
| return data_loader, sampler |
|
|
| elif split == "train": |
| data_loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=datasets.collate_fn, |
| ) |
| return data_loader, None |
|
|
| else: |
| data_loader = DataLoader( |
| dataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=datasets.collate_fn, |
| ) |
| return data_loader |
|
|