| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from torchvision.transforms.v2 import Compose |
| import os, sys |
| from argparse import ArgumentParser |
| from typing import Union, Tuple, List, Dict |
|
|
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| sys.path.append(parent_dir) |
|
|
| import datasets |
|
|
|
|
| def calc_bin_center( |
| bins: List[Tuple[float, float]], |
| count_stats: Dict[int, int], |
| ) -> Tuple[List[float], List[int]]: |
| """ |
| Calculate the representative value for each bin based on the count statistics. |
| |
| `bins` may look like: [(0, 0), (1, 1), (2, 3), (4, 6), (7, float('inf'))] |
| `count_stats` may look like: {0: 10, 1: 20, 2: 30, 3: 40, 4: 50, 5: 60, 6: 70, 7: 80, 8: 90, 9: 100} |
| In this example, for bin (2, 3), we have 30 samples of 2 and 40 samples of 3 that fall into this bin. |
| The representative value for this bin is (30 * 2 + 40 * 3) / (30 + 40) = 2.6. |
| |
| The returned list will have the same length as `bins`, and each element is the representative value for the corresponding bin. |
| """ |
| bin_counts = [0] * len(bins) |
| bin_sums = [0] * len(bins) |
| for k, v in count_stats.items(): |
| for i, (start, end) in enumerate(bins): |
| if start <= int(k) <= end: |
| bin_counts[i] += int(v) |
| bin_sums[i] += int(v) * int(k) |
| break |
| assert all(c > 0 for c in bin_counts), f"Expected all bin_counts to be greater than 0, got {bin_counts}. Consider to re-design the bins {bins}." |
| bin_centers = [s / c for s, c in zip(bin_sums, bin_counts)] |
| return bin_centers, bin_counts |
|
|
|
|
| def get_dataloader(args: ArgumentParser, split: str = "train") -> Union[Tuple[DataLoader, Union[DistributedSampler, None]], DataLoader]: |
| ddp = args.nprocs > 1 |
| if split == "train": |
| transforms = [ |
| datasets.RandomResizedCrop((args.input_size, args.input_size), scale=(args.aug_min_scale, args.aug_max_scale)), |
| datasets.RandomHorizontalFlip(), |
| ] |
| if args.aug_brightness > 0 or args.aug_contrast > 0 or args.aug_saturation > 0 or args.aug_hue > 0: |
| transforms.append(datasets.ColorJitter( |
| brightness=args.aug_brightness, contrast=args.aug_contrast, saturation=args.aug_saturation, hue=args.aug_hue |
| )) |
| if args.aug_blur_prob > 0 and args.aug_kernel_size > 0: |
| transforms.append(datasets.RandomApply([ |
| datasets.GaussianBlur(kernel_size=args.aug_kernel_size), |
| ], p=args.aug_blur_prob)) |
| if args.aug_saltiness > 0 or args.aug_spiciness > 0: |
| transforms.append(datasets.PepperSaltNoise( |
| saltiness=args.aug_saltiness, spiciness=args.aug_spiciness, |
| )) |
| transforms = Compose(transforms) |
|
|
| elif args.sliding_window and args.resize_to_multiple: |
| transforms = datasets.Resize2Multiple(args.window_size, stride=args.stride) |
|
|
| else: |
| transforms = None |
|
|
| dataset_class = datasets.InMemoryCrowd if args.in_memory_dataset else datasets.Crowd |
| prefetch_factor = None if args.num_workers == 0 else 3 |
| persistent_workers = False if args.num_workers == 0 else True |
| |
| dataset = dataset_class( |
| dataset=args.dataset, |
| split=split, |
| transforms=transforms, |
| sigma=None, |
| return_filename=False, |
| num_crops=args.num_crops if split == "train" else 1, |
| ) |
|
|
| if ddp and split == "train": |
| sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=True, seed=args.seed+args.local_rank) |
| data_loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| sampler=sampler, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=datasets.collate_fn, |
| prefetch_factor=prefetch_factor, |
| persistent_workers=persistent_workers, |
| ) |
| return data_loader, sampler |
|
|
| elif (not ddp) and split == "train": |
| data_loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=datasets.collate_fn, |
| prefetch_factor=prefetch_factor, |
| persistent_workers=persistent_workers, |
| ) |
| return data_loader, None |
|
|
| elif ddp and split == "val": |
| sampler = DistributedSampler(dataset, num_replicas=args.nprocs, rank=args.local_rank, shuffle=False) |
| data_loader = DataLoader( |
| dataset, |
| batch_size=1, |
| sampler=sampler, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=datasets.collate_fn, |
| prefetch_factor=prefetch_factor, |
| persistent_workers=persistent_workers, |
| ) |
| return data_loader |
| |
| else: |
| data_loader = DataLoader( |
| dataset, |
| batch_size=1, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=datasets.collate_fn, |
| prefetch_factor=prefetch_factor, |
| persistent_workers=persistent_workers, |
| ) |
| return data_loader |
|
|