| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import PIL |
| |
|
| | from torchvision import datasets, transforms |
| |
|
| | from timm.data import create_transform |
| | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| |
|
| |
|
| | def build_dataset(is_train, args): |
| | transform = build_transform(is_train, args) |
| |
|
| | root = os.path.join(args.data_path, 'train' if is_train else 'val') |
| | dataset = datasets.ImageFolder(root, transform=transform) |
| |
|
| | print(dataset) |
| |
|
| | return dataset |
| |
|
| |
|
| | def build_transform(is_train, args): |
| | mean = IMAGENET_DEFAULT_MEAN |
| | std = IMAGENET_DEFAULT_STD |
| | |
| | if is_train: |
| | |
| | transform = create_transform( |
| | input_size=args.input_size, |
| | is_training=True, |
| | color_jitter=args.color_jitter, |
| | auto_augment=args.aa, |
| | interpolation='bicubic', |
| | re_prob=args.reprob, |
| | re_mode=args.remode, |
| | re_count=args.recount, |
| | mean=mean, |
| | std=std, |
| | ) |
| | return transform |
| |
|
| | |
| | t = [] |
| | if args.input_size <= 224: |
| | crop_pct = 224 / 256 |
| | else: |
| | crop_pct = 1.0 |
| | size = int(args.input_size / crop_pct) |
| | t.append( |
| | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), |
| | ) |
| | t.append(transforms.CenterCrop(args.input_size)) |
| |
|
| | t.append(transforms.ToTensor()) |
| | t.append(transforms.Normalize(mean, std)) |
| | return transforms.Compose(t) |
| |
|