| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| |
|
| | import timm |
| | from timm.data import create_transform |
| |
|
| | from yacs.config import CfgNode as CN |
| | from PIL import ImageFilter |
| | import logging |
| | import random |
| |
|
| | import torch |
| | import torchvision.transforms as T |
| |
|
| |
|
| | from .autoaugment import AutoAugmentPolicy |
| | from .autoaugment import AutoAugment |
| | from .autoaugment import RandAugment |
| | from .autoaugment import TrivialAugmentWide |
| | from .threeaugment import deitIII_Solarization |
| | from .threeaugment import deitIII_gray_scale |
| | from .threeaugment import deitIII_GaussianBlur |
| |
|
| | from PIL import ImageOps |
| | from timm.data.transforms import RandomResizedCropAndInterpolation |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class GaussianBlur(object): |
| | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" |
| |
|
| | def __init__(self, sigma=[.1, 2.]): |
| | self.sigma = sigma |
| |
|
| | def __call__(self, x): |
| | sigma = random.uniform(self.sigma[0], self.sigma[1]) |
| | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) |
| | return x |
| |
|
| |
|
| | def get_resolution(original_resolution): |
| | """Takes (H,W) and returns (precrop, crop).""" |
| | area = original_resolution[0] * original_resolution[1] |
| | return (160, 128) if area < 96*96 else (512, 480) |
| |
|
| |
|
| | INTERPOLATION_MODES = { |
| | 'bilinear': T.InterpolationMode.BILINEAR, |
| | 'bicubic': T.InterpolationMode.BICUBIC, |
| | 'nearest': T.InterpolationMode.NEAREST, |
| | } |
| |
|
| |
|
| | def build_transforms(cfg, is_train=True): |
| | |
| | normalize = T.Normalize( |
| | mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'], |
| | std=cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| | ) |
| |
|
| | transforms = None |
| | if is_train: |
| | if 'THREE_AUG' in cfg['AUG']: |
| | img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'] |
| | remove_random_resized_crop = cfg['AUG']['THREE_AUG']['SRC'] |
| | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| | primary_tfl = [] |
| | scale=(0.08, 1.0) |
| | interpolation='bicubic' |
| | if remove_random_resized_crop: |
| | primary_tfl = [ |
| | T.Resize(img_size, interpolation=3), |
| | T.RandomCrop(img_size, padding=4,padding_mode='reflect'), |
| | T.RandomHorizontalFlip() |
| | ] |
| | else: |
| | primary_tfl = [ |
| | RandomResizedCropAndInterpolation( |
| | img_size, scale=scale, interpolation=interpolation), |
| | T.RandomHorizontalFlip() |
| | ] |
| | secondary_tfl = [T.RandomChoice([gray_scale(p=1.0), |
| | Solarization(p=1.0), |
| | GaussianBlurDeiTv3(p=1.0)])] |
| | color_jitter = cfg['AUG']['THREE_AUG']['COLOR_JITTER'] |
| | if color_jitter is not None and not color_jitter==0: |
| | secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter)) |
| | final_tfl = [ |
| | T.ToTensor(), |
| | T.Normalize( |
| | mean=torch.tensor(mean), |
| | std=torch.tensor(std)) |
| | ] |
| | return T.Compose(primary_tfl+secondary_tfl+final_tfl) |
| | elif 'TIMM_AUG' in cfg['AUG'] and cfg['AUG']['TIMM_AUG']['USE_TRANSFORM']: |
| | logger.info('=> use timm transform for training') |
| | timm_cfg = cfg['AUG']['TIMM_AUG'] |
| | transforms = create_transform( |
| | input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], |
| | is_training=True, |
| | use_prefetcher=False, |
| | no_aug=False, |
| | re_prob=timm_cfg.get('RE_PROB', 0.), |
| | re_mode=timm_cfg.get('RE_MODE', 'const'), |
| | re_count=timm_cfg.get('RE_COUNT', 1), |
| | re_num_splits= 0 if not timm_cfg.get('RE_SPLITS', False) else timm_cfg['RE_SPLITS'], |
| | scale=cfg['AUG'].get('SCALE', None), |
| | ratio=cfg['AUG'].get('RATIO', None), |
| | hflip=timm_cfg.get('HFLIP', 0.5), |
| | vflip=timm_cfg.get('VFLIP', 0.), |
| | color_jitter=timm_cfg.get('COLOR_JITTER', 0.4), |
| | auto_augment=timm_cfg.get('AUTO_AUGMENT', None), |
| | interpolation=cfg['AUG']['INTERPOLATION'], |
| | mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'], |
| | std=cfg['IMAGE_ENCODER']['IMAGE_STD'], |
| | ) |
| | elif 'TORCHVISION_AUG' in cfg['AUG']: |
| | logger.info('=> use torchvision transform fro training') |
| | crop_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] |
| | interpolation = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']] |
| | trans = [ |
| | T.RandomResizedCrop( |
| | crop_size, scale=cfg['AUG']['SCALE'], ratio=cfg['AUG']['RATIO'], |
| | interpolation=interpolation |
| | ) |
| | ] |
| | hflip_prob = cfg['AUG']['TORCHVISION_AUG']['HFLIP'] |
| | auto_augment_policy = cfg['AUG']['TORCHVISION_AUG'].get('AUTO_AUGMENT', None) |
| | if hflip_prob > 0: |
| | trans.append(T.RandomHorizontalFlip(hflip_prob)) |
| | if auto_augment_policy is not None: |
| | if auto_augment_policy == "ra": |
| | trans.append(RandAugment(interpolation=interpolation)) |
| | elif auto_augment_policy == "ta_wide": |
| | trans.append(TrivialAugmentWide(interpolation=interpolation)) |
| | else: |
| | aa_policy = AutoAugmentPolicy(auto_augment_policy) |
| | trans.append(AutoAugment(policy=aa_policy, interpolation=interpolation)) |
| | trans.extend( |
| | [ |
| | T.ToTensor(), |
| | normalize, |
| | ] |
| | ) |
| | random_erase_prob = cfg['AUG']['TORCHVISION_AUG']['RE_PROB'] |
| | random_erase_scale = cfg['AUG']['TORCHVISION_AUG'].get('RE_SCALE', 0.33) |
| | if random_erase_prob > 0: |
| | |
| | trans.append(T.RandomErasing(p=random_erase_prob, scale = (0.02, random_erase_scale))) |
| |
|
| | from torchvision.transforms import InterpolationMode |
| | rotation = cfg['AUG']['TORCHVISION_AUG'].get('ROTATION', 0.0) |
| | if (rotation > 0.0): |
| | trans.append(T.RandomRotation(rotation, interpolation=InterpolationMode.BILINEAR)) |
| | logger.info(" TORCH AUG: Rotation: " + str(rotation)) |
| |
|
| | transforms = T.Compose(trans) |
| | elif cfg['AUG'].get('RANDOM_CENTER_CROP', False): |
| | logger.info('=> use random center crop data augmenation') |
| | |
| | crop = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] |
| | padding = cfg['AUG'].get('RANDOM_CENTER_CROP_PADDING', 32) |
| | precrop = crop + padding |
| | mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']] |
| | transforms = T.Compose([ |
| | T.Resize( |
| | (precrop, precrop), |
| | interpolation=mode |
| | ), |
| | T.RandomCrop((crop, crop)), |
| | T.RandomHorizontalFlip(), |
| | T.ToTensor(), |
| | normalize, |
| | ]) |
| | elif cfg['AUG'].get('MAE_FINETUNE_AUG', False): |
| | mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN'] |
| | std = cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| | transforms = create_transform( |
| | input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], |
| | is_training=True, |
| | color_jitter=cfg['AUG'].get('COLOR_JITTER', None), |
| | auto_augment=cfg['AUG'].get('AUTO_AUGMENT', 'rand-m9-mstd0.5-inc1'), |
| | interpolation='bicubic', |
| | re_prob=cfg['AUG'].get('RE_PROB', 0.25), |
| | re_mode=cfg['AUG'].get('RE_MODE', "pixel"), |
| | re_count=cfg['AUG'].get('RE_COUNT', 1), |
| | mean=mean, |
| | std=std, |
| | ) |
| | elif cfg['AUG'].get('MAE_PRETRAIN_AUG', False): |
| | mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN'] |
| | std = cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| | transforms = T.Compose([ |
| | T.RandomResizedCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], scale=tuple(cfg['AUG']['SCALE']), interpolation=INTERPOLATION_MODES["bicubic"]), |
| | T.RandomHorizontalFlip(), |
| | T.ToTensor(), |
| | T.Normalize(mean=mean, std=std)]) |
| | elif cfg['AUG'].get('ThreeAugment', False): |
| | mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN'] |
| | std = cfg['IMAGE_ENCODER']['IMAGE_STD'] |
| | img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] |
| | remove_random_resized_crop = cfg['AUG'].get('src', False) |
| | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| | primary_tfl = [] |
| | scale=(0.08, 1.0) |
| | interpolation='bicubic' |
| | if remove_random_resized_crop: |
| | primary_tfl = [ |
| | T.Resize(img_size, interpolation=3), |
| | T.RandomCrop(img_size, padding=4,padding_mode='reflect'), |
| | T.RandomHorizontalFlip() |
| | ] |
| | else: |
| | primary_tfl = [ |
| | timm.data.transforms.RandomResizedCropAndInterpolation( |
| | img_size, scale=scale, interpolation=interpolation), |
| | T.RandomHorizontalFlip() |
| | ] |
| |
|
| | secondary_tfl = [T.RandomChoice([deitIII_gray_scale(p=1.0), |
| | deitIII_Solarization(p=1.0), |
| | deitIII_GaussianBlur(p=1.0)])] |
| | color_jitter = cfg['AUG']['COLOR_JITTER'] |
| | secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter)) |
| | final_tfl = [ |
| | T.ToTensor(), |
| | T.Normalize( |
| | mean=torch.tensor(mean), |
| | std=torch.tensor(std)) |
| | ] |
| | transforms = T.Compose(primary_tfl+secondary_tfl+final_tfl) |
| | logger.info('=> training transformers: {}'.format(transforms)) |
| | else: |
| | mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']] |
| | if cfg['TEST']['CENTER_CROP']: |
| | transforms = T.Compose([ |
| | T.Resize( |
| | int(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] / 0.875), |
| | |
| | |
| | interpolation=mode |
| | ), |
| | T.CenterCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]), |
| | T.ToTensor(), |
| | normalize, |
| | ]) |
| | else: |
| | transforms = T.Compose([ |
| | T.Resize( |
| | (cfg['IMAGE_ENCODER']['IMAGE_SIZE'][1], cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]), |
| | interpolation=mode |
| | ), |
| | T.ToTensor(), |
| | normalize, |
| | ]) |
| | logger.info('=> testing transformers: {}'.format(transforms)) |
| |
|
| | return transforms |
| |
|
| |
|