| | import cv2 |
| | import numpy as np |
| | import torch |
| | from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD |
| | from timm.data.transforms import RandomResizedCropAndInterpolation |
| | from torchvision import transforms |
| | import urllib |
| | from tqdm import tqdm |
| | from cpm_live.tokenizers import CPMBeeTokenizer |
| | from torch.utils.data import default_collate |
| | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| | from typing_extensions import TypedDict |
| | from numpy.typing import NDArray |
| | import importlib.machinery |
| | import importlib.util |
| | import types |
| | import random |
| |
|
| |
|
| | CPMBeeInputType = Union[str, Dict[str, "CPMBeeInputType"]] |
| |
|
| |
|
| | def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"): |
| | items = [] |
| | if isinstance(orig_items[0][key], list): |
| | assert isinstance(orig_items[0][key][0], torch.Tensor) |
| | for it in orig_items: |
| | for tr in it[key]: |
| | items.append({key: tr}) |
| | else: |
| | assert isinstance(orig_items[0][key], torch.Tensor) |
| | items = orig_items |
| |
|
| | batch_size = len(items) |
| | shape = items[0][key].shape |
| | dim = len(shape) |
| | assert dim <= 3 |
| | if max_length is None: |
| | max_length = 0 |
| | max_length = max(max_length, max(item[key].shape[-1] for item in items)) |
| | min_length = min(item[key].shape[-1] for item in items) |
| | dtype = items[0][key].dtype |
| |
|
| | if dim == 1: |
| | return torch.cat([item[key] for item in items], dim=0) |
| | elif dim == 2: |
| | if max_length == min_length: |
| | return torch.cat([item[key] for item in items], dim=0) |
| | tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value |
| | else: |
| | tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value |
| |
|
| | for i, item in enumerate(items): |
| | if dim == 2: |
| | if padding_side == "left": |
| | tensor[i, -len(item[key][0]):] = item[key][0].clone() |
| | else: |
| | tensor[i, : len(item[key][0])] = item[key][0].clone() |
| | elif dim == 3: |
| | if padding_side == "left": |
| | tensor[i, -len(item[key][0]):, :] = item[key][0].clone() |
| | else: |
| | tensor[i, : len(item[key][0]), :] = item[key][0].clone() |
| |
|
| | return tensor |
| |
|
| |
|
| | class CPMBeeCollater: |
| | """ |
| | 针对 cpmbee 输入数据 collate, 对应 cpm-live 的 _MixedDatasetBatchPacker |
| | 目前利用 torch 的原生 Dataloader 不太适合改造 in-context-learning |
| | 并且原来实现为了最大化提高有效 token 比比例, 会有一个 best_fit 操作, 这个目前也不支持 |
| | todo: @wangchongyi 重写一下 Dataloader or BatchPacker |
| | """ |
| |
|
| | def __init__(self, tokenizer: CPMBeeTokenizer, max_len): |
| | self.tokenizer = tokenizer |
| | self._max_length = max_len |
| | self.pad_keys = ['input_ids', 'input_id_subs', 'context', 'segment_ids', 'segment_rel_offset', |
| | 'segment_rel', 'sample_ids', 'num_segments'] |
| |
|
| | def __call__(self, batch): |
| | batch_size = len(batch) |
| |
|
| | tgt = np.full((batch_size, self._max_length), -100, dtype=np.int32) |
| | |
| | span = np.zeros((batch_size, self._max_length), dtype=np.int32) |
| | length = np.zeros((batch_size,), dtype=np.int32) |
| |
|
| | batch_ext_table_map: Dict[Tuple[int, int], int] = {} |
| | batch_ext_table_ids: List[int] = [] |
| | batch_ext_table_sub: List[int] = [] |
| | raw_data_list: List[Any] = [] |
| |
|
| | for i in range(batch_size): |
| | instance_length = batch[i]['input_ids'][0].shape[0] |
| | length[i] = instance_length |
| | raw_data_list.extend(batch[i]['raw_data']) |
| |
|
| | for j in range(instance_length): |
| | idx, idx_sub = batch[i]['input_ids'][0, j], batch[i]['input_id_subs'][0, j] |
| | tgt_idx = idx |
| | if idx_sub > 0: |
| | |
| | if (idx, idx_sub) not in batch_ext_table_map: |
| | batch_ext_table_map[(idx, idx_sub)] = len(batch_ext_table_map) |
| | batch_ext_table_ids.append(idx) |
| | batch_ext_table_sub.append(idx_sub) |
| | tgt_idx = batch_ext_table_map[(idx, idx_sub)] + self.tokenizer.vocab_size |
| | if j > 1 and batch[i]['context'][0, j - 1] == 0: |
| | if idx != self.tokenizer.bos_id: |
| | tgt[i, j - 1] = tgt_idx |
| | else: |
| | tgt[i, j - 1] = self.tokenizer.eos_id |
| | if batch[i]['context'][0, instance_length - 1] == 0: |
| | tgt[i, instance_length - 1] = self.tokenizer.eos_id |
| |
|
| | if len(batch_ext_table_map) == 0: |
| | |
| | batch_ext_table_ids.append(0) |
| | batch_ext_table_sub.append(1) |
| |
|
| | |
| | if 'pixel_values' in batch[0]: |
| | data = {'pixel_values': default_collate([i['pixel_values'] for i in batch])} |
| | else: |
| | data = {} |
| |
|
| | |
| | if 'image_bound' in batch[0]: |
| | data['image_bound'] = default_collate([i['image_bound'] for i in batch]) |
| |
|
| | |
| | for key in self.pad_keys: |
| | data[key] = pad(batch, key, max_length=self._max_length, padding_value=0, padding_side='right') |
| |
|
| | data['context'] = data['context'] > 0 |
| | data['length'] = torch.from_numpy(length) |
| | data['span'] = torch.from_numpy(span) |
| | data['target'] = torch.from_numpy(tgt) |
| | data['ext_table_ids'] = torch.from_numpy(np.array(batch_ext_table_ids)) |
| | data['ext_table_sub'] = torch.from_numpy(np.array(batch_ext_table_sub)) |
| | data['raw_data'] = raw_data_list |
| |
|
| | return data |
| |
|
| |
|
| | class _DictTree(TypedDict): |
| | value: str |
| | children: List["_DictTree"] |
| | depth: int |
| | segment_id: int |
| | need_predict: bool |
| | is_image: bool |
| |
|
| |
|
| | class _PrevExtTableStates(TypedDict): |
| | ext_table: Dict[int, str] |
| | token_id_table: Dict[str, Dict[int, int]] |
| |
|
| |
|
| | class _TransformFuncDict(TypedDict): |
| | loader: importlib.machinery.SourceFileLoader |
| | module: types.ModuleType |
| | last_m: float |
| |
|
| |
|
| | _TransformFunction = Callable[[CPMBeeInputType, int, random.Random], CPMBeeInputType] |
| |
|
| |
|
| | class CPMBeeBatch(TypedDict): |
| | inputs: NDArray[np.int32] |
| | inputs_sub: NDArray[np.int32] |
| | length: NDArray[np.int32] |
| | context: NDArray[np.bool_] |
| | sample_ids: NDArray[np.int32] |
| | num_segments: NDArray[np.int32] |
| | segment_ids: NDArray[np.int32] |
| | segment_rel_offset: NDArray[np.int32] |
| | segment_rel: NDArray[np.int32] |
| | spans: NDArray[np.int32] |
| | target: NDArray[np.int32] |
| | ext_ids: NDArray[np.int32] |
| | ext_sub: NDArray[np.int32] |
| | task_ids: NDArray[np.int32] |
| | task_names: List[str] |
| | raw_data: List[Any] |
| |
|
| |
|
| | def rel_to_bucket(n_up: int, n_down: int, max_depth: int = 8): |
| | ret = n_up * max_depth + n_down |
| | if ret == 0: |
| | return ret |
| | else: |
| | |
| | return ret + 1 |
| |
|
| |
|
| | def convert_data_to_id( |
| | tokenizer: CPMBeeTokenizer, |
| | data: Any, |
| | prev_ext_states: Optional[_PrevExtTableStates] = None, |
| | shuffle_answer: bool = True, |
| | max_depth: int = 8 |
| | ): |
| | root: _DictTree = { |
| | "value": "<root>", |
| | "children": [], |
| | "depth": 0, |
| | "segment_id": 0, |
| | "need_predict": False, |
| | "is_image": False |
| | } |
| |
|
| | segments = [root] |
| |
|
| | def _build_dict_tree(data: CPMBeeInputType, depth: int, need_predict: bool, is_image: bool) -> List[_DictTree]: |
| | if isinstance(data, dict): |
| | ret_list: List[_DictTree] = [] |
| | curr_items = list(data.items()) |
| | if need_predict and shuffle_answer: |
| | access_idx = np.arange(len(curr_items)) |
| | np.random.shuffle(access_idx) |
| | curr_items = [curr_items[idx] for idx in access_idx] |
| | for k, v in curr_items: |
| | child_info: _DictTree = { |
| | "value": k, |
| | "children": [], |
| | "depth": depth, |
| | "segment_id": len(segments), |
| | "need_predict": False, |
| | "is_image": False, |
| | } |
| | segments.append(child_info) |
| | child_info["children"] = _build_dict_tree( |
| | v, depth + 1, |
| | need_predict=need_predict or (depth == 1 and k == "<ans>"), |
| | is_image=is_image or (depth == 1 and k == "image") |
| | ) |
| |
|
| | ret_list.append(child_info) |
| | return ret_list |
| | else: |
| | assert isinstance(data, str), "Invalid data {}".format(data) |
| | ret: _DictTree = { |
| | "value": data, |
| | "children": [], |
| | "depth": depth, |
| | "segment_id": len(segments), |
| | "need_predict": need_predict, |
| | "is_image": is_image, |
| | } |
| | segments.append(ret) |
| | return [ret] |
| |
|
| | root["children"] = _build_dict_tree(data, 1, False, False) |
| |
|
| | num_segments = len(segments) |
| | segment_rel = np.zeros((num_segments * num_segments,), dtype=np.int32) |
| |
|
| | def _build_segment_rel(node: _DictTree) -> List[Tuple[int, int]]: |
| | ret: List[Tuple[int, int]] = [(node["segment_id"], node["depth"])] |
| | for child in node["children"]: |
| | sub = _build_segment_rel(child) |
| | for seg_id_1, depth_1 in sub: |
| | for seg_id_2, depth_2 in ret: |
| | n_up = min(depth_1 - node["depth"], max_depth - 1) |
| | n_down = min(depth_2 - node["depth"], max_depth - 1) |
| | segment_rel[seg_id_1 * num_segments + seg_id_2] = rel_to_bucket( |
| | n_up, n_down, max_depth=max_depth |
| | ) |
| | segment_rel[seg_id_2 * num_segments + seg_id_1] = rel_to_bucket( |
| | n_down, n_up, max_depth=max_depth |
| | ) |
| | ret.extend(sub) |
| | return ret |
| |
|
| | _build_segment_rel(root) |
| |
|
| | input_ids: List[int] = [] |
| | input_id_subs: List[int] = [] |
| | segment_bound: List[Tuple[int, int]] = [] |
| | image_bound: List[Tuple[int, int]] = [] |
| |
|
| | ext_table: Dict[int, str] = {} |
| | token_id_table: Dict[str, Dict[int, int]] = {} |
| |
|
| | if prev_ext_states is not None: |
| | ext_table = prev_ext_states["ext_table"] |
| | token_id_table = prev_ext_states["token_id_table"] |
| |
|
| | for seg in segments: |
| | tokens, ext_table = tokenizer.encode(seg["value"], ext_table) |
| |
|
| | token_id_subs = [] |
| | reid_token_ids = [] |
| | for idx in tokens: |
| | if idx in ext_table: |
| | |
| | token = ext_table[idx] |
| | if token.startswith("<") and token.endswith(">"): |
| | |
| | if "_" in token: |
| | token_name = token[1:-1].split("_", maxsplit=1)[0] |
| | else: |
| | token_name = token[1:-1] |
| | token_name = "<{}>".format(token_name) |
| | else: |
| | token_name = "<unk>" |
| |
|
| | if token_name not in token_id_table: |
| | token_id_table[token_name] = {} |
| | if idx not in token_id_table[token_name]: |
| | token_id_table[token_name][idx] = len(token_id_table[token_name]) |
| | if token_name not in tokenizer.encoder: |
| | raise ValueError("Invalid token {}".format(token)) |
| | reid_token_ids.append(tokenizer.encoder[token_name]) |
| | token_id_subs.append(token_id_table[token_name][idx]) |
| | else: |
| | reid_token_ids.append(idx) |
| | token_id_subs.append(0) |
| | tokens = [tokenizer.bos_id] + reid_token_ids |
| | token_id_subs = [0] + token_id_subs |
| | if not seg["need_predict"]: |
| | tokens = tokens + [tokenizer.eos_id] |
| | token_id_subs = token_id_subs + [0] |
| | else: |
| | |
| | pass |
| | begin = len(input_ids) |
| | input_ids.extend(tokens) |
| | input_id_subs.extend(token_id_subs) |
| | end = len(input_ids) |
| | segment_bound.append((begin, end)) |
| |
|
| | ids = np.array(input_ids, dtype=np.int32) |
| | id_subs = np.array(input_id_subs, dtype=np.int32) |
| | segs = np.zeros((ids.shape[0],), dtype=np.int32) |
| | context = np.zeros((ids.shape[0],), dtype=np.int8) |
| | for i, (begin, end) in enumerate(segment_bound): |
| | if not segments[i]["need_predict"]: |
| | context[begin:end] = 1 |
| | if segments[i]["is_image"]: |
| | image_bound.append((begin+1, end-1)) |
| | segs[begin:end] = i |
| |
|
| | curr_ext_table_states: _PrevExtTableStates = { |
| | "ext_table": ext_table, |
| | "token_id_table": token_id_table, |
| | } |
| | image_bound = np.array(image_bound, dtype=np.int32) |
| | return ids, id_subs, context, segs, segment_rel, num_segments, curr_ext_table_states, image_bound |
| |
|
| |
|
| | |
| | def identity_func(img): |
| | return img |
| |
|
| |
|
| | def autocontrast_func(img, cutoff=0): |
| | ''' |
| | same output as PIL.ImageOps.autocontrast |
| | ''' |
| | n_bins = 256 |
| |
|
| | def tune_channel(ch): |
| | n = ch.size |
| | cut = cutoff * n // 100 |
| | if cut == 0: |
| | high, low = ch.max(), ch.min() |
| | else: |
| | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) |
| | low = np.argwhere(np.cumsum(hist) > cut) |
| | low = 0 if low.shape[0] == 0 else low[0] |
| | high = np.argwhere(np.cumsum(hist[::-1]) > cut) |
| | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] |
| | if high <= low: |
| | table = np.arange(n_bins) |
| | else: |
| | scale = (n_bins - 1) / (high - low) |
| | table = np.arange(n_bins) * scale - low * scale |
| | table[table < 0] = 0 |
| | table[table > n_bins - 1] = n_bins - 1 |
| | table = table.clip(0, 255).astype(np.uint8) |
| | return table[ch] |
| |
|
| | channels = [tune_channel(ch) for ch in cv2.split(img)] |
| | out = cv2.merge(channels) |
| | return out |
| |
|
| |
|
| | def equalize_func(img): |
| | ''' |
| | same output as PIL.ImageOps.equalize |
| | PIL's implementation is different from cv2.equalize |
| | ''' |
| | n_bins = 256 |
| |
|
| | def tune_channel(ch): |
| | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) |
| | non_zero_hist = hist[hist != 0].reshape(-1) |
| | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) |
| | if step == 0: |
| | return ch |
| | n = np.empty_like(hist) |
| | n[0] = step // 2 |
| | n[1:] = hist[:-1] |
| | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) |
| | return table[ch] |
| |
|
| | channels = [tune_channel(ch) for ch in cv2.split(img)] |
| | out = cv2.merge(channels) |
| | return out |
| |
|
| |
|
| | def rotate_func(img, degree, fill=(0, 0, 0)): |
| | ''' |
| | like PIL, rotate by degree, not radians |
| | ''' |
| | H, W = img.shape[0], img.shape[1] |
| | center = W / 2, H / 2 |
| | M = cv2.getRotationMatrix2D(center, degree, 1) |
| | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) |
| | return out |
| |
|
| |
|
| | def solarize_func(img, thresh=128): |
| | ''' |
| | same output as PIL.ImageOps.posterize |
| | ''' |
| | table = np.array([el if el < thresh else 255 - el for el in range(256)]) |
| | table = table.clip(0, 255).astype(np.uint8) |
| | out = table[img] |
| | return out |
| |
|
| |
|
| | def color_func(img, factor): |
| | ''' |
| | same output as PIL.ImageEnhance.Color |
| | ''' |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | M = ( |
| | np.float32([ |
| | [0.886, -0.114, -0.114], |
| | [-0.587, 0.413, -0.587], |
| | [-0.299, -0.299, 0.701]]) * factor |
| | + np.float32([[0.114], [0.587], [0.299]]) |
| | ) |
| | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) |
| | return out |
| |
|
| |
|
| | def contrast_func(img, factor): |
| | """ |
| | same output as PIL.ImageEnhance.Contrast |
| | """ |
| | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) |
| | table = np.array([( |
| | el - mean) * factor + mean |
| | for el in range(256) |
| | ]).clip(0, 255).astype(np.uint8) |
| | out = table[img] |
| | return out |
| |
|
| |
|
| | def brightness_func(img, factor): |
| | ''' |
| | same output as PIL.ImageEnhance.Contrast |
| | ''' |
| | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) |
| | out = table[img] |
| | return out |
| |
|
| |
|
| | def sharpness_func(img, factor): |
| | ''' |
| | The differences the this result and PIL are all on the 4 boundaries, the center |
| | areas are same |
| | ''' |
| | kernel = np.ones((3, 3), dtype=np.float32) |
| | kernel[1][1] = 5 |
| | kernel /= 13 |
| | degenerate = cv2.filter2D(img, -1, kernel) |
| | if factor == 0.0: |
| | out = degenerate |
| | elif factor == 1.0: |
| | out = img |
| | else: |
| | out = img.astype(np.float32) |
| | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] |
| | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) |
| | out = out.astype(np.uint8) |
| | return out |
| |
|
| |
|
| | def shear_x_func(img, factor, fill=(0, 0, 0)): |
| | H, W = img.shape[0], img.shape[1] |
| | M = np.float32([[1, factor, 0], [0, 1, 0]]) |
| | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
| | return out |
| |
|
| |
|
| | def translate_x_func(img, offset, fill=(0, 0, 0)): |
| | ''' |
| | same output as PIL.Image.transform |
| | ''' |
| | H, W = img.shape[0], img.shape[1] |
| | M = np.float32([[1, 0, -offset], [0, 1, 0]]) |
| | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
| | return out |
| |
|
| |
|
| | def translate_y_func(img, offset, fill=(0, 0, 0)): |
| | ''' |
| | same output as PIL.Image.transform |
| | ''' |
| | H, W = img.shape[0], img.shape[1] |
| | M = np.float32([[1, 0, 0], [0, 1, -offset]]) |
| | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
| | return out |
| |
|
| |
|
| | def posterize_func(img, bits): |
| | ''' |
| | same output as PIL.ImageOps.posterize |
| | ''' |
| | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) |
| | return out |
| |
|
| |
|
| | def shear_y_func(img, factor, fill=(0, 0, 0)): |
| | H, W = img.shape[0], img.shape[1] |
| | M = np.float32([[1, 0, 0], [factor, 1, 0]]) |
| | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) |
| | return out |
| |
|
| |
|
| | def cutout_func(img, pad_size, replace=(0, 0, 0)): |
| | replace = np.array(replace, dtype=np.uint8) |
| | H, W = img.shape[0], img.shape[1] |
| | rh, rw = np.random.random(2) |
| | pad_size = pad_size // 2 |
| | ch, cw = int(rh * H), int(rw * W) |
| | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) |
| | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) |
| | out = img.copy() |
| | out[x1:x2, y1:y2, :] = replace |
| | return out |
| |
|
| |
|
| | |
| | def enhance_level_to_args(MAX_LEVEL): |
| | def level_to_args(level): |
| | return ((level / MAX_LEVEL) * 1.8 + 0.1,) |
| | return level_to_args |
| |
|
| |
|
| | def shear_level_to_args(MAX_LEVEL, replace_value): |
| | def level_to_args(level): |
| | level = (level / MAX_LEVEL) * 0.3 |
| | if np.random.random() > 0.5: |
| | level = -level |
| | return (level, replace_value) |
| |
|
| | return level_to_args |
| |
|
| |
|
| | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): |
| | def level_to_args(level): |
| | level = (level / MAX_LEVEL) * float(translate_const) |
| | if np.random.random() > 0.5: |
| | level = -level |
| | return (level, replace_value) |
| |
|
| | return level_to_args |
| |
|
| |
|
| | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): |
| | def level_to_args(level): |
| | level = int((level / MAX_LEVEL) * cutout_const) |
| | return (level, replace_value) |
| |
|
| | return level_to_args |
| |
|
| |
|
| | def solarize_level_to_args(MAX_LEVEL): |
| | def level_to_args(level): |
| | level = int((level / MAX_LEVEL) * 256) |
| | return (level, ) |
| | return level_to_args |
| |
|
| |
|
| | def none_level_to_args(level): |
| | return () |
| |
|
| |
|
| | def posterize_level_to_args(MAX_LEVEL): |
| | def level_to_args(level): |
| | level = int((level / MAX_LEVEL) * 4) |
| | return (level, ) |
| | return level_to_args |
| |
|
| |
|
| | def rotate_level_to_args(MAX_LEVEL, replace_value): |
| | def level_to_args(level): |
| | level = (level / MAX_LEVEL) * 30 |
| | if np.random.random() < 0.5: |
| | level = -level |
| | return (level, replace_value) |
| |
|
| | return level_to_args |
| |
|
| |
|
| | func_dict = { |
| | 'Identity': identity_func, |
| | 'AutoContrast': autocontrast_func, |
| | 'Equalize': equalize_func, |
| | 'Rotate': rotate_func, |
| | 'Solarize': solarize_func, |
| | 'Color': color_func, |
| | 'Contrast': contrast_func, |
| | 'Brightness': brightness_func, |
| | 'Sharpness': sharpness_func, |
| | 'ShearX': shear_x_func, |
| | 'TranslateX': translate_x_func, |
| | 'TranslateY': translate_y_func, |
| | 'Posterize': posterize_func, |
| | 'ShearY': shear_y_func, |
| | } |
| |
|
| | translate_const = 10 |
| | MAX_LEVEL = 10 |
| | replace_value = (128, 128, 128) |
| | arg_dict = { |
| | 'Identity': none_level_to_args, |
| | 'AutoContrast': none_level_to_args, |
| | 'Equalize': none_level_to_args, |
| | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), |
| | 'Solarize': solarize_level_to_args(MAX_LEVEL), |
| | 'Color': enhance_level_to_args(MAX_LEVEL), |
| | 'Contrast': enhance_level_to_args(MAX_LEVEL), |
| | 'Brightness': enhance_level_to_args(MAX_LEVEL), |
| | 'Sharpness': enhance_level_to_args(MAX_LEVEL), |
| | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), |
| | 'TranslateX': translate_level_to_args( |
| | translate_const, MAX_LEVEL, replace_value |
| | ), |
| | 'TranslateY': translate_level_to_args( |
| | translate_const, MAX_LEVEL, replace_value |
| | ), |
| | 'Posterize': posterize_level_to_args(MAX_LEVEL), |
| | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), |
| | } |
| |
|
| |
|
| | class RandomAugment(object): |
| |
|
| | def __init__(self, N=2, M=10, isPIL=False, augs=[]): |
| | self.N = N |
| | self.M = M |
| | self.isPIL = isPIL |
| | if augs: |
| | self.augs = augs |
| | else: |
| | self.augs = list(arg_dict.keys()) |
| |
|
| | def get_random_ops(self): |
| | sampled_ops = np.random.choice(self.augs, self.N) |
| | return [(op, 0.5, self.M) for op in sampled_ops] |
| |
|
| | def __call__(self, img): |
| | if self.isPIL: |
| | img = np.array(img) |
| | ops = self.get_random_ops() |
| | for name, prob, level in ops: |
| | if np.random.random() > prob: |
| | continue |
| | args = arg_dict[name](level) |
| | img = func_dict[name](img, *args) |
| | return img |
| |
|
| |
|
| | def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic'): |
| | if is_train: |
| | t = [ |
| | RandomResizedCropAndInterpolation( |
| | input_size, scale=(0.5, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), |
| | transforms.RandomHorizontalFlip(), |
| | ] |
| | if randaug: |
| | t.append( |
| | RandomAugment( |
| | 2, 7, isPIL=True, |
| | augs=[ |
| | 'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', |
| | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', |
| | ])) |
| | t += [ |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), |
| | ] |
| | t = transforms.Compose(t) |
| | else: |
| | t = transforms.Compose([ |
| | transforms.Resize((input_size, input_size), |
| | interpolation=transforms.InterpolationMode.BICUBIC), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD) |
| | ]) |
| |
|
| | return t |
| |
|
| |
|
| | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: |
| | with open(filename, "wb") as fh: |
| | with urllib.request.urlopen( |
| | urllib.request.Request(url, headers={"User-Agent": "vissl"}) |
| | ) as response: |
| | with tqdm(total=response.length) as pbar: |
| | for chunk in iter(lambda: response.read(chunk_size), ""): |
| | if not chunk: |
| | break |
| | pbar.update(chunk_size) |
| | fh.write(chunk) |
| |
|