Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- Tipsomaly/.gitignore +2 -0
- Tipsomaly/imgs/Models_Architecture_page-0001.jpg +3 -0
- Tipsomaly/imgs/Qualitative_results_page-0001.jpg +3 -0
- Tipsomaly/imgs/results-table.png +3 -0
- Tipsomaly/model/big_vision/__pycache__/__init__.cpython-39.pyc +0 -0
- Tipsomaly/model/big_vision/__pycache__/load_siglip.cpython-39.pyc +0 -0
- Tipsomaly/model/big_vision/__pycache__/utils.cpython-39.pyc +0 -0
- Tipsomaly/model/big_vision/configs/__init__.py +0 -0
- Tipsomaly/model/big_vision/configs/bit_i1k.py +102 -0
- Tipsomaly/model/big_vision/configs/bit_i21k.py +85 -0
- Tipsomaly/model/big_vision/configs/common.py +188 -0
- Tipsomaly/model/big_vision/configs/common_fewshot.py +60 -0
- Tipsomaly/model/big_vision/configs/load_and_eval.py +143 -0
- Tipsomaly/model/big_vision/configs/mlp_mixer_i1k.py +120 -0
- Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vertexai_l4.py +115 -0
- Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vqav2.py +160 -0
- Tipsomaly/model/big_vision/configs/transfer.py +186 -0
- Tipsomaly/model/big_vision/configs/vit_i1k.py +177 -0
- Tipsomaly/model/big_vision/configs/vit_i21k.py +145 -0
- Tipsomaly/model/big_vision/configs/vit_s16_i1k.py +105 -0
- Tipsomaly/model/big_vision/datasets/core.py +77 -0
- Tipsomaly/model/big_vision/datasets/jsonl.py +177 -0
- Tipsomaly/model/big_vision/datasets/sequence_packing.py +77 -0
- Tipsomaly/model/big_vision/datasets/tfds.py +94 -0
- Tipsomaly/model/big_vision/evaluators/__init__.py +0 -0
- Tipsomaly/model/big_vision/evaluators/classification.py +76 -0
- Tipsomaly/model/big_vision/evaluators/common.py +228 -0
- Tipsomaly/model/big_vision/evaluators/fewshot_lsr.py +245 -0
- Tipsomaly/model/big_vision/evaluators/mean.py +80 -0
- Tipsomaly/model/big_vision/evaluators/save.py +121 -0
- Tipsomaly/model/big_vision/models/__init__.py +0 -0
- Tipsomaly/model/big_vision/models/bit.py +162 -0
- Tipsomaly/model/big_vision/models/bit_paper.py +260 -0
- Tipsomaly/model/big_vision/models/common.py +133 -0
- Tipsomaly/model/big_vision/models/mlp_mixer.py +177 -0
- Tipsomaly/model/big_vision/models/vit.py +505 -0
- Tipsomaly/model/big_vision/pp/__init__.py +0 -0
- Tipsomaly/model/big_vision/pp/autoaugment.py +700 -0
- Tipsomaly/model/big_vision/pp/builder.py +85 -0
- Tipsomaly/model/big_vision/pp/builder_test.py +72 -0
- Tipsomaly/model/big_vision/pp/ops_general.py +468 -0
- Tipsomaly/model/big_vision/pp/ops_general_test.py +236 -0
- Tipsomaly/model/big_vision/pp/ops_image.py +361 -0
- Tipsomaly/model/big_vision/pp/ops_image_test.py +82 -0
- Tipsomaly/model/big_vision/pp/ops_text.py +411 -0
- Tipsomaly/model/big_vision/pp/ops_text_test.py +200 -0
- Tipsomaly/model/big_vision/pp/registry.py +163 -0
- Tipsomaly/model/big_vision/pp/registry_test.py +128 -0
- Tipsomaly/model/big_vision/pp/tokenizer.py +103 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Tipsomaly/imgs/Models_Architecture_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Tipsomaly/imgs/results-table.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Tipsomaly/imgs/Qualitative_results_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
|
Tipsomaly/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
Tipsomaly/imgs/Models_Architecture_page-0001.jpg
ADDED
|
Git LFS Details
|
Tipsomaly/imgs/Qualitative_results_page-0001.jpg
ADDED
|
Git LFS Details
|
Tipsomaly/imgs/results-table.png
ADDED
|
Git LFS Details
|
Tipsomaly/model/big_vision/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (163 Bytes). View file
|
|
|
Tipsomaly/model/big_vision/__pycache__/load_siglip.cpython-39.pyc
ADDED
|
Binary file (5.4 kB). View file
|
|
|
Tipsomaly/model/big_vision/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (52.4 kB). View file
|
|
|
Tipsomaly/model/big_vision/configs/__init__.py
ADDED
|
File without changes
|
Tipsomaly/model/big_vision/configs/bit_i1k.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training BiT on ILSVRC-2012 as in https://arxiv.org/abs/1912.11370
|
| 17 |
+
|
| 18 |
+
Run training of a BiT-ResNet-50x1 variant, which takes ~32min on v3-128:
|
| 19 |
+
|
| 20 |
+
big_vision.train \
|
| 21 |
+
--config big_vision/configs/bit_i1k.py \
|
| 22 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
| 23 |
+
--config.model.depth 50 --config.model.width 1
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 27 |
+
import ml_collections as mlc
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_config(runlocal=False):
|
| 31 |
+
"""Config for training on ImageNet-1k."""
|
| 32 |
+
config = mlc.ConfigDict()
|
| 33 |
+
|
| 34 |
+
config.seed = 0
|
| 35 |
+
config.total_epochs = 90
|
| 36 |
+
config.num_classes = 1000
|
| 37 |
+
config.loss = 'softmax_xent'
|
| 38 |
+
|
| 39 |
+
config.input = dict()
|
| 40 |
+
config.input.data = dict(
|
| 41 |
+
name='imagenet2012',
|
| 42 |
+
split='train[:99%]',
|
| 43 |
+
)
|
| 44 |
+
config.input.batch_size = 4096
|
| 45 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
| 46 |
+
config.input.shuffle_buffer_size = 250_000 # Per host.
|
| 47 |
+
|
| 48 |
+
pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
|
| 49 |
+
pp_common += '|value_range(-1, 1)|keep("image", "labels")'
|
| 50 |
+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
|
| 51 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
| 52 |
+
|
| 53 |
+
config.log_training_steps = 50
|
| 54 |
+
config.ckpt_steps = 1000
|
| 55 |
+
|
| 56 |
+
# Model section
|
| 57 |
+
config.model_name = 'bit'
|
| 58 |
+
config.model = dict(
|
| 59 |
+
depth=50, # You can also pass e.g. [3, 5, 10, 2]
|
| 60 |
+
width=1.0,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Optimizer section
|
| 64 |
+
config.optax_name = 'big_vision.momentum_hp'
|
| 65 |
+
config.grad_clip_norm = 1.0
|
| 66 |
+
|
| 67 |
+
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
|
| 68 |
+
config.wd = (1e-4 / 256) * config.input.batch_size
|
| 69 |
+
config.lr = (0.1 / 256) * config.input.batch_size
|
| 70 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=1000)
|
| 71 |
+
|
| 72 |
+
# Eval section
|
| 73 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 74 |
+
return dict(
|
| 75 |
+
type='classification',
|
| 76 |
+
data=dict(name=dataset, split=split),
|
| 77 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 78 |
+
loss_name=config.loss,
|
| 79 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
| 80 |
+
cache='final_data',
|
| 81 |
+
)
|
| 82 |
+
config.evals = {}
|
| 83 |
+
config.evals.train = get_eval('train[:2%]')
|
| 84 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 85 |
+
config.evals.val = get_eval('validation')
|
| 86 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 87 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 88 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 89 |
+
|
| 90 |
+
# config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal)
|
| 91 |
+
# config.evals.fewshot.log_steps = 1000
|
| 92 |
+
|
| 93 |
+
if runlocal:
|
| 94 |
+
config.input.batch_size = 32
|
| 95 |
+
config.input.cache_raw = False
|
| 96 |
+
config.input.shuffle_buffer_size = 100
|
| 97 |
+
|
| 98 |
+
local_eval = config.evals.val
|
| 99 |
+
config.evals = {'val': local_eval}
|
| 100 |
+
config.evals.val.cache = 'none'
|
| 101 |
+
|
| 102 |
+
return config
|
Tipsomaly/model/big_vision/configs/bit_i21k.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""A config for pre-training BiT on ImageNet-21k.
|
| 17 |
+
|
| 18 |
+
This config relies on the Imagenet-21k tfds dataset, which is not yet
|
| 19 |
+
available publicly in TFDS. We intend to add the dataset to public TFDS soon,
|
| 20 |
+
and this config will then be runnable.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 24 |
+
import ml_collections as mlc
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_config():
|
| 28 |
+
"""Config for training on imagenet-21k."""
|
| 29 |
+
config = mlc.ConfigDict()
|
| 30 |
+
|
| 31 |
+
config.seed = 0
|
| 32 |
+
config.total_epochs = 90
|
| 33 |
+
config.num_classes = 21843
|
| 34 |
+
config.init_head_bias = -10.0
|
| 35 |
+
config.loss = 'sigmoid_xent'
|
| 36 |
+
|
| 37 |
+
config.input = dict()
|
| 38 |
+
config.input.data = dict(
|
| 39 |
+
name='imagenet21k',
|
| 40 |
+
split='full[51200:]',
|
| 41 |
+
)
|
| 42 |
+
config.input.batch_size = 4096
|
| 43 |
+
config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
|
| 44 |
+
|
| 45 |
+
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
|
| 46 |
+
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
|
| 47 |
+
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
|
| 48 |
+
config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
|
| 49 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)'
|
| 50 |
+
|
| 51 |
+
config.log_training_steps = 50
|
| 52 |
+
config.ckpt_steps = 1000
|
| 53 |
+
|
| 54 |
+
# Model section
|
| 55 |
+
config.model_name = 'bit_paper'
|
| 56 |
+
config.model = dict(depth=50, width=1.0)
|
| 57 |
+
|
| 58 |
+
# Optimizer section
|
| 59 |
+
config.optax_name = 'big_vision.momentum_hp'
|
| 60 |
+
config.grad_clip_norm = 1.0
|
| 61 |
+
|
| 62 |
+
# linear scaling rule. Don't forget to sweep if sweeping batch_size.
|
| 63 |
+
config.lr = (0.03 / 256) * config.input.batch_size
|
| 64 |
+
config.wd = (3e-5 / 256) * config.input.batch_size
|
| 65 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=5000)
|
| 66 |
+
|
| 67 |
+
# Evaluations on i21k itself.
|
| 68 |
+
def eval_i21k(split):
|
| 69 |
+
return dict(
|
| 70 |
+
type='classification',
|
| 71 |
+
data={**config.input.data, 'split': split},
|
| 72 |
+
pp_fn=pp_eval + pp_common_i21k,
|
| 73 |
+
loss_name=config.loss,
|
| 74 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
| 75 |
+
)
|
| 76 |
+
config.evals = {}
|
| 77 |
+
config.evals.test = eval_i21k('full[:25_600]')
|
| 78 |
+
config.evals.val = eval_i21k('full[25_600:51_200]')
|
| 79 |
+
config.evals.train = eval_i21k('full[51_200:76_800]')
|
| 80 |
+
|
| 81 |
+
# Few-shot evaluators
|
| 82 |
+
config.evals.fewshot = get_fewshot_lsr()
|
| 83 |
+
config.evals.fewshot.log_steps = 25_000
|
| 84 |
+
|
| 85 |
+
return config
|
Tipsomaly/model/big_vision/configs/common.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""A few things commonly used across A LOT of config files."""
|
| 16 |
+
|
| 17 |
+
import string
|
| 18 |
+
|
| 19 |
+
import ml_collections as mlc
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def input_for_quicktest(config_input, quicktest):
|
| 23 |
+
if quicktest:
|
| 24 |
+
config_input.batch_size = 8
|
| 25 |
+
config_input.shuffle_buffer_size = 10
|
| 26 |
+
config_input.cache_raw = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_arg(arg, lazy=False, **spec):
|
| 30 |
+
"""Makes ConfigDict's get_config single-string argument more usable.
|
| 31 |
+
|
| 32 |
+
Example use in the config file:
|
| 33 |
+
|
| 34 |
+
import big_vision.configs.common as bvcc
|
| 35 |
+
def get_config(arg):
|
| 36 |
+
arg = bvcc.parse_arg(arg,
|
| 37 |
+
res=(224, int),
|
| 38 |
+
runlocal=False,
|
| 39 |
+
schedule='short',
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# ...
|
| 43 |
+
|
| 44 |
+
config.shuffle_buffer = 250_000 if not arg.runlocal else 50
|
| 45 |
+
|
| 46 |
+
Ways that values can be passed when launching:
|
| 47 |
+
|
| 48 |
+
--config amazing.py:runlocal,schedule=long,res=128
|
| 49 |
+
--config amazing.py:res=128
|
| 50 |
+
--config amazing.py:runlocal # A boolean needs no value for "true".
|
| 51 |
+
--config amazing.py:runlocal=False # Explicit false boolean.
|
| 52 |
+
--config amazing.py:128 # The first spec entry may be passed unnamed alone.
|
| 53 |
+
|
| 54 |
+
Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
|
| 55 |
+
'false', '' to False).
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
arg: the string argument that's passed to get_config.
|
| 59 |
+
lazy: allow lazy parsing of arguments, which are not in spec. For these,
|
| 60 |
+
the type is auto-extracted in dependence of most complex possible type.
|
| 61 |
+
**spec: the name and default values of the expected options.
|
| 62 |
+
If the value is a tuple, the value's first element is the default value,
|
| 63 |
+
and the second element is a function called to convert the string.
|
| 64 |
+
Otherwise the type is automatically extracted from the default value.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
ConfigDict object with extracted type-converted values.
|
| 68 |
+
"""
|
| 69 |
+
# Normalize arg and spec layout.
|
| 70 |
+
arg = arg or '' # Normalize None to empty string
|
| 71 |
+
spec = {k: get_type_with_default(v) for k, v in spec.items()}
|
| 72 |
+
|
| 73 |
+
result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.
|
| 74 |
+
|
| 75 |
+
# Expand convenience-cases for a single parameter without = sign.
|
| 76 |
+
if arg and ',' not in arg and '=' not in arg:
|
| 77 |
+
# (think :runlocal) If it's the name of sth in the spec (or there is no
|
| 78 |
+
# spec), it's that in bool.
|
| 79 |
+
if arg in spec or not spec:
|
| 80 |
+
arg = f'{arg}=True'
|
| 81 |
+
# Otherwise, it is the value for the first entry in the spec.
|
| 82 |
+
else:
|
| 83 |
+
arg = f'{list(spec.keys())[0]}={arg}'
|
| 84 |
+
# Yes, we rely on Py3.7 insertion order!
|
| 85 |
+
|
| 86 |
+
# Now, expand the `arg` string into a dict of keys and values:
|
| 87 |
+
raw_kv = {raw_arg.split('=')[0]:
|
| 88 |
+
raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
|
| 89 |
+
for raw_arg in arg.split(',') if raw_arg}
|
| 90 |
+
|
| 91 |
+
# And go through the spec, using provided or default value for each:
|
| 92 |
+
for name, (default, type_fn) in spec.items():
|
| 93 |
+
val = raw_kv.pop(name, None)
|
| 94 |
+
result[name] = type_fn(val) if val is not None else default
|
| 95 |
+
|
| 96 |
+
if raw_kv:
|
| 97 |
+
if lazy: # Process args which are not in spec.
|
| 98 |
+
for k, v in raw_kv.items():
|
| 99 |
+
result[k] = autotype(v)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f'Unhandled config args remain: {raw_kv}')
|
| 102 |
+
|
| 103 |
+
return result
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_type_with_default(v):
|
| 107 |
+
"""Returns (v, string_to_v_type) with lenient bool parsing."""
|
| 108 |
+
# For bool, do safe string conversion.
|
| 109 |
+
if isinstance(v, bool):
|
| 110 |
+
def strict_bool(x):
|
| 111 |
+
assert x.lower() in {'true', 'false', ''}
|
| 112 |
+
return x.lower() == 'true'
|
| 113 |
+
return (v, strict_bool)
|
| 114 |
+
# If already a (default, type) tuple, use that.
|
| 115 |
+
if isinstance(v, (tuple, list)):
|
| 116 |
+
assert len(v) == 2 and isinstance(v[1], type), (
|
| 117 |
+
'List or tuple types are currently not supported because we use `,` as'
|
| 118 |
+
' dumb delimiter. Contributions (probably using ast) welcome. You can'
|
| 119 |
+
' unblock by using a string with eval(s.replace(";", ",")) or similar')
|
| 120 |
+
return (v[0], v[1])
|
| 121 |
+
# Otherwise, derive the type from the default value.
|
| 122 |
+
return (v, type(v))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def autotype(x):
|
| 126 |
+
"""Auto-converts string to bool/int/float if possible."""
|
| 127 |
+
assert isinstance(x, str)
|
| 128 |
+
if x.lower() in {'true', 'false'}:
|
| 129 |
+
return x.lower() == 'true' # Returns as bool.
|
| 130 |
+
try:
|
| 131 |
+
return int(x) # Returns as int.
|
| 132 |
+
except ValueError:
|
| 133 |
+
try:
|
| 134 |
+
return float(x) # Returns as float.
|
| 135 |
+
except ValueError:
|
| 136 |
+
return x # Returns as str.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def pack_arg(**kw):
|
| 140 |
+
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
|
| 141 |
+
for v in kw.values():
|
| 142 |
+
assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
|
| 143 |
+
return ','.join([f'{k}={v}' for k, v in kw.items()])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def arg(**kw):
|
| 147 |
+
"""Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
|
| 148 |
+
return {'config_arg': pack_arg(**kw), **kw}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _get_field_ref(config_dict, field_name):
|
| 152 |
+
path = field_name.split('.')
|
| 153 |
+
for field in path[:-1]:
|
| 154 |
+
config_dict = getattr(config_dict, field)
|
| 155 |
+
return config_dict.get_ref(path[-1])
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def format_str(format_string, config):
|
| 159 |
+
"""Format string with reference fields from config.
|
| 160 |
+
|
| 161 |
+
This makes it easy to build preprocess strings that contain references to
|
| 162 |
+
fields tha are edited after. E.g.:
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
config = mlc.ConficDict()
|
| 166 |
+
config.res = (256, 256)
|
| 167 |
+
config.pp = bvcc.format_str('resize({res})', config)
|
| 168 |
+
...
|
| 169 |
+
# if config.res is modified (e.g. via sweeps) it will propagate to pp field:
|
| 170 |
+
config.res = (512, 512)
|
| 171 |
+
assert config.pp == 'resize((512, 512))'
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
format_string: string to format with references.
|
| 176 |
+
config: ConfigDict to get references to format the string.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
A reference field which renders a string using references to config fields.
|
| 180 |
+
"""
|
| 181 |
+
output = ''
|
| 182 |
+
parts = string.Formatter().parse(format_string)
|
| 183 |
+
for (literal_text, field_name, format_spec, conversion) in parts:
|
| 184 |
+
assert not format_spec and not conversion
|
| 185 |
+
output += literal_text
|
| 186 |
+
if field_name:
|
| 187 |
+
output += _get_field_ref(config, field_name).to_str()
|
| 188 |
+
return output
|
Tipsomaly/model/big_vision/configs/common_fewshot.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Most common few-shot eval configuration."""
|
| 16 |
+
|
| 17 |
+
import ml_collections as mlc
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
|
| 21 |
+
runlocal=False, pp=None, **kw):
|
| 22 |
+
"""Returns a standard-ish fewshot eval configuration."""
|
| 23 |
+
kw.setdefault('representation_layer', 'pre_logits')
|
| 24 |
+
kw.setdefault('shots', (1, 5, 10, 25))
|
| 25 |
+
kw.setdefault('l2_reg', 2.0 ** 10)
|
| 26 |
+
kw.setdefault('num_seeds', 3)
|
| 27 |
+
kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/
|
| 28 |
+
|
| 29 |
+
# Backward-compatible default:
|
| 30 |
+
if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long
|
| 31 |
+
kw['log_steps'] = 25_000
|
| 32 |
+
|
| 33 |
+
config = mlc.ConfigDict(kw)
|
| 34 |
+
config.type = 'fewshot_lsr'
|
| 35 |
+
config.datasets = {
|
| 36 |
+
'caltech': ('caltech101', 'train', 'test'), # copybara:srtip
|
| 37 |
+
'cars': ('cars196:2.1.0', 'train', 'test'),
|
| 38 |
+
'cifar100': ('cifar100', 'train', 'test'),
|
| 39 |
+
'dtd': ('dtd', 'train', 'test'),
|
| 40 |
+
# The first 65000 ImageNet samples have at least 30 shots per any class.
|
| 41 |
+
# Commented out by default because needs manual download.
|
| 42 |
+
# 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'),
|
| 43 |
+
'pets': ('oxford_iiit_pet', 'train', 'test'),
|
| 44 |
+
'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
|
| 45 |
+
} if not runlocal else {
|
| 46 |
+
'pets': ('oxford_iiit_pet', 'train', 'test'),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
pp = pp or '|'.join([
|
| 50 |
+
'decode',
|
| 51 |
+
f'resize({resize_resolution})',
|
| 52 |
+
f'central_crop({target_resolution})',
|
| 53 |
+
'value_range(-1,1)'
|
| 54 |
+
])
|
| 55 |
+
pp += '|keep("image", "label")'
|
| 56 |
+
config.pp_train = pp
|
| 57 |
+
config.pp_eval = pp
|
| 58 |
+
config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
|
| 59 |
+
|
| 60 |
+
return config
|
Tipsomaly/model/big_vision/configs/load_and_eval.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pytype: disable=not-writable,attribute-error
|
| 16 |
+
# pylint: disable=line-too-long,missing-function-docstring
|
| 17 |
+
r"""A config to load and eval key model using the core train.py.
|
| 18 |
+
|
| 19 |
+
The runtime varies widely depending on the model, but each one should reproduce
|
| 20 |
+
the corresponding paper's numbers.
|
| 21 |
+
This configuration makes use of the "arg" to get_config to select which model
|
| 22 |
+
to run, so a few examples are given below:
|
| 23 |
+
|
| 24 |
+
Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k:
|
| 25 |
+
|
| 26 |
+
big_vision.train \
|
| 27 |
+
--config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \
|
| 28 |
+
--config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50
|
| 29 |
+
|
| 30 |
+
Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper:
|
| 31 |
+
|
| 32 |
+
big_vision.train \
|
| 33 |
+
--config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \
|
| 34 |
+
--config.model.variant B/32 --config.model_init howto-i21k-B/32
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import big_vision.configs.common as bvcc
|
| 38 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def eval_only(config, batch_size, spec_for_init):
|
| 42 |
+
"""Set a few configs that turn trainer into (almost) eval-only."""
|
| 43 |
+
config.total_steps = 0
|
| 44 |
+
config.input = {}
|
| 45 |
+
config.input.batch_size = batch_size
|
| 46 |
+
config.input.data = dict(name='bv:dummy', spec=spec_for_init)
|
| 47 |
+
config.optax_name = 'identity'
|
| 48 |
+
config.lr = 0.0
|
| 49 |
+
|
| 50 |
+
config.mesh = [('data', -1)]
|
| 51 |
+
config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')]
|
| 52 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
| 53 |
+
|
| 54 |
+
return config
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_config(arg=''):
|
| 58 |
+
config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4)
|
| 59 |
+
|
| 60 |
+
# Make the config eval-only by setting some dummies.
|
| 61 |
+
eval_only(config, config.batch_size, spec_for_init=dict(
|
| 62 |
+
image=dict(shape=(224, 224, 3), dtype='float32'),
|
| 63 |
+
))
|
| 64 |
+
|
| 65 |
+
config.evals = dict(fewshot=get_fewshot_lsr())
|
| 66 |
+
|
| 67 |
+
# Just calls the function with the name given as `config`.
|
| 68 |
+
# Could also be a giant if-block if you're into that kind of thing.
|
| 69 |
+
globals()[config.name](config)
|
| 70 |
+
return config
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def bit_paper(config):
|
| 74 |
+
config.num_classes = 1000
|
| 75 |
+
|
| 76 |
+
config.model_name = 'bit_paper'
|
| 77 |
+
config.model_init = 'M-imagenet2012' # M = i21k, -imagenet2012 = fine-tuned
|
| 78 |
+
config.model = dict(width=1, depth=50)
|
| 79 |
+
|
| 80 |
+
def get_eval(split, lbl, dataset='imagenet2012_real'):
|
| 81 |
+
return dict(
|
| 82 |
+
type='classification',
|
| 83 |
+
data=dict(name=dataset, split=split),
|
| 84 |
+
loss_name='softmax_xent',
|
| 85 |
+
cache='none', # Only run once, on low-mem machine.
|
| 86 |
+
pp_fn=(
|
| 87 |
+
'decode|resize(384)|value_range(-1, 1)'
|
| 88 |
+
f'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 89 |
+
'|keep("image", "labels")'
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
config.evals.test = get_eval('validation', 'original_label')
|
| 93 |
+
config.evals.real = get_eval('validation', 'real_label')
|
| 94 |
+
config.evals.v2 = get_eval('test', 'label', 'imagenet_v2')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def vit_i1k(config):
|
| 98 |
+
config.num_classes = 1000
|
| 99 |
+
|
| 100 |
+
config.model_name = 'vit'
|
| 101 |
+
config.model_init = '' # Will be set in sweep.
|
| 102 |
+
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
|
| 103 |
+
rep_size=True)
|
| 104 |
+
|
| 105 |
+
config.evals.val = dict(
|
| 106 |
+
type='classification',
|
| 107 |
+
data=dict(name='imagenet2012', split='validation'),
|
| 108 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
|
| 109 |
+
loss_name='softmax_xent',
|
| 110 |
+
cache='none', # Only run once, on low-mem machine.
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def mlp_mixer_i1k(config):
|
| 115 |
+
config.num_classes = 1000
|
| 116 |
+
|
| 117 |
+
config.model_name = 'mlp_mixer'
|
| 118 |
+
config.model_init = '' # Will be set in sweep.
|
| 119 |
+
config.model = dict(variant='L/16')
|
| 120 |
+
|
| 121 |
+
config.evals.val = dict(
|
| 122 |
+
type='classification',
|
| 123 |
+
data=dict(name='imagenet2012', split='validation'),
|
| 124 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
|
| 125 |
+
loss_name='softmax_xent',
|
| 126 |
+
cache='none', # Only run once, on low-mem machine.
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def vit_i21k(config):
|
| 131 |
+
config.num_classes = 21843
|
| 132 |
+
|
| 133 |
+
config.model_name = 'vit'
|
| 134 |
+
config.model_init = '' # Will be set in sweep.
|
| 135 |
+
config.model = dict(variant='B/32', pool_type='tok')
|
| 136 |
+
|
| 137 |
+
config.evals.val = dict(
|
| 138 |
+
type='classification',
|
| 139 |
+
data=dict(name='imagenet21k', split='full[:51200]'),
|
| 140 |
+
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")',
|
| 141 |
+
loss_name='sigmoid_xent',
|
| 142 |
+
cache='none', # Only run once, on low-mem machine.
|
| 143 |
+
)
|
Tipsomaly/model/big_vision/configs/mlp_mixer_i1k.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""A config for training MLP-Mixer-B/16 model on ILSVRC-2012 ("ImageNet-1k").
|
| 17 |
+
|
| 18 |
+
Achieves 76.3% top-1 accuracy on the test split in 2h11m on TPU v3-128
|
| 19 |
+
with 300 epochs. A shorter 60 epochs run is expected to get to 70.5% in 27m.
|
| 20 |
+
|
| 21 |
+
big_vision.train \
|
| 22 |
+
--config big_vision/configs/mlp_mixer_i1k.py \
|
| 23 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 27 |
+
import ml_collections as mlc
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_config(mode=None):
|
| 31 |
+
"""Config for training Mixer on i1k."""
|
| 32 |
+
config = mlc.ConfigDict()
|
| 33 |
+
|
| 34 |
+
config.seed = 0
|
| 35 |
+
config.total_epochs = 300
|
| 36 |
+
config.num_classes = 1000
|
| 37 |
+
config.loss = 'sigmoid_xent'
|
| 38 |
+
config.init_head_bias = -6.9
|
| 39 |
+
|
| 40 |
+
config.input = dict()
|
| 41 |
+
config.input.data = dict(
|
| 42 |
+
name='imagenet2012',
|
| 43 |
+
split='train[:99%]',
|
| 44 |
+
)
|
| 45 |
+
config.input.batch_size = 4096
|
| 46 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
| 47 |
+
config.input.shuffle_buffer_size = 250_000
|
| 48 |
+
|
| 49 |
+
config.input.pp = (
|
| 50 |
+
'decode_jpeg_and_inception_crop(224)'
|
| 51 |
+
'|flip_lr'
|
| 52 |
+
'|randaug(2,15)'
|
| 53 |
+
'|value_range(-1, 1)'
|
| 54 |
+
'|onehot(1000, key="label", key_result="labels")'
|
| 55 |
+
'|keep("image", "labels")'
|
| 56 |
+
)
|
| 57 |
+
pp_eval = (
|
| 58 |
+
'decode'
|
| 59 |
+
'|resize_small(256)|central_crop(224)'
|
| 60 |
+
'|value_range(-1, 1)'
|
| 61 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 62 |
+
'|keep("image", "labels")'
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# To continue using the near-defunct randaug op.
|
| 66 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 67 |
+
|
| 68 |
+
config.log_training_steps = 50
|
| 69 |
+
config.ckpt_steps = 1000
|
| 70 |
+
|
| 71 |
+
config.prefetch_to_device = 2
|
| 72 |
+
|
| 73 |
+
# Model section
|
| 74 |
+
config.model_name = 'mlp_mixer'
|
| 75 |
+
config.model = dict()
|
| 76 |
+
config.model.variant = 'B/16'
|
| 77 |
+
config.model.stoch_depth = 0.1
|
| 78 |
+
|
| 79 |
+
config.mixup = dict(fold_in=None, p=0.5)
|
| 80 |
+
|
| 81 |
+
# Optimizer section
|
| 82 |
+
config.optax_name = 'scale_by_adam'
|
| 83 |
+
config.grad_clip_norm = 1.
|
| 84 |
+
|
| 85 |
+
config.lr = 0.001
|
| 86 |
+
config.wd = 1e-4
|
| 87 |
+
config.schedule = dict(
|
| 88 |
+
decay_type='linear',
|
| 89 |
+
warmup_steps=10_000,
|
| 90 |
+
linear_end=1e-5,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Eval section
|
| 94 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 95 |
+
return dict(
|
| 96 |
+
type='classification',
|
| 97 |
+
data=dict(name=dataset, split=split),
|
| 98 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 99 |
+
loss_name=config.loss,
|
| 100 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
| 101 |
+
cache_final=mode != 'gpu8',
|
| 102 |
+
)
|
| 103 |
+
config.evals = {}
|
| 104 |
+
config.evals.train = get_eval('train[:2%]')
|
| 105 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 106 |
+
config.evals.val = get_eval('validation')
|
| 107 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 108 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 109 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 110 |
+
|
| 111 |
+
config.fewshot = get_fewshot_lsr()
|
| 112 |
+
|
| 113 |
+
if mode == 'gpu8':
|
| 114 |
+
config.total_epochs = 60
|
| 115 |
+
config.input.batch_size = 512
|
| 116 |
+
config.input.cache_raw = False
|
| 117 |
+
if mode == 'regression_test':
|
| 118 |
+
config.total_epochs = 60
|
| 119 |
+
|
| 120 |
+
return config
|
Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vertexai_l4.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""PaliGemma transfer to a task stored in JSON-L, designed to fit on an L4 GPU.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import big_vision.configs.common as bvcc
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def training_data(res, text_len):
|
| 23 |
+
"""Creates training data config."""
|
| 24 |
+
c = bvcc.parse_arg('') # Just make a configdict without extra import.
|
| 25 |
+
c.data = dict(
|
| 26 |
+
name='bv:jsonl',
|
| 27 |
+
fname='gs://longcap100/data_train90.jsonl',
|
| 28 |
+
fopen_keys={'image': 'gs://longcap100/'},
|
| 29 |
+
# See docstring in datasets/jsonl.py for further details.
|
| 30 |
+
# download_keys=['image'], # If jsonl contains external paths.
|
| 31 |
+
)
|
| 32 |
+
c.pp = '|'.join([
|
| 33 |
+
# Read and prepare the image by just resizing it:
|
| 34 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
| 35 |
+
# The texts are already prepared in `prefix` and `suffix` keys.
|
| 36 |
+
'strfmt("caption en", outkey="prefix")',
|
| 37 |
+
combine_and_keep(text_len),
|
| 38 |
+
])
|
| 39 |
+
# Keep the whole dataset in RAM after first pass. Useful optimization for
|
| 40 |
+
# small/mid-size datasets, but risks a host OOM for large datasets.
|
| 41 |
+
c.cache_raw = True
|
| 42 |
+
return c
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_config(arg=None):
|
| 46 |
+
"""Config for training."""
|
| 47 |
+
# You probably do NOT want to add settings here. The `arg` way of settings is
|
| 48 |
+
# really only for things you'd want to sweep and which affect MULTIPLE config
|
| 49 |
+
# settings at once or go into the pp string.
|
| 50 |
+
c = bvcc.parse_arg(arg, res=224, text_len=128, batch_size=4,
|
| 51 |
+
freeze_vit=False, freeze_llm=False)
|
| 52 |
+
|
| 53 |
+
c.input = training_data(c.res, c.text_len)
|
| 54 |
+
|
| 55 |
+
# These settings are suited for fitting in a single L4.
|
| 56 |
+
c.total_epochs = 1
|
| 57 |
+
c.input.batch_size = c.batch_size
|
| 58 |
+
c.optax_name = 'big_vision.sgd' # Without momentum, so really low-memory.
|
| 59 |
+
c.lr = 0.1
|
| 60 |
+
c.wd = 0.0
|
| 61 |
+
c.grad_clip_norm = 1.0
|
| 62 |
+
c.label_smoothing = 0.0
|
| 63 |
+
|
| 64 |
+
# Learning-rate schedule. Probably is fine like this.
|
| 65 |
+
sched = dict(decay_type='cosine', warmup_percent=0.05)
|
| 66 |
+
c.schedule = [
|
| 67 |
+
('img/.*', None if c.freeze_vit else sched),
|
| 68 |
+
('llm/.*', None if c.freeze_llm else sched),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
c.evals = {}
|
| 72 |
+
|
| 73 |
+
# Model section.
|
| 74 |
+
c.model_name = 'proj.paligemma.paligemma'
|
| 75 |
+
c.model = {}
|
| 76 |
+
# TODO: b/lbeyer - no scan and no remat might be better on 1-GPU machines?
|
| 77 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
| 78 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
|
| 79 |
+
c.model_init = f'pt_{c.res}'
|
| 80 |
+
|
| 81 |
+
# FSDP strategy.
|
| 82 |
+
c.mesh = [('data', -1)]
|
| 83 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
| 84 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
| 85 |
+
|
| 86 |
+
c.input.shuffle_buffer_size = 1000
|
| 87 |
+
c.log_training_steps = 1
|
| 88 |
+
c.ckpt_steps = 1_000
|
| 89 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
|
| 90 |
+
|
| 91 |
+
c.seed = 0
|
| 92 |
+
return c
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def tok(**kw):
|
| 96 |
+
"""Creates the tokenization preprocessing string."""
|
| 97 |
+
# Single entry point so that it's consistent everywhere and easier to switch.
|
| 98 |
+
kw.setdefault('model', 'gemma(tokensets=("loc", "seg"))')
|
| 99 |
+
kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items())
|
| 100 |
+
return f'tok({kw})'
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def combine_and_keep(text_len):
|
| 104 |
+
return '|'.join([
|
| 105 |
+
tok(key='prefix', bos='yes'),
|
| 106 |
+
tok(key='suffix', eos='yes'),
|
| 107 |
+
tok(key='septok', text='\n'),
|
| 108 |
+
# If masks confuse you, see (internal link)
|
| 109 |
+
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])',
|
| 110 |
+
# For training, we +1 because the trainer removes EOS.
|
| 111 |
+
f'tolen({text_len+1}, pad_value=0, key="text")', # For text, value doesn't matter.
|
| 112 |
+
f'tolen({text_len+1}, pad_value=1, key="mask_ar")',
|
| 113 |
+
f'tolen({text_len+1}, pad_value=0, key="mask_loss")',
|
| 114 |
+
'keep("image", "text", "mask_ar", "mask_loss")',
|
| 115 |
+
])
|
Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vqav2.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""PaliGemma transfer to VQAv2.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import big_vision.configs.common as bvcc
|
| 20 |
+
from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def training_data(res, final_split, text_len=32):
|
| 24 |
+
"""Creates training data config.
|
| 25 |
+
|
| 26 |
+
See (internal link)
|
| 27 |
+
You can add more arguments beside `res`, but give them good defaults.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
res: The requested image resolution (eg 224).
|
| 31 |
+
final_split: Whether to use all of the validation data.
|
| 32 |
+
text_len: sequence length
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
The ConfigDict for the input section.
|
| 36 |
+
"""
|
| 37 |
+
c = bvcc.parse_arg('') # Just make a configdict without extra import.
|
| 38 |
+
c.data = dict(
|
| 39 |
+
name='vqa',
|
| 40 |
+
split='train + validation' if final_split else 'train + validation[:-10240]',
|
| 41 |
+
)
|
| 42 |
+
c.pp = '|'.join([
|
| 43 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
| 44 |
+
'strfmt("answer en {question_text}", outkey="prefix")',
|
| 45 |
+
'choice_no_replacement(inkey="answers", outkey="suffix")',
|
| 46 |
+
combine_and_keep_train(text_len),
|
| 47 |
+
])
|
| 48 |
+
return c
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def add_eval(c, res, text_len=32, **kw):
|
| 52 |
+
"""VQAv2 evaluators."""
|
| 53 |
+
pp = '|'.join([
|
| 54 |
+
f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
|
| 55 |
+
'strfmt("answer en {question_text}", outkey="prefix")',
|
| 56 |
+
combine_and_keep_eval(text_len, keep=('answers', 'answer_type', 'question_type', 'question_id')),
|
| 57 |
+
])
|
| 58 |
+
|
| 59 |
+
for freq, name, split in [
|
| 60 |
+
(1/4, 'minitrain', 'train[:5120]'), # To gauge memorization.
|
| 61 |
+
(1/4, 'minival', 'validation[-10240:]'), # To tune hparams.
|
| 62 |
+
# To generate final predictions. Test sets combined since 2021 challenge.
|
| 63 |
+
(1.0, 'test', 'test + test-dev'),
|
| 64 |
+
]:
|
| 65 |
+
c.evals[f'vqav2/{name}'] = dict(
|
| 66 |
+
type='proj.paligemma.transfers.vqav2',
|
| 67 |
+
pred='decode', pred_kw={'max_decode_len': text_len},
|
| 68 |
+
outfile=f'{{workdir}}/vqav2_{name}.json',
|
| 69 |
+
data={**training_data(res, True, text_len).data, 'split': split},
|
| 70 |
+
log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp)
|
| 71 |
+
c.evals[f'vqav2/{name}'].update(kw)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def add_eval_pplx(c, res, text_len=32):
|
| 75 |
+
"""Perplexity evaluator to test runs before implementing the real deal."""
|
| 76 |
+
c_train = training_data(res, True, text_len) # Use mostly same settings as training.
|
| 77 |
+
|
| 78 |
+
for name, split in [
|
| 79 |
+
('minitrain', 'train[:20_864]'), # To gauge memorization.
|
| 80 |
+
('minival', 'validation[-10240:]'), # To tune hparams.
|
| 81 |
+
]:
|
| 82 |
+
c.evals[f'vqav2/{name}/pplx'] = dict(
|
| 83 |
+
type='proj.paligemma.perplexity', pred='logits',
|
| 84 |
+
key='text', shift_labels=True,
|
| 85 |
+
log_percent=1/4, # Not too cheap, do 4x per run.
|
| 86 |
+
data={**c_train.data, 'split': split},
|
| 87 |
+
pp_fn=c_train.pp,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def sweep_best(add, arg=None):
|
| 92 |
+
"""Train with best hyper-params."""
|
| 93 |
+
c = bvcc.parse_arg(arg, final_split=False)
|
| 94 |
+
# NOTE: lr was highest in sweep.
|
| 95 |
+
add(total_epochs=10, lr=1e-5, wd=1e-6, **bvcc.arg(res=224, **c))
|
| 96 |
+
add(total_epochs=10, lr=1e-5, wd=0.00, **bvcc.arg(res=448, **c))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
sweep = sweep_best # Choose which sweep to run.
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def get_config(arg=None):
|
| 103 |
+
"""Config for training."""
|
| 104 |
+
c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False)
|
| 105 |
+
|
| 106 |
+
c.input = training_data(c.res, c.final_split)
|
| 107 |
+
|
| 108 |
+
# Instead of epochs, you can also use `total_examples` or `total_steps`.
|
| 109 |
+
c.total_epochs = 10
|
| 110 |
+
c.input.batch_size = 256
|
| 111 |
+
c.optax_name = 'scale_by_adam'
|
| 112 |
+
c.lr = 3e-6
|
| 113 |
+
c.wd = 3e-7
|
| 114 |
+
c.grad_clip_norm = 1.0
|
| 115 |
+
c.label_smoothing = 0.0
|
| 116 |
+
c.schedule = dict(decay_type='cosine', warmup_percent=0.05)
|
| 117 |
+
|
| 118 |
+
# Add evaluators.
|
| 119 |
+
c.evals = {}
|
| 120 |
+
add_eval(c, c.res, batch_size=1024)
|
| 121 |
+
add_eval_pplx(c, c.res)
|
| 122 |
+
|
| 123 |
+
# Model section.
|
| 124 |
+
c.model_name = 'proj.paligemma.paligemma'
|
| 125 |
+
c.model = {}
|
| 126 |
+
c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
|
| 127 |
+
c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
|
| 128 |
+
c.model_init = f'pt_{c.res}'
|
| 129 |
+
|
| 130 |
+
# FSDP strategy.
|
| 131 |
+
c.mesh = [('data', -1)]
|
| 132 |
+
c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
| 133 |
+
c.sharding_rules = [('act_batch', ('data',))]
|
| 134 |
+
|
| 135 |
+
# These probably do not need any change/tuning
|
| 136 |
+
c.input.shuffle_buffer_size = 50_000
|
| 137 |
+
c.log_training_steps = 50
|
| 138 |
+
c.ckpt_steps = 1_000
|
| 139 |
+
c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
|
| 140 |
+
|
| 141 |
+
# Update configs for quicker local runs and avoid swapping.
|
| 142 |
+
if c.mode in ('runlocal', 'mock'):
|
| 143 |
+
c.input.shuffle_buffer_size = None
|
| 144 |
+
for ev in c.evals.values():
|
| 145 |
+
ev.data.split = ev.data.split.split('[')[0] + '[:16]'
|
| 146 |
+
|
| 147 |
+
if c.mode == 'runlocal':
|
| 148 |
+
c.log_training_steps = 1
|
| 149 |
+
c.input.batch_size = 2
|
| 150 |
+
|
| 151 |
+
c.seed = 0
|
| 152 |
+
return c
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def metrics(arg=None): # pylint: disable=unused-argument
|
| 156 |
+
m = ['training_loss']
|
| 157 |
+
for split in ('minival', 'minitrain'):
|
| 158 |
+
m.append(f'vqav2/{split}/acc')
|
| 159 |
+
m.append(f'vqav2/{split}/pplx/avg')
|
| 160 |
+
return m
|
Tipsomaly/model/big_vision/configs/transfer.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long,missing-function-docstring
|
| 16 |
+
r"""A config for transferring vit-augreg.
|
| 17 |
+
|
| 18 |
+
Best HP selected on (mini)val, expected test results (repeated 5 times):
|
| 19 |
+
|
| 20 |
+
ViT-Augreg-B/32:
|
| 21 |
+
Dataset, crop, learning rate, mean (%), range (%)
|
| 22 |
+
- ImageNet, inception_crop, 0.03, 83.27, [83.22...83.33]
|
| 23 |
+
- Cifar10, resmall_crop, 0.003, 98.55, [98.46...98.6]
|
| 24 |
+
- Cifar100, resmall_crop, 0.01, 91.35, [91.09...91.62]
|
| 25 |
+
- Pets, inception_crop, 0.003, 93.78, [93.62...94.00]
|
| 26 |
+
- Flowers, inception_crop, 0.003, 99.43, [99.42...99.45]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
Command to run:
|
| 30 |
+
big_vision.train \
|
| 31 |
+
--config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop \
|
| 32 |
+
--workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import big_vision.configs.common as bvcc
|
| 36 |
+
import ml_collections as mlc
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _set_model(config, model):
|
| 40 |
+
"""Load pre-trained models: vit or bit."""
|
| 41 |
+
# Reset the head to init (of zeros) when transferring.
|
| 42 |
+
config.model_load = dict(dont_load=['head/kernel', 'head/bias'])
|
| 43 |
+
|
| 44 |
+
if model == 'vit-i21k-augreg-b/32':
|
| 45 |
+
# Load "recommended" upstream B/32 from https://arxiv.org/abs/2106.10270
|
| 46 |
+
config.model_name = 'vit'
|
| 47 |
+
config.model_init = 'howto-i21k-B/32'
|
| 48 |
+
config.model = dict(variant='B/32', pool_type='tok')
|
| 49 |
+
elif model == 'vit-i21k-augreg-l/16':
|
| 50 |
+
config.model_name = 'vit'
|
| 51 |
+
config.model_init = 'howto-i21k-L/16'
|
| 52 |
+
config.model = dict(variant='L/16', pool_type='tok')
|
| 53 |
+
elif model == 'vit-s16':
|
| 54 |
+
config.model_name = 'vit'
|
| 55 |
+
config.model_init = 'i1k-s16-300ep'
|
| 56 |
+
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
|
| 57 |
+
rep_size=True)
|
| 58 |
+
elif model == 'bit-m-r50x1':
|
| 59 |
+
config.model_name = 'bit_paper'
|
| 60 |
+
config.model_init = 'M'
|
| 61 |
+
config.model = dict(depth=50, width=1)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f'Unknown model: {model}, please define customized model.')
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _set_dataset(config, dataset, crop='inception_crop', h_res=448, l_res=384):
|
| 67 |
+
if dataset == 'cifar10':
|
| 68 |
+
_set_task(config, 'cifar10', 'train[:98%]', 'train[98%:]', 'test', 10, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
| 69 |
+
elif dataset == 'cifar100':
|
| 70 |
+
_set_task(config, 'cifar100', 'train[:98%]', 'train[98%:]', 'test', 100, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
| 71 |
+
elif dataset == 'imagenet2012':
|
| 72 |
+
_set_task(config, 'imagenet2012', 'train[:99%]', 'train[99%:]', 'validation', 1000, steps=20_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
| 73 |
+
_set_imagenet_variants(config)
|
| 74 |
+
elif dataset == 'oxford_iiit_pet':
|
| 75 |
+
_set_task(config, 'oxford_iiit_pet', 'train[:90%]', 'train[90%:]', 'test', 37, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
|
| 76 |
+
elif dataset == 'oxford_flowers102':
|
| 77 |
+
_set_task(config, 'oxford_flowers102', 'train[:90%]', 'train[90%:]', 'test', 102, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f'Unknown dataset: {dataset}, please define customized dataset.')
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _set_task(config, dataset, train, val, test, n_cls,
|
| 84 |
+
steps=20_000, warmup=500, lbl='label', crop='resmall_crop',
|
| 85 |
+
flip=True, h_res=448, l_res=384):
|
| 86 |
+
"""Vision task with val and test splits."""
|
| 87 |
+
config.total_steps = steps
|
| 88 |
+
config.schedule = dict(
|
| 89 |
+
warmup_steps=warmup,
|
| 90 |
+
decay_type='cosine',
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
config.input.data = dict(name=dataset, split=train)
|
| 94 |
+
pp_common = (
|
| 95 |
+
'|value_range(-1, 1)|'
|
| 96 |
+
f'onehot({n_cls}, key="{lbl}", key_result="labels")|'
|
| 97 |
+
'keep("image", "labels")'
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if crop == 'inception_crop':
|
| 101 |
+
pp_train = f'decode|inception_crop({l_res})'
|
| 102 |
+
elif crop == 'resmall_crop':
|
| 103 |
+
pp_train = f'decode|resize_small({h_res})|random_crop({l_res})'
|
| 104 |
+
elif crop == 'resize_crop':
|
| 105 |
+
pp_train = f'decode|resize({h_res})|random_crop({l_res})'
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f'Unknown crop: {crop}. Must be one of: '
|
| 108 |
+
'inception_crop, resmall_crop, resize_crop')
|
| 109 |
+
if flip:
|
| 110 |
+
pp_train += '|flip_lr'
|
| 111 |
+
config.input.pp = pp_train + pp_common
|
| 112 |
+
|
| 113 |
+
pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common
|
| 114 |
+
config.num_classes = n_cls
|
| 115 |
+
|
| 116 |
+
def get_eval(split):
|
| 117 |
+
return dict(
|
| 118 |
+
type='classification',
|
| 119 |
+
data=dict(name=dataset, split=split),
|
| 120 |
+
loss_name='softmax_xent',
|
| 121 |
+
log_steps=100,
|
| 122 |
+
pp_fn=pp,
|
| 123 |
+
)
|
| 124 |
+
config.evals = dict(val=get_eval(val), test=get_eval(test))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _set_imagenet_variants(config, h_res=448, l_res=384):
|
| 128 |
+
"""Evaluation tasks on ImageNet variants: v2 and real."""
|
| 129 |
+
pp = (f'decode|resize_small({h_res})|central_crop({l_res})'
|
| 130 |
+
'|value_range(-1, 1)|onehot(1000, key="{lbl}", key_result="labels")|'
|
| 131 |
+
'keep("image", "labels")'
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Special-case rename for i1k (val+test -> minival+val)
|
| 135 |
+
config.evals.minival = config.evals.val
|
| 136 |
+
config.evals.val = config.evals.test
|
| 137 |
+
# NOTE: keep test == val for convenience in subsequent analysis.
|
| 138 |
+
|
| 139 |
+
config.evals.real = dict(type='classification')
|
| 140 |
+
config.evals.real.data = dict(name='imagenet2012_real', split='validation')
|
| 141 |
+
config.evals.real.pp_fn = pp.format(lbl='real_label')
|
| 142 |
+
config.evals.real.loss_name = config.loss
|
| 143 |
+
config.evals.real.log_steps = 100
|
| 144 |
+
|
| 145 |
+
config.evals.v2 = dict(type='classification')
|
| 146 |
+
config.evals.v2.data = dict(name='imagenet_v2', split='test')
|
| 147 |
+
config.evals.v2.pp_fn = pp.format(lbl='label')
|
| 148 |
+
config.evals.v2.loss_name = config.loss
|
| 149 |
+
config.evals.v2.log_steps = 100
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_config(arg=None):
|
| 153 |
+
"""Config for adaptation."""
|
| 154 |
+
arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop',
|
| 155 |
+
h_res=448, l_res=384, batch_size=512, fsdp=False,
|
| 156 |
+
runlocal=False)
|
| 157 |
+
config = mlc.ConfigDict()
|
| 158 |
+
|
| 159 |
+
config.input = {}
|
| 160 |
+
config.input.batch_size = arg.batch_size if not arg.runlocal else 8
|
| 161 |
+
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100
|
| 162 |
+
|
| 163 |
+
config.log_training_steps = 10
|
| 164 |
+
config.ckpt_steps = 1000
|
| 165 |
+
config.ckpt_timeout = 600
|
| 166 |
+
|
| 167 |
+
# Optimizer section
|
| 168 |
+
config.optax_name = 'big_vision.momentum_hp'
|
| 169 |
+
config.grad_clip_norm = 1.0
|
| 170 |
+
config.wd = None # That's our default, but just being explicit here!
|
| 171 |
+
config.loss = 'softmax_xent'
|
| 172 |
+
config.lr = 0.01
|
| 173 |
+
config.mixup = dict(p=0.0)
|
| 174 |
+
|
| 175 |
+
config.seed = 0
|
| 176 |
+
|
| 177 |
+
_set_dataset(config, arg.dataset, arg.crop, arg.h_res, arg.l_res)
|
| 178 |
+
|
| 179 |
+
_set_model(config, arg.model)
|
| 180 |
+
if arg.fsdp:
|
| 181 |
+
config.mesh = [('data', -1)]
|
| 182 |
+
config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
| 183 |
+
config.sharding_rules = [('act_batch', ('data',))]
|
| 184 |
+
config.model.scan = True
|
| 185 |
+
|
| 186 |
+
return config
|
Tipsomaly/model/big_vision/configs/vit_i1k.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2106.10270
|
| 17 |
+
|
| 18 |
+
This config does NOT include regularization (dropout, stochastic depth), which
|
| 19 |
+
was shown to help with B/32, B/16, L/16 models in the paper (Figure 4).
|
| 20 |
+
|
| 21 |
+
This configuration makes use of the "arg" to get_config to select which model
|
| 22 |
+
to run, so a few examples are given below:
|
| 23 |
+
|
| 24 |
+
Run training of a B/16 model:
|
| 25 |
+
|
| 26 |
+
big_vision.train \
|
| 27 |
+
--config big_vision/configs/vit_i1k.py:variant=B/16 \
|
| 28 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
|
| 29 |
+
|
| 30 |
+
Run training of a B/32 model with custom aug-strenght and 300ep:
|
| 31 |
+
|
| 32 |
+
big_vision.train \
|
| 33 |
+
--config big_vision/configs/vit_i1k.py:variant=B/32,aug=light1 \
|
| 34 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
|
| 35 |
+
--config.total_epochs 300
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import big_vision.configs.common as bvcc
|
| 39 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 40 |
+
import ml_collections as mlc
|
| 41 |
+
|
| 42 |
+
MIXUP_DEF = {
|
| 43 |
+
'none': dict(p=0.0, fold_in=None),
|
| 44 |
+
'light1': dict(p=0.0, fold_in=None),
|
| 45 |
+
'light2': dict(p=0.2, fold_in=None),
|
| 46 |
+
'medium1': dict(p=0.2, fold_in=None),
|
| 47 |
+
'medium2': dict(p=0.5, fold_in=None),
|
| 48 |
+
'strong1': dict(p=0.5, fold_in=None),
|
| 49 |
+
'strong2': dict(p=0.8, fold_in=None),
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
RANDAUG_DEF = {
|
| 53 |
+
'none': '',
|
| 54 |
+
'light1': 'randaug(2,0)', # Actually not nothing!
|
| 55 |
+
'light2': 'randaug(2,10)',
|
| 56 |
+
'medium1': 'randaug(2,15)',
|
| 57 |
+
'medium2': 'randaug(2,15)',
|
| 58 |
+
'strong1': 'randaug(2,20)',
|
| 59 |
+
'strong2': 'randaug(2,20)',
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_config(arg=None):
|
| 64 |
+
"""Config for training."""
|
| 65 |
+
arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug='')
|
| 66 |
+
config = mlc.ConfigDict()
|
| 67 |
+
|
| 68 |
+
config.seed = 0
|
| 69 |
+
config.total_epochs = 300
|
| 70 |
+
config.num_classes = 1000
|
| 71 |
+
config.loss = 'sigmoid_xent'
|
| 72 |
+
config.init_head_bias = -6.9
|
| 73 |
+
|
| 74 |
+
# If this gives a KeyError, lookup Fig4 of the paper and add an entry.
|
| 75 |
+
# Note, this here is a good average between 30ep and 300ep, sometimes you coud
|
| 76 |
+
# find a slightly better setting for either of them.
|
| 77 |
+
aug_setting = arg.aug or {
|
| 78 |
+
'Ti/16': 'light1',
|
| 79 |
+
'S/32': 'medium1',
|
| 80 |
+
'S/16': 'medium2',
|
| 81 |
+
'B/32': 'medium2',
|
| 82 |
+
'B/16': 'medium2',
|
| 83 |
+
'L/16': 'medium2',
|
| 84 |
+
}[arg.variant]
|
| 85 |
+
|
| 86 |
+
config.input = dict()
|
| 87 |
+
config.input.data = dict(
|
| 88 |
+
name='imagenet2012',
|
| 89 |
+
split='train[:99%]',
|
| 90 |
+
)
|
| 91 |
+
config.input.batch_size = 4096
|
| 92 |
+
config.input.cache = 'raw_data' if arg.runlocal else 'none' # Needs up to 120GB of RAM!
|
| 93 |
+
config.input.shuffle_buffer_size = 250_000
|
| 94 |
+
|
| 95 |
+
pp_common = (
|
| 96 |
+
'|value_range(-1, 1)'
|
| 97 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 98 |
+
'|keep("image", "labels")'
|
| 99 |
+
)
|
| 100 |
+
config.input.pp = (
|
| 101 |
+
'decode_jpeg_and_inception_crop(224)|flip_lr|' +
|
| 102 |
+
RANDAUG_DEF[aug_setting] +
|
| 103 |
+
pp_common.format(lbl='label')
|
| 104 |
+
)
|
| 105 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
| 106 |
+
|
| 107 |
+
# To continue using the near-defunct randaug op.
|
| 108 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 109 |
+
|
| 110 |
+
# Aggressive pre-fetching because our models here are small, so we not only
|
| 111 |
+
# can afford it, but we also need it for the smallest models to not be
|
| 112 |
+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
|
| 113 |
+
config.input.prefetch = 8
|
| 114 |
+
config.prefetch_to_device = 4
|
| 115 |
+
|
| 116 |
+
config.log_training_steps = 50
|
| 117 |
+
config.ckpt_steps = 1000
|
| 118 |
+
|
| 119 |
+
# Model section
|
| 120 |
+
config.model_name = 'vit'
|
| 121 |
+
config.model = dict(
|
| 122 |
+
variant=arg.variant,
|
| 123 |
+
rep_size=True,
|
| 124 |
+
pool_type='tok',
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Optimizer section
|
| 128 |
+
config.grad_clip_norm = 1.0
|
| 129 |
+
config.optax_name = 'scale_by_adam'
|
| 130 |
+
config.optax = dict(mu_dtype='bfloat16')
|
| 131 |
+
# The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
|
| 132 |
+
# almost always behaves exactly like adam, but at a fraction of the memory
|
| 133 |
+
# cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
|
| 134 |
+
# good idea to try it when you are memory-bound!
|
| 135 |
+
# config.optax_name = 'big_vision.scale_by_adafactor'
|
| 136 |
+
# A good flag to play with when hitting instabilities, is the following:
|
| 137 |
+
# config.optax = dict(beta2_cap=0.95)
|
| 138 |
+
|
| 139 |
+
config.lr = 0.001
|
| 140 |
+
config.wd = 0.0001
|
| 141 |
+
config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
|
| 142 |
+
|
| 143 |
+
config.mixup = MIXUP_DEF[aug_setting]
|
| 144 |
+
|
| 145 |
+
# Eval section
|
| 146 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 147 |
+
return dict(
|
| 148 |
+
type='classification',
|
| 149 |
+
data=dict(name=dataset, split=split),
|
| 150 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 151 |
+
loss_name=config.loss,
|
| 152 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
| 153 |
+
cache='final_data' if arg.runlocal else 'none',
|
| 154 |
+
)
|
| 155 |
+
config.evals = {}
|
| 156 |
+
config.evals.train = get_eval('train[:2%]')
|
| 157 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 158 |
+
config.evals.val = get_eval('validation')
|
| 159 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 160 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 161 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 162 |
+
|
| 163 |
+
config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
|
| 164 |
+
config.fewshot.log_steps = 10_000
|
| 165 |
+
|
| 166 |
+
# Make a few things much smaller for quick local debugging testruns.
|
| 167 |
+
if arg.runlocal:
|
| 168 |
+
config.input.shuffle_buffer_size = 10
|
| 169 |
+
config.input.batch_size = 8
|
| 170 |
+
config.input.cache_raw = False
|
| 171 |
+
config.evals.train.data.split = 'train[:16]'
|
| 172 |
+
config.evals.minival.data.split = 'train[:16]'
|
| 173 |
+
config.evals.val.data.split = 'validation[:16]'
|
| 174 |
+
config.evals.v2.data.split = 'test[:16]'
|
| 175 |
+
config.evals.real.data.split = 'validation[:16]'
|
| 176 |
+
|
| 177 |
+
return config
|
Tipsomaly/model/big_vision/configs/vit_i21k.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training ViT on ImageNet-21k as in https://arxiv.org/abs/2106.10270
|
| 17 |
+
|
| 18 |
+
This config relies on the Imagenet-21k tfds dataset, which is not yet
|
| 19 |
+
available publicly in TFDS. We intend to add the dataset to public TFDS soon,
|
| 20 |
+
and this config will then be runnable.
|
| 21 |
+
|
| 22 |
+
Note that regularization (dropout, stochastic depth) is not currently
|
| 23 |
+
implemented. This was not beneficial for ImageNet-21k pre-trainning.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import big_vision.configs.common as bvcc
|
| 27 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 28 |
+
import ml_collections as mlc
|
| 29 |
+
|
| 30 |
+
MIXUP_DEF = {
|
| 31 |
+
'none': dict(p=0.0, fold_in=None),
|
| 32 |
+
'light1': dict(p=0.0, fold_in=None),
|
| 33 |
+
'light2': dict(p=0.2, fold_in=None),
|
| 34 |
+
'medium1': dict(p=0.2, fold_in=None),
|
| 35 |
+
'medium2': dict(p=0.5, fold_in=None),
|
| 36 |
+
'strong1': dict(p=0.5, fold_in=None),
|
| 37 |
+
'strong2': dict(p=0.8, fold_in=None),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
RANDAUG_DEF = {
|
| 41 |
+
'none': '',
|
| 42 |
+
'light1': 'randaug(2,0)', # Actually not nothing!
|
| 43 |
+
'light2': 'randaug(2,10)',
|
| 44 |
+
'medium1': 'randaug(2,15)',
|
| 45 |
+
'medium2': 'randaug(2,15)',
|
| 46 |
+
'strong1': 'randaug(2,20)',
|
| 47 |
+
'strong2': 'randaug(2,20)',
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_config(arg=None):
|
| 52 |
+
"""Config for training."""
|
| 53 |
+
arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug=None)
|
| 54 |
+
config = mlc.ConfigDict()
|
| 55 |
+
|
| 56 |
+
config.seed = 0
|
| 57 |
+
config.total_epochs = 300
|
| 58 |
+
config.num_classes = 21843
|
| 59 |
+
config.init_head_bias = -10.0
|
| 60 |
+
config.loss = 'sigmoid_xent'
|
| 61 |
+
|
| 62 |
+
# If this gives a KeyError, lookup Fig4 of the paper and add an entry.
|
| 63 |
+
# Note, this here is a good average between 30ep and 300ep, sometimes you coud
|
| 64 |
+
# find a slightly better setting for either of them.
|
| 65 |
+
aug_setting = {
|
| 66 |
+
'Ti/16': 'none',
|
| 67 |
+
'S/32': 'none',
|
| 68 |
+
'S/16': 'light1',
|
| 69 |
+
'B/32': 'light2',
|
| 70 |
+
'B/16': 'light2',
|
| 71 |
+
'L/16': 'medium2',
|
| 72 |
+
}[arg.variant]
|
| 73 |
+
|
| 74 |
+
config.input = dict()
|
| 75 |
+
config.input.data = dict(
|
| 76 |
+
name='imagenet21k',
|
| 77 |
+
split='full[51200:]',
|
| 78 |
+
)
|
| 79 |
+
config.input.batch_size = 4096
|
| 80 |
+
config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
|
| 81 |
+
|
| 82 |
+
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
|
| 83 |
+
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
|
| 84 |
+
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
|
| 85 |
+
config.input.pp = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k
|
| 86 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)'
|
| 87 |
+
|
| 88 |
+
# To continue using the near-defunct randaug op.
|
| 89 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 90 |
+
|
| 91 |
+
# Aggressive pre-fetching because our models here are small, so we not only
|
| 92 |
+
# can afford it, but we also need it for the smallest models to not be
|
| 93 |
+
# bottle-necked by the input pipeline. Play around with it for -L models tho.
|
| 94 |
+
config.input.prefetch = 8
|
| 95 |
+
config.prefetch_to_device = 4
|
| 96 |
+
|
| 97 |
+
config.log_training_steps = 50
|
| 98 |
+
config.ckpt_steps = 1000
|
| 99 |
+
|
| 100 |
+
# Model section
|
| 101 |
+
config.model_name = 'vit'
|
| 102 |
+
config.model = dict(variant=arg.variant, pool_type='gap', posemb='learn')
|
| 103 |
+
|
| 104 |
+
# Optimizer section
|
| 105 |
+
config.optax_name = 'scale_by_adam'
|
| 106 |
+
config.optax = dict(mu_dtype='bfloat16')
|
| 107 |
+
config.grad_clip_norm = 1.0
|
| 108 |
+
|
| 109 |
+
config.lr = 0.001
|
| 110 |
+
config.wd = 0.0001
|
| 111 |
+
config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
|
| 112 |
+
|
| 113 |
+
config.mixup = MIXUP_DEF[aug_setting]
|
| 114 |
+
|
| 115 |
+
# Evaluations on i21k itself.
|
| 116 |
+
def eval_i21k(split):
|
| 117 |
+
return dict(
|
| 118 |
+
type='classification',
|
| 119 |
+
data={**config.input.data, 'split': split},
|
| 120 |
+
pp_fn=pp_eval + pp_common_i21k,
|
| 121 |
+
loss_name=config.loss,
|
| 122 |
+
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
|
| 123 |
+
)
|
| 124 |
+
config.evals = {}
|
| 125 |
+
config.evals.test = eval_i21k('full[:25_600]')
|
| 126 |
+
config.evals.val = eval_i21k('full[25_600:51_200]')
|
| 127 |
+
config.evals.train = eval_i21k('full[51_200:76_800]')
|
| 128 |
+
|
| 129 |
+
# Few-shot evaluators
|
| 130 |
+
config.evals.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
|
| 131 |
+
config.evals.fewshot.log_steps = 25_000
|
| 132 |
+
|
| 133 |
+
# Make a few things much smaller for quick local debugging testruns.
|
| 134 |
+
if arg.runlocal:
|
| 135 |
+
config.input.shuffle_buffer_size = 10
|
| 136 |
+
config.input.batch_size = 8
|
| 137 |
+
config.evals.test.data.split = 'full[:16]'
|
| 138 |
+
config.evals.train.data.split = 'full[:16]'
|
| 139 |
+
config.evals.val.data.split = 'full[:16]'
|
| 140 |
+
config.evals.i1k_val.data.split = 'validation[:16]'
|
| 141 |
+
config.evals.i1k_v2.data.split = 'test[:16]'
|
| 142 |
+
config.evals.i1k_a.data.split = 'test[:16]'
|
| 143 |
+
config.evals.i1k_r.data.split = 'test[:16]'
|
| 144 |
+
|
| 145 |
+
return config
|
Tipsomaly/model/big_vision/configs/vit_s16_i1k.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-training ViT-S/16 on ILSVRC-2012 following https://arxiv.org/abs/2205.01580.
|
| 17 |
+
|
| 18 |
+
This should take 6-7h to finish 90ep on a TPU-v3-8 and reach 76.5%,
|
| 19 |
+
see the tech report for more details.
|
| 20 |
+
|
| 21 |
+
Command to run:
|
| 22 |
+
|
| 23 |
+
big_vision.train \
|
| 24 |
+
--config big_vision/configs/vit_s16_i1k.py \
|
| 25 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
|
| 26 |
+
|
| 27 |
+
To run for 300ep, add `--config.total_epochs 300` to the command.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import ml_collections as mlc
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_config():
|
| 34 |
+
"""Config for training."""
|
| 35 |
+
config = mlc.ConfigDict()
|
| 36 |
+
|
| 37 |
+
config.seed = 0
|
| 38 |
+
config.total_epochs = 90
|
| 39 |
+
config.num_classes = 1000
|
| 40 |
+
config.loss = 'softmax_xent'
|
| 41 |
+
|
| 42 |
+
config.input = {}
|
| 43 |
+
config.input.data = dict(
|
| 44 |
+
name='imagenet2012',
|
| 45 |
+
split='train[:99%]',
|
| 46 |
+
)
|
| 47 |
+
config.input.batch_size = 1024
|
| 48 |
+
config.input.cache_raw = True # Needs up to 120GB of RAM!
|
| 49 |
+
config.input.shuffle_buffer_size = 250_000
|
| 50 |
+
|
| 51 |
+
pp_common = (
|
| 52 |
+
'|value_range(-1, 1)'
|
| 53 |
+
'|onehot(1000, key="{lbl}", key_result="labels")'
|
| 54 |
+
'|keep("image", "labels")'
|
| 55 |
+
)
|
| 56 |
+
config.input.pp = (
|
| 57 |
+
'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' +
|
| 58 |
+
pp_common.format(lbl='label')
|
| 59 |
+
)
|
| 60 |
+
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
|
| 61 |
+
|
| 62 |
+
# To continue using the near-defunct randaug op.
|
| 63 |
+
config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
|
| 64 |
+
|
| 65 |
+
config.log_training_steps = 50
|
| 66 |
+
config.ckpt_steps = 1000
|
| 67 |
+
|
| 68 |
+
# Model section
|
| 69 |
+
config.model_name = 'vit'
|
| 70 |
+
config.model = dict(
|
| 71 |
+
variant='S/16',
|
| 72 |
+
rep_size=True,
|
| 73 |
+
pool_type='gap',
|
| 74 |
+
posemb='sincos2d',
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Optimizer section
|
| 78 |
+
config.grad_clip_norm = 1.0
|
| 79 |
+
config.optax_name = 'scale_by_adam'
|
| 80 |
+
config.optax = dict(mu_dtype='bfloat16')
|
| 81 |
+
|
| 82 |
+
config.lr = 0.001
|
| 83 |
+
config.wd = 0.0001
|
| 84 |
+
config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
|
| 85 |
+
|
| 86 |
+
config.mixup = dict(p=0.2, fold_in=None)
|
| 87 |
+
|
| 88 |
+
# Eval section
|
| 89 |
+
def get_eval(split, dataset='imagenet2012'):
|
| 90 |
+
return dict(
|
| 91 |
+
type='classification',
|
| 92 |
+
data=dict(name=dataset, split=split),
|
| 93 |
+
pp_fn=pp_eval.format(lbl='label'),
|
| 94 |
+
loss_name=config.loss,
|
| 95 |
+
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
|
| 96 |
+
)
|
| 97 |
+
config.evals = {}
|
| 98 |
+
config.evals.train = get_eval('train[:2%]')
|
| 99 |
+
config.evals.minival = get_eval('train[99%:]')
|
| 100 |
+
config.evals.val = get_eval('validation')
|
| 101 |
+
config.evals.v2 = get_eval('test', dataset='imagenet_v2')
|
| 102 |
+
config.evals.real = get_eval('validation', dataset='imagenet2012_real')
|
| 103 |
+
config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
|
| 104 |
+
|
| 105 |
+
return config
|
Tipsomaly/model/big_vision/datasets/core.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Core data functions, dispatch calls to the requested dataset."""
|
| 16 |
+
import importlib
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Note: intentionally not using ABC to avoid forcing implementation of every
|
| 20 |
+
# method, since one can imagine train-only datasets for example.
|
| 21 |
+
class DataSource:
|
| 22 |
+
"""The API that any data source should implement."""
|
| 23 |
+
|
| 24 |
+
def get_tfdata(self, ordered, *, process_split=True, allow_cache=True):
|
| 25 |
+
"""Creates this data object as a tf.data.Dataset.
|
| 26 |
+
|
| 27 |
+
This will be called separately in each process, and it is up to the dataset
|
| 28 |
+
implementation to shard it accordingly if desired!
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
ordered: if True, the dataset should use deterministic ordering, if False
|
| 32 |
+
it may have undefined ordering. Think of True == val, False == train.
|
| 33 |
+
process_split: if False then every process receives the entire dataset
|
| 34 |
+
(e.g. for evaluators running in a single process).
|
| 35 |
+
allow_cache: whether to allow caching the opened data or not.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
A tf.data.Dataset object.
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
RuntimeError: if not implemented by the dataset, but called.
|
| 42 |
+
"""
|
| 43 |
+
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def total_examples(self):
|
| 47 |
+
"""Returns number of examples in the dataset, regardless of sharding."""
|
| 48 |
+
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
| 49 |
+
|
| 50 |
+
def num_examples_per_process(self):
|
| 51 |
+
"""Returns a list of the numer of examples for each process.
|
| 52 |
+
|
| 53 |
+
This is only needed for datasets that should go through make_for_inference.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Returns a list of the numer of examples for each process.
|
| 57 |
+
|
| 58 |
+
Ideally, this would always be `[total() / nprocess] * nprocess`, but in
|
| 59 |
+
reality we can almost never perfectly shard a dataset across arbitrary
|
| 60 |
+
number of processes.
|
| 61 |
+
|
| 62 |
+
One alternative option that can work in some cases is to not even shard
|
| 63 |
+
the dataset and thus return `[num_examples()] * nprocess.
|
| 64 |
+
|
| 65 |
+
Raises:
|
| 66 |
+
RuntimeError: if not implemented by the dataset, but called.
|
| 67 |
+
"""
|
| 68 |
+
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get(name, **kw):
|
| 72 |
+
if name.startswith("bv:"):
|
| 73 |
+
mod = importlib.import_module(f"big_vision.datasets.{name[3:]}")
|
| 74 |
+
return mod.DataSource(**kw)
|
| 75 |
+
else:
|
| 76 |
+
mod = importlib.import_module("big_vision.datasets.tfds")
|
| 77 |
+
return mod.DataSource(name, **kw)
|
Tipsomaly/model/big_vision/datasets/jsonl.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Simple data input from .jsonl files."""
|
| 16 |
+
|
| 17 |
+
import hashlib
|
| 18 |
+
import json
|
| 19 |
+
from multiprocessing.pool import ThreadPool
|
| 20 |
+
import os
|
| 21 |
+
import tempfile
|
| 22 |
+
import urllib.request
|
| 23 |
+
|
| 24 |
+
from absl import logging
|
| 25 |
+
import big_vision.datasets.core as ds_core
|
| 26 |
+
import jax
|
| 27 |
+
import numpy as np
|
| 28 |
+
import overrides
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def cached_download(url, dest=None, verbose=True):
|
| 33 |
+
"""Download `url` to local file and return path to that, but with caching."""
|
| 34 |
+
# NOTE: there is a small chance of saving corrupted data if the process is
|
| 35 |
+
# interrupted in the middle of writing the file. Then, reading in the input
|
| 36 |
+
# pipeline will fail, and the fix is to nuke the temp folder.
|
| 37 |
+
|
| 38 |
+
# Compute a temp name based on the URL, so we can check if we already
|
| 39 |
+
# downloaded it before.
|
| 40 |
+
dest = dest or os.path.join(tempfile.gettempdir(), "bv")
|
| 41 |
+
os.makedirs(dest, exist_ok=True)
|
| 42 |
+
dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest())
|
| 43 |
+
|
| 44 |
+
# NOTE: we should use last-modified header to know whether to re-download.
|
| 45 |
+
if os.path.isfile(dest):
|
| 46 |
+
return dest
|
| 47 |
+
|
| 48 |
+
if verbose:
|
| 49 |
+
print(f"\rRetrieving {url} into {dest}", end="", flush=True)
|
| 50 |
+
|
| 51 |
+
with urllib.request.urlopen(url) as f:
|
| 52 |
+
data = f.read()
|
| 53 |
+
with open(dest, "wb+") as f:
|
| 54 |
+
f.write(data)
|
| 55 |
+
return dest
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DataSource(ds_core.DataSource):
|
| 59 |
+
""".jsonl DataSource."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, fname, *, fopen_keys=(), download_keys=(),
|
| 62 |
+
start=0, stop=float("inf")):
|
| 63 |
+
"""Create data-source that's jsonl + data files (eg images).
|
| 64 |
+
|
| 65 |
+
This correctly supports multi-host in that each host only reads a subset of
|
| 66 |
+
the dataset automatically. However, currently, all hosts download all items
|
| 67 |
+
if `download_keys` is specified. TODO: b/lbeyer - This can be improved.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
fname: str, the path to the jsonl file that holds the dataset.
|
| 71 |
+
fopen_keys: collection of str or dict, the keys in the dataset whose
|
| 72 |
+
string value actually is a file-path that should be opened and read,
|
| 73 |
+
and its content is what goes into the batch (eg image filenames
|
| 74 |
+
commonly ["image"]).
|
| 75 |
+
If a dict, the values are folders prefixed to the filenames.
|
| 76 |
+
Supports gs:// for reading from buckets.
|
| 77 |
+
download_keys: collection of str, the keys in the dataset whose string
|
| 78 |
+
value actually is a URL from which the file should be downloaded first.
|
| 79 |
+
files are downloaded to a persistent tmp folder using the URL hash as
|
| 80 |
+
filename. If the file already exists, the download is skipped.
|
| 81 |
+
Must be a subset of `fopen_keys`.
|
| 82 |
+
start: int, index of the first row to use; use for slicing the data.
|
| 83 |
+
stop: int or inf, index of the row after the last one to use.
|
| 84 |
+
|
| 85 |
+
Note:
|
| 86 |
+
This simple data input does not allow for nested/hierarchical values,
|
| 87 |
+
or in any way more complicated values like vectors. Use TFDS for that.
|
| 88 |
+
|
| 89 |
+
The way start/stop arguments are used is as in list slicing[start:stop].
|
| 90 |
+
"""
|
| 91 |
+
self.examples = []
|
| 92 |
+
|
| 93 |
+
with tf.io.gfile.GFile(fname) as f:
|
| 94 |
+
for i, line in enumerate(f):
|
| 95 |
+
if (start or 0) <= i < (stop or float("inf")):
|
| 96 |
+
try:
|
| 97 |
+
self.examples.append(json.loads(line))
|
| 98 |
+
except json.decoder.JSONDecodeError as e:
|
| 99 |
+
raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e
|
| 100 |
+
|
| 101 |
+
if download_keys:
|
| 102 |
+
for k in download_keys:
|
| 103 |
+
assert k in fopen_keys, (
|
| 104 |
+
f"{k} in download_keys but missing from fopen_keys {fopen_keys}")
|
| 105 |
+
|
| 106 |
+
# TODO: b/lbeyer - use info from trainer instead, move that to utils.
|
| 107 |
+
logging.info( # pylint: disable=logging-fstring-interpolation
|
| 108 |
+
f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} "
|
| 109 |
+
f"for dataset {fname} ({len(self.examples)} examples) ...")
|
| 110 |
+
|
| 111 |
+
def _dl_one(ex):
|
| 112 |
+
for k in download_keys:
|
| 113 |
+
ex[k] = cached_download(ex[k])
|
| 114 |
+
|
| 115 |
+
ThreadPool(100).map(_dl_one, self.examples)
|
| 116 |
+
print("Done")
|
| 117 |
+
logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.")
|
| 118 |
+
|
| 119 |
+
# Normalize.
|
| 120 |
+
if isinstance(fopen_keys, (list, tuple)):
|
| 121 |
+
self.fopen_keys = {k: "" for k in fopen_keys}
|
| 122 |
+
else:
|
| 123 |
+
self.fopen_keys = fopen_keys or {}
|
| 124 |
+
|
| 125 |
+
# We need to apply fopen path prefix here already, because doing so while
|
| 126 |
+
# actually reading the files in TF, things are symbolic :(
|
| 127 |
+
for ex in self.examples:
|
| 128 |
+
for k, dirname in self.fopen_keys.items():
|
| 129 |
+
ex[k] = os.path.join(dirname, ex[k])
|
| 130 |
+
|
| 131 |
+
def _indices(self, *, process_split=True, process_index=None):
|
| 132 |
+
indices = np.arange(len(self.examples))
|
| 133 |
+
|
| 134 |
+
if not process_split:
|
| 135 |
+
return list(indices)
|
| 136 |
+
|
| 137 |
+
pid = jax.process_index() if process_index is None else process_index
|
| 138 |
+
return list(np.array_split(indices, jax.process_count())[pid])
|
| 139 |
+
|
| 140 |
+
@overrides.overrides
|
| 141 |
+
def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True):
|
| 142 |
+
del allow_cache # We don't cache anything anyways.
|
| 143 |
+
assert not process_split or len(self.examples) >= jax.process_count(), (
|
| 144 |
+
"Process splitting the data with fewer examples than processes!?")
|
| 145 |
+
|
| 146 |
+
my_idxs = self._indices(process_split=process_split)
|
| 147 |
+
if not ordered:
|
| 148 |
+
np.random.shuffle(my_idxs)
|
| 149 |
+
|
| 150 |
+
dataset = tf.data.Dataset.from_generator(
|
| 151 |
+
generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs),
|
| 152 |
+
output_signature={
|
| 153 |
+
"id": _guess_signature("0"),
|
| 154 |
+
**{k: _guess_signature(v) for k, v in self.examples[0].items()},
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
def _read_files(example):
|
| 158 |
+
for k in self.fopen_keys:
|
| 159 |
+
example[k] = tf.io.read_file(example[k])
|
| 160 |
+
return example
|
| 161 |
+
dataset = dataset.map(_read_files)
|
| 162 |
+
|
| 163 |
+
return dataset
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
@overrides.overrides
|
| 167 |
+
def total_examples(self):
|
| 168 |
+
return len(self.examples)
|
| 169 |
+
|
| 170 |
+
@overrides.overrides
|
| 171 |
+
def num_examples_per_process(self):
|
| 172 |
+
return [len(self._indices(process_index=pid))
|
| 173 |
+
for pid in range(jax.process_count())]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _guess_signature(value):
|
| 177 |
+
return tf.TensorSpec.from_tensor(tf.constant(value))
|
Tipsomaly/model/big_vision/datasets/sequence_packing.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Packed Sequence Op."""
|
| 16 |
+
|
| 17 |
+
# Forked from
|
| 18 |
+
# https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from typing import Dict, Optional, List, Union
|
| 22 |
+
|
| 23 |
+
from flax import traverse_util
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
| 27 |
+
FLATTEN_SEPARATOR = "<|sep|>"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def pack_dataset(
|
| 31 |
+
dataset: tf.data.Dataset,
|
| 32 |
+
batch_size: int | None,
|
| 33 |
+
key2length: Union[int, Dict[str, int]],
|
| 34 |
+
keys: Optional[List[str | tuple[str, ...]]] = None) -> tf.data.Dataset:
|
| 35 |
+
"""Creates a 'packed' version of a dataset on-the-fly.
|
| 36 |
+
|
| 37 |
+
Wrap `tensorflow.grain` ops.
|
| 38 |
+
|
| 39 |
+
This is meant to replace the irritation of having to create a separate
|
| 40 |
+
"packed" version of a dataset to train efficiently on TPU.
|
| 41 |
+
Each example in the output dataset represents several examples in the
|
| 42 |
+
input dataset.
|
| 43 |
+
|
| 44 |
+
For each key in the input dataset, two additional keys are created:
|
| 45 |
+
<key>_segment_ids: an int32 tensor identifying the parts
|
| 46 |
+
representing the original example.
|
| 47 |
+
<key>_positions: an int32 tensor identifying the position within the original
|
| 48 |
+
example.
|
| 49 |
+
|
| 50 |
+
Example:
|
| 51 |
+
Two input examples get combined to form an output example.
|
| 52 |
+
The input examples are:
|
| 53 |
+
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
|
| 54 |
+
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
|
| 55 |
+
The output example is:
|
| 56 |
+
{
|
| 57 |
+
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
|
| 58 |
+
"inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
|
| 59 |
+
"inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
|
| 60 |
+
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
|
| 61 |
+
"targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
|
| 62 |
+
"targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
|
| 63 |
+
}
|
| 64 |
+
0 represents padding in both the inputs and the outputs.
|
| 65 |
+
Sequences in the incoming examples are truncated to length "length", and the
|
| 66 |
+
sequences in the output examples all have fixed (padded) length "length".
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
dataset: A `tf.data.Dataset`.
|
| 70 |
+
batch_size: Batch size of the packed dataset.
|
| 71 |
+
key2length: An integer, or a dict from feature-key to integer.
|
| 72 |
+
keys: A list of strings (e.g. ["inputs", "targets"]).
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
A `tf.data.Dataset`.
|
| 76 |
+
"""
|
| 77 |
+
raise ValueError("Not implemented in OSS yet.")
|
Tipsomaly/model/big_vision/datasets/tfds.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""TensorFlow Datasets as data source for big_vision."""
|
| 16 |
+
import functools
|
| 17 |
+
|
| 18 |
+
import big_vision.datasets.core as ds_core
|
| 19 |
+
import jax
|
| 20 |
+
import numpy as np
|
| 21 |
+
import overrides
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
import tensorflow_datasets as tfds
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DataSource(ds_core.DataSource):
|
| 27 |
+
"""Use TFDS as a data source."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, name, split, data_dir=None, skip_decode=("image",)):
|
| 30 |
+
self.builder = _get_builder(name, data_dir)
|
| 31 |
+
self.split = split
|
| 32 |
+
# Each host is responsible for a fixed subset of data
|
| 33 |
+
process_splits = tfds.even_splits(split, jax.process_count())
|
| 34 |
+
self.process_split = process_splits[jax.process_index()]
|
| 35 |
+
self.skip_decode = skip_decode
|
| 36 |
+
|
| 37 |
+
@overrides.overrides
|
| 38 |
+
def get_tfdata(
|
| 39 |
+
self, ordered=False, *, process_split=True, allow_cache=True, **kw):
|
| 40 |
+
# The tf.data may use a lot of RAM, so we need to expose the option of not
|
| 41 |
+
# keeping this in memory when we use lots of input pipelines, such as when
|
| 42 |
+
# having many ephemeral evaluators.
|
| 43 |
+
return (_cached_get_dataset if allow_cache else _get_dataset)(
|
| 44 |
+
self.builder, self.skip_decode,
|
| 45 |
+
split=self.process_split if process_split else self.split,
|
| 46 |
+
shuffle_files=not ordered,
|
| 47 |
+
**kw)
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
@overrides.overrides
|
| 51 |
+
def total_examples(self):
|
| 52 |
+
return self.builder.info.splits[self.split].num_examples
|
| 53 |
+
|
| 54 |
+
@overrides.overrides
|
| 55 |
+
def num_examples_per_process(self):
|
| 56 |
+
splits = tfds.even_splits(self.split, jax.process_count())
|
| 57 |
+
return [self.builder.info.splits[s].num_examples for s in splits]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@functools.cache
|
| 61 |
+
def _get_builder(dataset, data_dir):
|
| 62 |
+
if dataset == "from_data_dir":
|
| 63 |
+
return tfds.builder_from_directory(data_dir)
|
| 64 |
+
else:
|
| 65 |
+
return tfds.builder(dataset, data_dir=data_dir, try_gcs=True)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Cache as it may well take 1-2min on large datasets, and we may use the same
|
| 69 |
+
# multiple times (eg various evaluators).
|
| 70 |
+
def _get_dataset(builder, skip_decode, shuffle_files, split=None, **rckw):
|
| 71 |
+
"""Returns a tf.data to be used."""
|
| 72 |
+
ds = builder.as_dataset(
|
| 73 |
+
split=split, shuffle_files=shuffle_files,
|
| 74 |
+
read_config=tfds.ReadConfig(
|
| 75 |
+
skip_prefetch=True, # We prefetch after pipeline.
|
| 76 |
+
try_autocache=False, # We control this, esp. for few-shot.
|
| 77 |
+
add_tfds_id=True,
|
| 78 |
+
**rckw,
|
| 79 |
+
),
|
| 80 |
+
decoders={
|
| 81 |
+
f: tfds.decode.SkipDecoding()
|
| 82 |
+
for f in skip_decode if f in builder.info.features
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
def _hash_tfds_id(example):
|
| 86 |
+
id_ = tf.strings.to_hash_bucket_strong(
|
| 87 |
+
example["tfds_id"],
|
| 88 |
+
np.iinfo(np.uint32).max, # Max value
|
| 89 |
+
[3714561454027272724, 8800639020734831960]) # Magic.
|
| 90 |
+
example["_id"] = tf.bitcast(id_, tf.int32)[0] # good device dtype.
|
| 91 |
+
return example
|
| 92 |
+
|
| 93 |
+
return ds.map(_hash_tfds_id)
|
| 94 |
+
_cached_get_dataset = functools.cache(_get_dataset)
|
Tipsomaly/model/big_vision/evaluators/__init__.py
ADDED
|
File without changes
|
Tipsomaly/model/big_vision/evaluators/classification.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for the classfication task."""
|
| 16 |
+
# pylint: disable=consider-using-from-import
|
| 17 |
+
|
| 18 |
+
import functools
|
| 19 |
+
|
| 20 |
+
from big_vision.evaluators import common
|
| 21 |
+
import big_vision.utils as u
|
| 22 |
+
import jax
|
| 23 |
+
import jax.numpy as jnp
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 27 |
+
# by the end of year 2023.
|
| 28 |
+
API = 'jit'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# To avoid re-compiling the function for every new instance of the same
|
| 32 |
+
# evaluator on a different dataset!
|
| 33 |
+
@functools.cache
|
| 34 |
+
def get_eval_fn(predict_fn, loss_name):
|
| 35 |
+
"""Produces eval function, also applies pmap."""
|
| 36 |
+
@jax.jit
|
| 37 |
+
def _eval_fn(train_state, batch, labels, mask):
|
| 38 |
+
logits, *_ = predict_fn(train_state, batch)
|
| 39 |
+
|
| 40 |
+
# Ignore the entries with all zero labels for evaluation.
|
| 41 |
+
mask *= labels.max(axis=1)
|
| 42 |
+
|
| 43 |
+
loss = getattr(u, loss_name)(
|
| 44 |
+
logits=logits, labels=labels, reduction=False)
|
| 45 |
+
loss = jnp.sum(loss * mask)
|
| 46 |
+
|
| 47 |
+
top1_idx = jnp.argmax(logits, axis=1)
|
| 48 |
+
# Extracts the label at the highest logit index for each image.
|
| 49 |
+
top1_correct = jnp.take_along_axis(
|
| 50 |
+
labels, top1_idx[:, None], axis=1)[:, 0]
|
| 51 |
+
ncorrect = jnp.sum(top1_correct * mask)
|
| 52 |
+
nseen = jnp.sum(mask)
|
| 53 |
+
return ncorrect, loss, nseen
|
| 54 |
+
return _eval_fn
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Evaluator:
|
| 58 |
+
"""Classification evaluator."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, predict_fn, loss_name, label_key='labels', **kw):
|
| 61 |
+
self.get_data_iter, self.steps = common.eval_input_pipeline(**kw)
|
| 62 |
+
self.eval_fn = get_eval_fn(predict_fn, loss_name)
|
| 63 |
+
self.label_key = label_key
|
| 64 |
+
|
| 65 |
+
def run(self, train_state):
|
| 66 |
+
"""Computes all metrics."""
|
| 67 |
+
ncorrect, loss, nseen = 0, 0, 0
|
| 68 |
+
for _, batch in zip(range(self.steps), self.get_data_iter()):
|
| 69 |
+
labels, mask = batch.pop(self.label_key), batch.pop('_mask')
|
| 70 |
+
batch_ncorrect, batch_losses, batch_nseen = jax.device_get(
|
| 71 |
+
self.eval_fn(train_state, batch, labels, mask))
|
| 72 |
+
ncorrect += batch_ncorrect
|
| 73 |
+
loss += batch_losses
|
| 74 |
+
nseen += batch_nseen
|
| 75 |
+
yield ('prec@1', ncorrect / nseen)
|
| 76 |
+
yield ('loss', loss / nseen)
|
Tipsomaly/model/big_vision/evaluators/common.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Utils for evaluators in general."""
|
| 16 |
+
|
| 17 |
+
import dataclasses
|
| 18 |
+
import functools
|
| 19 |
+
import importlib
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Callable
|
| 23 |
+
|
| 24 |
+
from absl import flags
|
| 25 |
+
from big_vision import input_pipeline
|
| 26 |
+
from big_vision.datasets import core as ds_core
|
| 27 |
+
from big_vision.pp import builder as pp_builder
|
| 28 |
+
import big_vision.utils as u
|
| 29 |
+
import flax
|
| 30 |
+
import jax
|
| 31 |
+
import numpy as np
|
| 32 |
+
|
| 33 |
+
from tensorflow.io import gfile
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def from_config(config, predict_fns,
|
| 37 |
+
write_note=lambda s: s,
|
| 38 |
+
get_steps=lambda key, cfg: cfg[f"{key}_steps"],
|
| 39 |
+
devices=None):
|
| 40 |
+
"""Creates a list of evaluators based on `config`."""
|
| 41 |
+
evaluators = []
|
| 42 |
+
specs = config.get("evals", {})
|
| 43 |
+
|
| 44 |
+
for name, cfg in specs.items():
|
| 45 |
+
write_note(name)
|
| 46 |
+
|
| 47 |
+
# Pop all generic settings off so we're left with eval's kwargs in the end.
|
| 48 |
+
cfg = cfg.to_dict()
|
| 49 |
+
module = cfg.pop("type", name)
|
| 50 |
+
pred_key = cfg.pop("pred", "predict")
|
| 51 |
+
pred_kw = cfg.pop("pred_kw", None)
|
| 52 |
+
prefix = cfg.pop("prefix", f"{name}/")
|
| 53 |
+
cfg.pop("skip_first", None)
|
| 54 |
+
logsteps = get_steps("log", cfg)
|
| 55 |
+
for typ in ("steps", "epochs", "examples", "percent"):
|
| 56 |
+
cfg.pop(f"log_{typ}", None)
|
| 57 |
+
|
| 58 |
+
# Use same batch_size as eval by default, to reduce fragmentation.
|
| 59 |
+
# TODO: eventually remove all the deprecated names...
|
| 60 |
+
cfg["batch_size"] = cfg.get("batch_size") or config.get("batch_size_eval") or config.get("input.batch_size") or config.get("batch_size") # pylint: disable=line-too-long
|
| 61 |
+
|
| 62 |
+
module = importlib.import_module(f"big_vision.evaluators.{module}")
|
| 63 |
+
|
| 64 |
+
if devices is not None:
|
| 65 |
+
cfg["devices"] = devices
|
| 66 |
+
|
| 67 |
+
api_type = getattr(module, "API", "pmap")
|
| 68 |
+
if api_type == "pmap" and "devices" in cfg:
|
| 69 |
+
raise RuntimeError(
|
| 70 |
+
"You are seemingly using the old pmap-based evaluator, but with "
|
| 71 |
+
"jit-based train loop, see (internal link) for more details.")
|
| 72 |
+
if api_type == "jit" and "devices" not in cfg:
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
"You are seemingly using new jit-based evaluator, but with "
|
| 75 |
+
"old pmap-based train loop, see (internal link) for more details.")
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
predict_fn = predict_fns[pred_key]
|
| 79 |
+
except KeyError as e:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"Unknown predict_fn '{pred_key}'. Available predict_fns are:\n"
|
| 82 |
+
+ "\n".join(predict_fns)) from e
|
| 83 |
+
if pred_kw is not None:
|
| 84 |
+
predict_fn = _CacheablePartial(predict_fn, flax.core.freeze(pred_kw))
|
| 85 |
+
evaluator = module.Evaluator(predict_fn, **cfg)
|
| 86 |
+
evaluators.append((name, evaluator, logsteps, prefix))
|
| 87 |
+
|
| 88 |
+
return evaluators
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclasses.dataclass(frozen=True, eq=True)
|
| 92 |
+
class _CacheablePartial:
|
| 93 |
+
"""partial(fn, **kwargs) that defines hash and eq - to help with jit caches.
|
| 94 |
+
|
| 95 |
+
This is particularly common in evaluators when one has many evaluator
|
| 96 |
+
instances that run on difference slices of data.
|
| 97 |
+
|
| 98 |
+
Example:
|
| 99 |
+
|
| 100 |
+
```
|
| 101 |
+
f1 = _CacheablePartial(fn, a=1)
|
| 102 |
+
jax.jit(f1)(...)
|
| 103 |
+
jax.jit(_CacheablePartial(fn, a=1))(...) # fn won't be retraced.
|
| 104 |
+
del f1
|
| 105 |
+
jax.jit(_CacheablePartial(fn, a=1))(...) # fn will be retraced.
|
| 106 |
+
```
|
| 107 |
+
"""
|
| 108 |
+
fn: Callable[..., Any]
|
| 109 |
+
kwargs: flax.core.FrozenDict
|
| 110 |
+
|
| 111 |
+
def __call__(self, *args, **kwargs):
|
| 112 |
+
return functools.partial(self.fn, **self.kwargs)(*args, **kwargs)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def eval_input_pipeline(
|
| 116 |
+
data, pp_fn, batch_size, devices, keep_on_cpu=(),
|
| 117 |
+
cache="pipeline", prefetch=1, warmup=False,
|
| 118 |
+
):
|
| 119 |
+
"""Create an input pipeline in the way used by most evaluators.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
data: The configuration to create the data source (like for training).
|
| 123 |
+
pp_fn: A string representing the preprocessing to be performed.
|
| 124 |
+
batch_size: The batch size to use.
|
| 125 |
+
devices: The devices that the batches are sharded and pre-fetched onto.
|
| 126 |
+
keep_on_cpu: See input_pipeline.start_global. Entries in the batch that
|
| 127 |
+
should be kept on the CPU, hence could be ragged or of string type.
|
| 128 |
+
cache: One of "none", "pipeline", "raw_data", "final_data". Determines what
|
| 129 |
+
part of the input stream should be cached across evaluator runs. They use
|
| 130 |
+
more and more RAM, but make evals faster, in that order.
|
| 131 |
+
- "none": Entirely re-create and destroy the input pipeline each run.
|
| 132 |
+
- "pipeline": Keep the (tf.data) pipeline object alive across runs.
|
| 133 |
+
- "raw_data": Cache the full raw data before pre-processing.
|
| 134 |
+
- "final_data": Cache the full raw data after pre-processing.
|
| 135 |
+
prefetch: How many batches to fetch ahead.
|
| 136 |
+
warmup: Start fetching the first batch at creation time (right now),
|
| 137 |
+
instead of once the iteration starts.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
A tuple (get_iter, steps), the first element is a function that returns
|
| 141 |
+
the iterator to be used for an evaluation, the second one is how many steps
|
| 142 |
+
should be iterated for doing one evaluation.
|
| 143 |
+
"""
|
| 144 |
+
assert (
|
| 145 |
+
cache is None
|
| 146 |
+
or cache.lower() in ("none", "pipeline", "raw_data", "final_data")
|
| 147 |
+
), f"Unknown value for cache: {cache}"
|
| 148 |
+
data_source = ds_core.get(**data)
|
| 149 |
+
tfdata, steps = input_pipeline.make_for_inference(
|
| 150 |
+
data_source.get_tfdata(ordered=True, allow_cache=cache.lower() != "none"),
|
| 151 |
+
batch_size=batch_size,
|
| 152 |
+
num_ex_per_process=data_source.num_examples_per_process(),
|
| 153 |
+
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn, str(data)),
|
| 154 |
+
cache_final=cache == "raw_data",
|
| 155 |
+
cache_raw=cache == "final_data")
|
| 156 |
+
get_data_iter = lambda: input_pipeline.start_global(
|
| 157 |
+
tfdata, devices, prefetch, keep_on_cpu, warmup)
|
| 158 |
+
|
| 159 |
+
# Possibly create one persistent iterator:
|
| 160 |
+
if cache in ("pipeline", "raw_data", "final_data"):
|
| 161 |
+
data_iter = get_data_iter()
|
| 162 |
+
get_data_iter = lambda: data_iter
|
| 163 |
+
|
| 164 |
+
return get_data_iter, steps
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def process_sum(tree):
|
| 168 |
+
"""Sums the pytree across all processes."""
|
| 169 |
+
if jax.process_count() == 1: # Avoids corner-cases on donuts.
|
| 170 |
+
return tree
|
| 171 |
+
|
| 172 |
+
with jax.transfer_guard_device_to_host("allow"):
|
| 173 |
+
gathered = jax.experimental.multihost_utils.process_allgather(tree)
|
| 174 |
+
return jax.tree.map(functools.partial(np.sum, axis=0), gathered)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def resolve_outfile(outfile, split="", **kw):
|
| 178 |
+
if not outfile:
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
# A caveat: when workdir doesn't exist but is in the `outfile`, we should
|
| 182 |
+
# skip. This is common in small runs or runlocal debuggings.
|
| 183 |
+
if "{workdir}" in outfile and not flags.FLAGS.workdir:
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
return outfile.format(
|
| 187 |
+
workdir=flags.FLAGS.workdir,
|
| 188 |
+
split="".join(c if c not in "[]%:" else "_" for c in (split or "")),
|
| 189 |
+
step=getattr(u.chrono, "prev_step", None),
|
| 190 |
+
**kw,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def multiprocess_write_json(outfile, jobj): # jobj = "json object"
|
| 195 |
+
"""Write a single json file combining all processes' `jobj`s."""
|
| 196 |
+
if not outfile:
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
outfile = resolve_outfile(outfile)
|
| 200 |
+
gfile.makedirs(os.path.dirname(outfile))
|
| 201 |
+
|
| 202 |
+
if isinstance(jobj, list):
|
| 203 |
+
combine_fn = list.extend
|
| 204 |
+
elif isinstance(jobj, dict):
|
| 205 |
+
combine_fn = dict.update
|
| 206 |
+
else:
|
| 207 |
+
raise TypeError(f"Can only write list or dict jsons, but got {type(jobj)}")
|
| 208 |
+
|
| 209 |
+
# First, each process writes its own file.
|
| 210 |
+
with gfile.GFile(outfile + f".p{jax.process_index()}", "w+") as f:
|
| 211 |
+
f.write(json.dumps(jobj))
|
| 212 |
+
|
| 213 |
+
u.sync() # Wait for all files to be written; `with` above does close/flush.
|
| 214 |
+
|
| 215 |
+
# Have process 0 collect, concat, and write final output.
|
| 216 |
+
all_json = type(jobj)()
|
| 217 |
+
if jax.process_index() == 0:
|
| 218 |
+
for pid in range(jax.process_count()):
|
| 219 |
+
with gfile.GFile(outfile + f".p{pid}", "r") as f:
|
| 220 |
+
combine_fn(all_json, json.loads(f.read()))
|
| 221 |
+
with gfile.GFile(outfile, "w+") as f:
|
| 222 |
+
f.write(json.dumps(all_json))
|
| 223 |
+
|
| 224 |
+
# Cleanup time
|
| 225 |
+
u.sync()
|
| 226 |
+
gfile.remove(outfile + f".p{jax.process_index()}")
|
| 227 |
+
|
| 228 |
+
return all_json
|
Tipsomaly/model/big_vision/evaluators/fewshot_lsr.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Utils for few-shot evaluation."""
|
| 16 |
+
# pylint: disable=consider-using-from-import,g-importing-member
|
| 17 |
+
|
| 18 |
+
import functools
|
| 19 |
+
|
| 20 |
+
import big_vision.datasets.core as ds_core
|
| 21 |
+
import big_vision.input_pipeline as input_pipeline
|
| 22 |
+
import big_vision.pp.builder as pp_builder
|
| 23 |
+
import big_vision.utils as u
|
| 24 |
+
import jax
|
| 25 |
+
import jax.numpy as jnp
|
| 26 |
+
from jax.sharding import NamedSharding as Sharding
|
| 27 |
+
from jax.sharding import PartitionSpec as P
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
BIAS_CONSTANT = 100.0
|
| 31 |
+
|
| 32 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 33 |
+
# by the end of year 2023.
|
| 34 |
+
API = "jit"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Setup function for few-shot regression on CPU to avoid "polluting" the TPU.
|
| 38 |
+
@u.jit_cpu(static_argnums=(2,))
|
| 39 |
+
def _precompute_cache(x, y, num_classes):
|
| 40 |
+
"""Cache quantities to speed-up the computation of L2-regularized least-sq."""
|
| 41 |
+
# Whiten
|
| 42 |
+
mean = jnp.mean(x, axis=0, keepdims=True)
|
| 43 |
+
std = jnp.std(x, axis=0, keepdims=True) + 1e-5
|
| 44 |
+
x = (x - mean) / std
|
| 45 |
+
|
| 46 |
+
# Add a constant feature for the bias, large so it's almost unregularized:
|
| 47 |
+
x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
|
| 48 |
+
|
| 49 |
+
# To one-hot representation rescaled into {-1, 1}
|
| 50 |
+
y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0
|
| 51 |
+
|
| 52 |
+
num_points, dim = x.shape
|
| 53 |
+
# Let N be the number of points, D the dimension and C the number of classes.
|
| 54 |
+
# We have x of shape (N, D) and y of shape (N, C).
|
| 55 |
+
# For least-squares, we can compute
|
| 56 |
+
#
|
| 57 |
+
# (A) when N >= D, (x^T x + l2 Id)^{-1} x^T y
|
| 58 |
+
# (B) when D > N, x^T (x x^T + l2 Id)^{-1} y
|
| 59 |
+
#
|
| 60 |
+
# We pre-compute the eigen-decomposition of either x^T x or x x^T which
|
| 61 |
+
# becomes q diag(eigs) q^T with q unitary matrix either (D, D) or (N, N)
|
| 62 |
+
# and eigs a vector (D,) or (N,).
|
| 63 |
+
#
|
| 64 |
+
# For any l2 > 0, we can compute (x^T x + l2 Id)^{-1} or (x x^T + l2 Id)^{-1}
|
| 65 |
+
# by simply computing q (diag(eigs) + l2 Id)^{-1} q^T.
|
| 66 |
+
# (SVD would be more natural here, but it proved slower, so we use eigh)
|
| 67 |
+
#
|
| 68 |
+
# Both cases (A) and (B) can be viewed as lhs (diag(eigs) + l2 Id)^{-1} rhs,
|
| 69 |
+
# where lhs/rhs are pre-computed left/right-hand sides to specify.
|
| 70 |
+
#
|
| 71 |
+
# Detailed evaluation in terms of time and fewshot metrics can be found in
|
| 72 |
+
# (internal link)
|
| 73 |
+
#
|
| 74 |
+
# Implemented by Rodolphe Jenatton.
|
| 75 |
+
if num_points >= dim:
|
| 76 |
+
eigs, q = jnp.linalg.eigh(x.T @ x)
|
| 77 |
+
rhs = q.T @ (x.T @ y)
|
| 78 |
+
lhs = q
|
| 79 |
+
else:
|
| 80 |
+
eigs, q = jnp.linalg.eigh(x @ x.T)
|
| 81 |
+
rhs = q.T @ y
|
| 82 |
+
lhs = x.T @ q
|
| 83 |
+
|
| 84 |
+
cache = {
|
| 85 |
+
"eigs": eigs,
|
| 86 |
+
"rhs": rhs,
|
| 87 |
+
"lhs": lhs,
|
| 88 |
+
"mean": mean,
|
| 89 |
+
"std": std
|
| 90 |
+
}
|
| 91 |
+
return cache
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@u.jit_cpu()
|
| 95 |
+
def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg):
|
| 96 |
+
"""Computes (x,y) linear regression accuracy on (x_test, y_test)."""
|
| 97 |
+
|
| 98 |
+
x_test = (x_test - cache["mean"]) / cache["std"]
|
| 99 |
+
x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
|
| 100 |
+
|
| 101 |
+
rhs = cache["rhs"]
|
| 102 |
+
lhs = cache["lhs"]
|
| 103 |
+
eigs = cache["eigs"]
|
| 104 |
+
|
| 105 |
+
# See comments in _precompute_cache for context about the formula.
|
| 106 |
+
scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs))
|
| 107 |
+
scaling = scaling.reshape((1, -1))
|
| 108 |
+
w = (lhs * scaling) @ rhs
|
| 109 |
+
# Predict test-set values and measure their accuracy
|
| 110 |
+
preds = jnp.argmax(x_test @ w, axis=1)
|
| 111 |
+
return jnp.mean(preds == y_test)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Evaluator:
|
| 115 |
+
"""Class for few-shot evaluation."""
|
| 116 |
+
|
| 117 |
+
def __init__(self, predict_fn, batch_size,
|
| 118 |
+
datasets, shots, l2_reg,
|
| 119 |
+
pp_train, pp_eval, display_first,
|
| 120 |
+
representation_layer=None, num_seeds=3,
|
| 121 |
+
label_key="label", mask_key="_mask", data_dir=None, *,
|
| 122 |
+
devices):
|
| 123 |
+
self.datasets = datasets
|
| 124 |
+
self.shots = shots
|
| 125 |
+
self.l2_reg = l2_reg
|
| 126 |
+
self.batch_size = batch_size
|
| 127 |
+
self.pp_tr = pp_train
|
| 128 |
+
self.pp_te = pp_eval
|
| 129 |
+
self.display_first = display_first
|
| 130 |
+
self._datasets = {} # Cache for tfds data. Persists while object is alive.
|
| 131 |
+
self._repr = {} # Cache for precomputed repr. Persists within the run call.
|
| 132 |
+
self.num_seeds = num_seeds
|
| 133 |
+
self.label_key = label_key
|
| 134 |
+
self.mask_key = mask_key
|
| 135 |
+
self.data_dir = data_dir
|
| 136 |
+
self.devices = devices
|
| 137 |
+
self.mesh = jax.sharding.Mesh(devices, ("devices",))
|
| 138 |
+
self.repr_fn = self.get_representation_fn(
|
| 139 |
+
predict_fn, representation_layer)
|
| 140 |
+
|
| 141 |
+
def get_representation_fn(self, predict_fn, representation_layer):
|
| 142 |
+
# `out_shardings=Sharding(self.mesh, P())` will "all_gather" the outputs.
|
| 143 |
+
@functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P()))
|
| 144 |
+
def _repr_fn(train_state, batch, labels, mask):
|
| 145 |
+
zimg, *_, out = predict_fn(train_state, batch)
|
| 146 |
+
if representation_layer is not None:
|
| 147 |
+
rep = u.tree_get(out, representation_layer)
|
| 148 |
+
else:
|
| 149 |
+
rep = zimg
|
| 150 |
+
return rep, labels, mask
|
| 151 |
+
return _repr_fn
|
| 152 |
+
|
| 153 |
+
# Setup input pipeline.
|
| 154 |
+
def _get_dataset(self, dataset, train_split, test_split):
|
| 155 |
+
"""Lazy-loads given dataset."""
|
| 156 |
+
key = (dataset, train_split, test_split)
|
| 157 |
+
try:
|
| 158 |
+
return self._datasets[key]
|
| 159 |
+
except KeyError:
|
| 160 |
+
# NOTE: only supporting TFDS data for now for bwd compat/lazyness.
|
| 161 |
+
train_data = ds_core.get(
|
| 162 |
+
name=dataset, split=train_split, data_dir=self.data_dir
|
| 163 |
+
)
|
| 164 |
+
test_data = ds_core.get(
|
| 165 |
+
name=dataset, split=test_split, data_dir=self.data_dir
|
| 166 |
+
)
|
| 167 |
+
train_ds, batches_tr = input_pipeline.make_for_inference(
|
| 168 |
+
train_data.get_tfdata(ordered=True),
|
| 169 |
+
num_ex_per_process=train_data.num_examples_per_process(),
|
| 170 |
+
batch_size=self.batch_size,
|
| 171 |
+
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr))
|
| 172 |
+
test_ds, batches_te = input_pipeline.make_for_inference(
|
| 173 |
+
test_data.get_tfdata(ordered=True),
|
| 174 |
+
num_ex_per_process=test_data.num_examples_per_process(),
|
| 175 |
+
batch_size=self.batch_size,
|
| 176 |
+
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te))
|
| 177 |
+
|
| 178 |
+
num_classes = train_data.builder.info.features[self.label_key].num_classes
|
| 179 |
+
return self._datasets.setdefault(
|
| 180 |
+
key, (train_ds, batches_tr, test_ds, batches_te, num_classes))
|
| 181 |
+
|
| 182 |
+
def _get_repr(self, params, data, steps):
|
| 183 |
+
"""Compute representation for the whole dataset."""
|
| 184 |
+
pre_logits_list = []
|
| 185 |
+
labels_list = []
|
| 186 |
+
for batch, _ in zip(
|
| 187 |
+
input_pipeline.start_global(data, self.devices, 0), range(steps)):
|
| 188 |
+
labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key)
|
| 189 |
+
pre_logits, labels, mask = jax.device_get(self.repr_fn(
|
| 190 |
+
params, batch, labels, mask))
|
| 191 |
+
mask = mask.astype(bool)
|
| 192 |
+
pre_logits_list.append(pre_logits[mask])
|
| 193 |
+
labels_list.append(labels[mask])
|
| 194 |
+
pre_logits = np.concatenate(pre_logits_list, axis=0)
|
| 195 |
+
labels = np.concatenate(labels_list, axis=0)
|
| 196 |
+
|
| 197 |
+
return pre_logits, labels
|
| 198 |
+
|
| 199 |
+
def compute_fewshot_metrics(self, train_state, seed,
|
| 200 |
+
dataset, train_split, test_split):
|
| 201 |
+
"""Compute few-shot metrics on one dataset."""
|
| 202 |
+
if dataset in self._repr:
|
| 203 |
+
repr_train, labels_train, repr_test, labels_test, num_classes = (
|
| 204 |
+
self._repr[dataset])
|
| 205 |
+
else:
|
| 206 |
+
train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset(
|
| 207 |
+
dataset, train_split, test_split)
|
| 208 |
+
repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr)
|
| 209 |
+
repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te)
|
| 210 |
+
self._repr[dataset] = (repr_train, labels_train,
|
| 211 |
+
repr_test, labels_test,
|
| 212 |
+
num_classes)
|
| 213 |
+
|
| 214 |
+
# Collect where we have samples of which classes.
|
| 215 |
+
rng = np.random.default_rng(seed)
|
| 216 |
+
class_indices = [rng.permutation(np.where(labels_train == cls_i)[0])
|
| 217 |
+
for cls_i in range(num_classes)]
|
| 218 |
+
|
| 219 |
+
results = {}
|
| 220 |
+
for shots in self.shots:
|
| 221 |
+
all_idx = [indices[:shots] for indices in class_indices]
|
| 222 |
+
all_idx = np.concatenate(all_idx, axis=0)
|
| 223 |
+
x = u.put_cpu(repr_train[all_idx])
|
| 224 |
+
y = u.put_cpu(labels_train[all_idx])
|
| 225 |
+
repr_test, labels_test = u.put_cpu((repr_test, labels_test))
|
| 226 |
+
|
| 227 |
+
# Note the code is optimized to solve multiple LSR tasks for changing l2
|
| 228 |
+
# strength, even though we currently used the fixed l2_reg constant.
|
| 229 |
+
cache = _precompute_cache(x, y, num_classes)
|
| 230 |
+
acc = _eig_fewshot_acc_fn(
|
| 231 |
+
cache, repr_test, labels_test, u.put_cpu(self.l2_reg))
|
| 232 |
+
results[shots] = jax.device_get(acc)
|
| 233 |
+
|
| 234 |
+
return results
|
| 235 |
+
|
| 236 |
+
def run(self, train_state):
|
| 237 |
+
"""New API executed in terms of old API."""
|
| 238 |
+
self._repr = {}
|
| 239 |
+
for seed in range(self.num_seeds):
|
| 240 |
+
for name, dataset_args in self.datasets.items():
|
| 241 |
+
result = self.compute_fewshot_metrics(train_state, seed, *dataset_args)
|
| 242 |
+
for shots, v in result.items():
|
| 243 |
+
prefix = "a/" if (name, shots) in self.display_first else "z/"
|
| 244 |
+
suffix = f"-seed-{seed}"
|
| 245 |
+
yield f"{prefix}{name}_{shots}shot{suffix}", v
|
Tipsomaly/model/big_vision/evaluators/mean.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator for computing mean of per-example metrics.
|
| 16 |
+
|
| 17 |
+
This evaluator can be used in two ways:
|
| 18 |
+
1. Create a new evaluator with reduced boilerplate by inheriting from it.
|
| 19 |
+
2. For quick prototyping, use this with predict_fns which return the metrics.
|
| 20 |
+
"""
|
| 21 |
+
from functools import partial
|
| 22 |
+
from typing import Mapping
|
| 23 |
+
|
| 24 |
+
from big_vision.evaluators import common
|
| 25 |
+
|
| 26 |
+
import jax
|
| 27 |
+
import jax.numpy as jnp
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 32 |
+
# by the end of year 2023.
|
| 33 |
+
API = 'jit'
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Note: global to avoid jax re-compiling across different evaluator instances.
|
| 37 |
+
@partial(jax.jit, static_argnums=0)
|
| 38 |
+
def _run_predict_fn(predict_fn, train_state, batch):
|
| 39 |
+
"""Sum per-example metrics weighted by `_mask`."""
|
| 40 |
+
metrics = predict_fn(train_state, batch)
|
| 41 |
+
mask = batch['_mask']
|
| 42 |
+
# Sanity check output format of predict_fn.
|
| 43 |
+
assert isinstance(metrics, Mapping), 'predict_fn must return a dict'
|
| 44 |
+
for y in jax.tree.leaves(metrics):
|
| 45 |
+
if y.shape != mask.shape:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
f'Expected per-example metrics of shape {mask.shape} found '
|
| 48 |
+
f'{jax.tree.map(lambda x: x.shape, metrics)}.')
|
| 49 |
+
metrics = {**metrics, '_mask': mask}
|
| 50 |
+
return jax.tree.map(lambda x: jnp.sum(jnp.where(mask, x, 0)), metrics)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Evaluator:
|
| 54 |
+
"""Report the mean of per-example metrics computed by predict_fn.
|
| 55 |
+
|
| 56 |
+
`predict_fn(params, batch)` must return a dict from metric name to
|
| 57 |
+
per-example metrics of shape [batch_size].
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, predict_fn, **kw):
|
| 61 |
+
self.get_data_iter, self.steps = common.eval_input_pipeline(**kw)
|
| 62 |
+
self.predict_fn = partial(_run_predict_fn, predict_fn)
|
| 63 |
+
|
| 64 |
+
def run(self, train_state):
|
| 65 |
+
"""Computes all metrics."""
|
| 66 |
+
metrics = []
|
| 67 |
+
|
| 68 |
+
# Compute batch metrics without blocking.
|
| 69 |
+
for _, batch in zip(range(self.steps), self.get_data_iter()):
|
| 70 |
+
batch_metrics = self.predict_fn(train_state, batch)
|
| 71 |
+
metrics.append(batch_metrics)
|
| 72 |
+
|
| 73 |
+
# Transfer metrics (blocking).
|
| 74 |
+
metrics = jax.device_get(metrics)
|
| 75 |
+
|
| 76 |
+
# Accumulate metrics across batches.
|
| 77 |
+
metrics_sum = jax.tree.map(lambda *x: np.sum(x), *metrics)
|
| 78 |
+
mask_sum = metrics_sum.pop('_mask')
|
| 79 |
+
for key, value_sum in metrics_sum.items():
|
| 80 |
+
yield (key, value_sum / mask_sum)
|
Tipsomaly/model/big_vision/evaluators/save.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Evaluator that save inputs and outputs of prediction functions."""
|
| 16 |
+
import functools
|
| 17 |
+
|
| 18 |
+
from absl import flags
|
| 19 |
+
from absl import logging
|
| 20 |
+
|
| 21 |
+
from big_vision import input_pipeline
|
| 22 |
+
from big_vision import optax as bv_optax
|
| 23 |
+
from big_vision import utils
|
| 24 |
+
from big_vision.datasets import core as ds_core
|
| 25 |
+
from big_vision.pp import builder as pp_builder
|
| 26 |
+
|
| 27 |
+
import jax
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
# Temporary global flag to facilitate backwards compatability. Will be removed
|
| 31 |
+
# by the end of year 2023.
|
| 32 |
+
API = 'jit'
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Note: global to avoid jax re-compiling across different evaluator instances.
|
| 36 |
+
def _run_predict_fn(predict_fn, train_state, batch):
|
| 37 |
+
"""Run predict_fn and gather all outputs on all devices."""
|
| 38 |
+
y = predict_fn(train_state, batch)
|
| 39 |
+
return {'inputs': batch, 'outputs': y}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Evaluator:
|
| 43 |
+
"""Evaluator that saves the inputs and outputs of a prediction function.
|
| 44 |
+
|
| 45 |
+
Example configuration:
|
| 46 |
+
|
| 47 |
+
```
|
| 48 |
+
config.evals.save_pred = {
|
| 49 |
+
'type': 'save',
|
| 50 |
+
'pred': 'inference',
|
| 51 |
+
'outfile': '{workdir}/inference-{step:09d}.npz',
|
| 52 |
+
'data': ..., 'pp_fn': ..., 'log_steps': ...,
|
| 53 |
+
}
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Results can then be easily inspected in a notebook such as:
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
results = utils.load_checkpoint("<full_path_to_outfile>")
|
| 60 |
+
inputs, outputs = (results["inputs"], results["outputs"])
|
| 61 |
+
```
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, predict_fn, data, pp_fn, batch_size, outfile,
|
| 65 |
+
cache_final=True, cache_raw=False, prefetch=1, *, devices):
|
| 66 |
+
replicate = jax.sharding.NamedSharding(
|
| 67 |
+
jax.sharding.Mesh(devices, ('devices',)),
|
| 68 |
+
jax.sharding.PartitionSpec()
|
| 69 |
+
)
|
| 70 |
+
self.predict_fn = functools.partial(
|
| 71 |
+
jax.jit(_run_predict_fn, static_argnums=0, out_shardings=replicate),
|
| 72 |
+
predict_fn,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
data = ds_core.get(**data)
|
| 76 |
+
self.dataset, self.steps = input_pipeline.make_for_inference(
|
| 77 |
+
data.get_tfdata(ordered=True),
|
| 78 |
+
batch_size=batch_size,
|
| 79 |
+
num_ex_per_process=data.num_examples_per_process(),
|
| 80 |
+
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
|
| 81 |
+
cache_final=cache_final,
|
| 82 |
+
cache_raw=cache_raw,
|
| 83 |
+
)
|
| 84 |
+
self.data_iter = input_pipeline.start_global(
|
| 85 |
+
self.dataset, devices, prefetch
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.outfile = outfile
|
| 89 |
+
|
| 90 |
+
def run(self, train_state):
|
| 91 |
+
"""Compute all predictions, gather in main host and save in outfile."""
|
| 92 |
+
step = jax.device_get(bv_optax.get_count(train_state['opt'], jittable=True))
|
| 93 |
+
outfile = self.outfile.format(workdir=flags.FLAGS.workdir, step=step)
|
| 94 |
+
|
| 95 |
+
count = 0
|
| 96 |
+
outputs = []
|
| 97 |
+
for _, batch in zip(range(self.steps), self.data_iter):
|
| 98 |
+
out = self.predict_fn(train_state, batch)
|
| 99 |
+
if jax.process_index():
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
out = jax.device_get(out)
|
| 103 |
+
mask = out['inputs']['_mask']
|
| 104 |
+
out = jax.tree.map(lambda x: x[mask == 1], out) # pylint: disable=cell-var-from-loop
|
| 105 |
+
count += mask.shape[0]
|
| 106 |
+
out['inputs'].pop('_mask')
|
| 107 |
+
outputs.append(out)
|
| 108 |
+
|
| 109 |
+
logging.log_every_n_seconds(
|
| 110 |
+
logging.INFO, 'Processed %i examples so far.', 60,
|
| 111 |
+
count)
|
| 112 |
+
|
| 113 |
+
if jax.process_index():
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
logging.info('Saving %d examples in %s', count, outfile)
|
| 117 |
+
outputs = jax.tree.map(lambda *x: np.concatenate(x, axis=0), *outputs)
|
| 118 |
+
utils.save_checkpoint(outputs, outfile, compressed=True)
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
yield None # pylint: disable=unreachable
|
Tipsomaly/model/big_vision/models/__init__.py
ADDED
|
File without changes
|
Tipsomaly/model/big_vision/models/bit.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""ResNet V1 with GroupNorm."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Sequence, Union
|
| 18 |
+
|
| 19 |
+
from big_vision import utils
|
| 20 |
+
from big_vision.models import common
|
| 21 |
+
import flax
|
| 22 |
+
import flax.linen as nn
|
| 23 |
+
import flax.training.checkpoints
|
| 24 |
+
import jax.numpy as jnp
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def weight_standardize(w, axis, eps):
|
| 29 |
+
w = w - jnp.mean(w, axis=axis)
|
| 30 |
+
w = w / (jnp.std(w, axis=axis) + eps)
|
| 31 |
+
return w
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class StdConv(nn.Conv):
|
| 35 |
+
|
| 36 |
+
def param(self, name, *a, **kw):
|
| 37 |
+
param = super().param(name, *a, **kw)
|
| 38 |
+
if name == "kernel":
|
| 39 |
+
param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5)
|
| 40 |
+
return param
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ResidualUnit(nn.Module):
|
| 44 |
+
"""Bottleneck ResNet block."""
|
| 45 |
+
nmid: Optional[int] = None
|
| 46 |
+
strides: Sequence[int] = (1, 1)
|
| 47 |
+
|
| 48 |
+
@nn.compact
|
| 49 |
+
def __call__(self, x):
|
| 50 |
+
nmid = self.nmid or x.shape[-1] // 4
|
| 51 |
+
nout = nmid * 4
|
| 52 |
+
|
| 53 |
+
residual = x
|
| 54 |
+
if x.shape[-1] != nout or self.strides != (1, 1):
|
| 55 |
+
residual = StdConv(nout, (1, 1), self.strides, use_bias=False,
|
| 56 |
+
name="conv_proj")(residual)
|
| 57 |
+
residual = nn.GroupNorm(name="gn_proj")(residual)
|
| 58 |
+
|
| 59 |
+
y = StdConv(nmid, (1, 1), use_bias=False, name="conv1")(x)
|
| 60 |
+
y = nn.GroupNorm(name="gn1")(y)
|
| 61 |
+
y = nn.relu(y)
|
| 62 |
+
y = StdConv(nmid, (3, 3), self.strides, use_bias=False, name="conv2")(y)
|
| 63 |
+
y = nn.GroupNorm(name="gn2")(y)
|
| 64 |
+
y = nn.relu(y)
|
| 65 |
+
y = StdConv(nout, (1, 1), use_bias=False, name="conv3")(y)
|
| 66 |
+
|
| 67 |
+
y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y)
|
| 68 |
+
y = nn.relu(residual + y)
|
| 69 |
+
return y
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ResNetStage(nn.Module):
|
| 73 |
+
"""One stage of ResNet."""
|
| 74 |
+
block_size: int
|
| 75 |
+
first_stride: Sequence[int] = (1, 1)
|
| 76 |
+
nmid: Optional[int] = None
|
| 77 |
+
|
| 78 |
+
@nn.compact
|
| 79 |
+
def __call__(self, x):
|
| 80 |
+
x = ResidualUnit(self.nmid, strides=self.first_stride, name="unit1")(x)
|
| 81 |
+
for i in range(1, self.block_size):
|
| 82 |
+
x = ResidualUnit(self.nmid, name=f"unit{i + 1}")(x)
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Model(nn.Module):
|
| 87 |
+
"""ResNetV1."""
|
| 88 |
+
num_classes: Optional[int] = None
|
| 89 |
+
width: float = 1
|
| 90 |
+
depth: Union[int, Sequence[int]] = 50
|
| 91 |
+
|
| 92 |
+
@nn.compact
|
| 93 |
+
def __call__(self, image, *, train=False):
|
| 94 |
+
del train # Unused
|
| 95 |
+
blocks = get_block_desc(self.depth)
|
| 96 |
+
width = int(64 * self.width)
|
| 97 |
+
|
| 98 |
+
out = {}
|
| 99 |
+
|
| 100 |
+
# Root block
|
| 101 |
+
x = StdConv(width, (7, 7), (2, 2), use_bias=False, name="conv_root")(image)
|
| 102 |
+
x = nn.GroupNorm(name="gn_root")(x)
|
| 103 |
+
x = nn.relu(x)
|
| 104 |
+
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
|
| 105 |
+
out["stem"] = x
|
| 106 |
+
|
| 107 |
+
# Stages
|
| 108 |
+
x = ResNetStage(blocks[0], nmid=width, name="block1")(x)
|
| 109 |
+
out["stage1"] = x
|
| 110 |
+
for i, block_size in enumerate(blocks[1:], 1):
|
| 111 |
+
x = ResNetStage(block_size, nmid=width * 2 ** i,
|
| 112 |
+
first_stride=(2, 2), name=f"block{i + 1}")(x)
|
| 113 |
+
out[f"stage{i + 1}"] = x
|
| 114 |
+
out["pre_logits_2d"] = x
|
| 115 |
+
|
| 116 |
+
# Head
|
| 117 |
+
x = out["pre_logits"] = jnp.mean(x, axis=(1, 2))
|
| 118 |
+
|
| 119 |
+
if self.num_classes:
|
| 120 |
+
head = nn.Dense(self.num_classes, name="head",
|
| 121 |
+
kernel_init=nn.initializers.zeros)
|
| 122 |
+
out["logits_2d"] = head(out["pre_logits_2d"])
|
| 123 |
+
x = out["logits"] = head(out["pre_logits"])
|
| 124 |
+
|
| 125 |
+
return x, out
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# A dictionary mapping the number of layers in a resnet to the number of
|
| 129 |
+
# blocks in each stage of the model.
|
| 130 |
+
# NOTE: Does not include 18/34 as they also need non-bottleneck block!
|
| 131 |
+
def get_block_desc(depth):
|
| 132 |
+
if isinstance(depth, list): # Be robust to silly mistakes.
|
| 133 |
+
depth = tuple(depth)
|
| 134 |
+
return {
|
| 135 |
+
26: [2, 2, 2, 2], # From timm, gets ~75% on ImageNet.
|
| 136 |
+
50: [3, 4, 6, 3],
|
| 137 |
+
101: [3, 4, 23, 3],
|
| 138 |
+
152: [3, 8, 36, 3],
|
| 139 |
+
200: [3, 24, 36, 3]
|
| 140 |
+
}.get(depth, depth)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def fix_old_checkpoints(params):
|
| 144 |
+
"""Modifies params from old checkpoints to run with current implementation."""
|
| 145 |
+
params = flax.core.unfreeze(
|
| 146 |
+
flax.training.checkpoints.convert_pre_linen(params))
|
| 147 |
+
# Old linen used to store non-squeezed GN params.
|
| 148 |
+
params = flax.traverse_util.unflatten_dict({
|
| 149 |
+
k: np.squeeze(v) if (set(k)
|
| 150 |
+
& {"gn_root", "gn_proj", "gn1", "gn2", "gn3"}) else v
|
| 151 |
+
for k, v in flax.traverse_util.flatten_dict(params).items()
|
| 152 |
+
})
|
| 153 |
+
return params
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def load(init_params, init_file, model_cfg, dont_load=()):
|
| 157 |
+
"""Load init from checkpoint."""
|
| 158 |
+
del model_cfg # Unused
|
| 159 |
+
params = utils.load_params(init_file)
|
| 160 |
+
params = common.merge_params(params, init_params, dont_load)
|
| 161 |
+
params = fix_old_checkpoints(params)
|
| 162 |
+
return params
|
Tipsomaly/model/big_vision/models/bit_paper.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""BiT models as in the paper (ResNet V2) w/ loading of public weights.
|
| 16 |
+
|
| 17 |
+
See reproduction proof: http://(internal link)/qY70qs6j944
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import functools
|
| 21 |
+
import re
|
| 22 |
+
from typing import Optional, Sequence, Union
|
| 23 |
+
|
| 24 |
+
from big_vision import utils as u
|
| 25 |
+
from big_vision.models import bit
|
| 26 |
+
from big_vision.models import common
|
| 27 |
+
import flax.linen as nn
|
| 28 |
+
import jax.numpy as jnp
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def standardize(x, axis, eps):
|
| 32 |
+
x = x - jnp.mean(x, axis=axis, keepdims=True)
|
| 33 |
+
x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Defined our own, because we compute normalizing variance slightly differently,
|
| 38 |
+
# which does affect performance when loading pre-trained weights!
|
| 39 |
+
class GroupNorm(nn.Module):
|
| 40 |
+
"""Group normalization (arxiv.org/abs/1803.08494)."""
|
| 41 |
+
ngroups: int = 32
|
| 42 |
+
|
| 43 |
+
@nn.compact
|
| 44 |
+
def __call__(self, x):
|
| 45 |
+
|
| 46 |
+
input_shape = x.shape
|
| 47 |
+
group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups)
|
| 48 |
+
|
| 49 |
+
x = x.reshape(group_shape)
|
| 50 |
+
|
| 51 |
+
# Standardize along spatial and group dimensions
|
| 52 |
+
x = standardize(x, axis=[1, 2, 4], eps=1e-5)
|
| 53 |
+
x = x.reshape(input_shape)
|
| 54 |
+
|
| 55 |
+
bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
|
| 56 |
+
x = x * self.param('scale', nn.initializers.ones, bias_scale_shape)
|
| 57 |
+
x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape)
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class StdConv(nn.Conv):
|
| 62 |
+
|
| 63 |
+
def param(self, name, *a, **kw):
|
| 64 |
+
param = super().param(name, *a, **kw)
|
| 65 |
+
if name == 'kernel':
|
| 66 |
+
param = standardize(param, axis=[0, 1, 2], eps=1e-10)
|
| 67 |
+
return param
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class RootBlock(nn.Module):
|
| 71 |
+
"""Root block of ResNet."""
|
| 72 |
+
width: int
|
| 73 |
+
|
| 74 |
+
@nn.compact
|
| 75 |
+
def __call__(self, x):
|
| 76 |
+
x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
|
| 77 |
+
use_bias=False, name='conv_root')(x)
|
| 78 |
+
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)])
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ResidualUnit(nn.Module):
|
| 83 |
+
"""Bottleneck ResNet block."""
|
| 84 |
+
nmid: Optional[int] = None
|
| 85 |
+
strides: Sequence[int] = (1, 1)
|
| 86 |
+
|
| 87 |
+
@nn.compact
|
| 88 |
+
def __call__(self, x):
|
| 89 |
+
nmid = self.nmid or x.shape[-1] // 4
|
| 90 |
+
nout = nmid * 4
|
| 91 |
+
conv = functools.partial(StdConv, use_bias=False)
|
| 92 |
+
|
| 93 |
+
residual = x
|
| 94 |
+
x = GroupNorm(name='gn1')(x)
|
| 95 |
+
x = nn.relu(x)
|
| 96 |
+
|
| 97 |
+
if x.shape[-1] != nout or self.strides != (1, 1):
|
| 98 |
+
residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x)
|
| 99 |
+
|
| 100 |
+
x = conv(nmid, (1, 1), name='conv1')(x)
|
| 101 |
+
x = GroupNorm(name='gn2')(x)
|
| 102 |
+
x = nn.relu(x)
|
| 103 |
+
x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)],
|
| 104 |
+
name='conv2')(x)
|
| 105 |
+
x = GroupNorm(name='gn3')(x)
|
| 106 |
+
x = nn.relu(x)
|
| 107 |
+
x = conv(nout, (1, 1), name='conv3')(x)
|
| 108 |
+
|
| 109 |
+
return x + residual
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class ResNetStage(nn.Module):
|
| 113 |
+
"""A stage (sequence of same-resolution blocks)."""
|
| 114 |
+
block_size: int
|
| 115 |
+
nmid: Optional[int] = None
|
| 116 |
+
first_stride: Sequence[int] = (1, 1)
|
| 117 |
+
|
| 118 |
+
@nn.compact
|
| 119 |
+
def __call__(self, x):
|
| 120 |
+
out = {}
|
| 121 |
+
x = out['unit01'] = ResidualUnit(
|
| 122 |
+
self.nmid, strides=self.first_stride, name='unit01')(x)
|
| 123 |
+
for i in range(1, self.block_size):
|
| 124 |
+
x = out[f'unit{i+1:02d}'] = ResidualUnit(
|
| 125 |
+
self.nmid, name=f'unit{i+1:02d}')(x)
|
| 126 |
+
return x, out
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Model(nn.Module):
|
| 130 |
+
"""ResNetV2."""
|
| 131 |
+
num_classes: Optional[int] = None
|
| 132 |
+
width: int = 1
|
| 133 |
+
depth: Union[int, Sequence[int]] = 50 # 50/101/152, or list of block depths.
|
| 134 |
+
head_zeroinit: bool = True
|
| 135 |
+
|
| 136 |
+
@nn.compact
|
| 137 |
+
def __call__(self, image, *, train=False):
|
| 138 |
+
blocks = bit.get_block_desc(self.depth)
|
| 139 |
+
width = int(64 * self.width)
|
| 140 |
+
out = {}
|
| 141 |
+
|
| 142 |
+
x = out['stem'] = RootBlock(width=width, name='root_block')(image)
|
| 143 |
+
|
| 144 |
+
# Blocks
|
| 145 |
+
x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x)
|
| 146 |
+
for i, block_size in enumerate(blocks[1:], 1):
|
| 147 |
+
x, out[f'stage{i + 1}'] = ResNetStage(
|
| 148 |
+
block_size, width * 2 ** i,
|
| 149 |
+
first_stride=(2, 2), name=f'block{i + 1}')(x)
|
| 150 |
+
|
| 151 |
+
# Pre-head
|
| 152 |
+
x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x)
|
| 153 |
+
x = out['pre_logits_2d'] = nn.relu(x)
|
| 154 |
+
x = out['pre_logits'] = jnp.mean(x, axis=(1, 2))
|
| 155 |
+
|
| 156 |
+
# Head
|
| 157 |
+
if self.num_classes:
|
| 158 |
+
kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {}
|
| 159 |
+
head = nn.Dense(self.num_classes, name='head', **kw)
|
| 160 |
+
out['logits_2d'] = head(out['pre_logits_2d'])
|
| 161 |
+
x = out['logits'] = head(out['pre_logits'])
|
| 162 |
+
|
| 163 |
+
return x, out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def load(init_params, init_file, model_cfg, dont_load=()):
|
| 167 |
+
"""Loads the TF-dumped NumPy or big_vision checkpoint.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
init_params: random init params from which the new head is taken.
|
| 171 |
+
init_file: comes from `config.model_init`, can either be an absolute
|
| 172 |
+
path (ie starts with /) to the checkpoint, or a string like
|
| 173 |
+
"L-imagenet2012" describing one of the variants from the paper.
|
| 174 |
+
model_cfg: the model configuration.
|
| 175 |
+
dont_load: list of param names to be reset to init.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
The loaded parameters.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
# Support for vanity model names from the paper.
|
| 182 |
+
vanity = {
|
| 183 |
+
'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz',
|
| 184 |
+
'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz',
|
| 185 |
+
}
|
| 186 |
+
if init_file[0] in ('L', 'M', 'S'): # The models from the original paper.
|
| 187 |
+
# Supported names are of the following type:
|
| 188 |
+
# - 'M' or 'S': the original "upstream" model without fine-tuning.
|
| 189 |
+
# - 'M-ILSVRC2012': i21k model fine-tuned on i1k.
|
| 190 |
+
# - 'M-run0-caltech101': i21k model fine-tuned on VTAB's caltech101.
|
| 191 |
+
# each VTAB fine-tuning was run 3x, so there's run0, run1, run2.
|
| 192 |
+
if '-' in init_file:
|
| 193 |
+
up, down = init_file[0], init_file[1:]
|
| 194 |
+
else:
|
| 195 |
+
up, down = init_file, ''
|
| 196 |
+
down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down) # normalize
|
| 197 |
+
fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz'
|
| 198 |
+
fname = f'gs://bit_models/{fname}'
|
| 199 |
+
else:
|
| 200 |
+
fname = vanity.get(init_file, init_file)
|
| 201 |
+
|
| 202 |
+
params = u.load_params(fname)
|
| 203 |
+
params = maybe_convert_big_transfer_format(params)
|
| 204 |
+
return common.merge_params(params, init_params, dont_load)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def maybe_convert_big_transfer_format(params_tf):
|
| 208 |
+
"""If the checkpoint comes from legacy codebase, convert it."""
|
| 209 |
+
|
| 210 |
+
# Only do anything at all if we recognize the format.
|
| 211 |
+
if 'resnet' not in params_tf:
|
| 212 |
+
return params_tf
|
| 213 |
+
|
| 214 |
+
# For ease of processing and backwards compatibility, flatten again:
|
| 215 |
+
params_tf = dict(u.tree_flatten_with_names(params_tf)[0])
|
| 216 |
+
|
| 217 |
+
# Works around some files containing weird naming of variables:
|
| 218 |
+
for k in list(params_tf):
|
| 219 |
+
k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k)
|
| 220 |
+
if k2 != k:
|
| 221 |
+
params_tf[k2] = params_tf[k]
|
| 222 |
+
del params_tf[k]
|
| 223 |
+
|
| 224 |
+
params = {
|
| 225 |
+
'root_block': {'conv_root': {'kernel': params_tf[
|
| 226 |
+
'resnet/root_block/standardized_conv2d/kernel']}},
|
| 227 |
+
'norm-pre-head': {
|
| 228 |
+
'bias': params_tf['resnet/group_norm/beta'][None, None, None],
|
| 229 |
+
'scale': params_tf['resnet/group_norm/gamma'][None, None, None],
|
| 230 |
+
},
|
| 231 |
+
'head': {
|
| 232 |
+
'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0],
|
| 233 |
+
'bias': params_tf['resnet/head/conv2d/bias'],
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
for block in ('block1', 'block2', 'block3', 'block4'):
|
| 238 |
+
params[block] = {}
|
| 239 |
+
units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys()
|
| 240 |
+
if p.find(block) >= 0])
|
| 241 |
+
for unit in units:
|
| 242 |
+
params[block][unit] = {}
|
| 243 |
+
for i, group in enumerate('abc', 1):
|
| 244 |
+
params[block][unit][f'conv{i}'] = {
|
| 245 |
+
'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel'] # pylint: disable=line-too-long
|
| 246 |
+
}
|
| 247 |
+
params[block][unit][f'gn{i}'] = {
|
| 248 |
+
'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None], # pylint: disable=line-too-long
|
| 249 |
+
'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None], # pylint: disable=line-too-long
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
projs = [p for p in params_tf.keys()
|
| 253 |
+
if p.find(f'{block}/{unit}/a/proj') >= 0]
|
| 254 |
+
assert len(projs) <= 1
|
| 255 |
+
if projs:
|
| 256 |
+
params[block][unit]['conv_proj'] = {
|
| 257 |
+
'kernel': params_tf[projs[0]]
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
return params
|
Tipsomaly/model/big_vision/models/common.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Utilities shared across models."""
|
| 16 |
+
|
| 17 |
+
from absl import logging
|
| 18 |
+
import big_vision.utils as u
|
| 19 |
+
import flax.linen as nn
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def merge_params(loaded, inited, dont_load=(), match_dtype=False):
|
| 25 |
+
"""Makes `loaded` pytree match `init`, warning or failing on mismatch.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
loaded: pytree of parameters, typically loaded from a checkpoint.
|
| 29 |
+
inited: pytree of parameter, typically coming from model init.
|
| 30 |
+
dont_load: List of regexes for parameters which shall not be taken
|
| 31 |
+
from `loaded`, either because they should remain at their init value,
|
| 32 |
+
or because they are missing on either side.
|
| 33 |
+
match_dtype: returned pytree as leaves converted to dtype from `inited`.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
If successful, a new pytree which matches the structure of `init`
|
| 37 |
+
but contains values from `loaded`, except for `dont_load`.
|
| 38 |
+
|
| 39 |
+
If structures don't match and mismatches are not covered by regexes in
|
| 40 |
+
`dont_load` argument, then raises an exception with more information.
|
| 41 |
+
"""
|
| 42 |
+
if inited is None: # A useful shortcut for example for colabs.
|
| 43 |
+
return loaded
|
| 44 |
+
|
| 45 |
+
dont_load = u.check_and_compile_patterns(dont_load)
|
| 46 |
+
|
| 47 |
+
def should_merge(name):
|
| 48 |
+
return not any(pattern.fullmatch(name) for pattern in dont_load)
|
| 49 |
+
|
| 50 |
+
loaded_flat, _ = u.tree_flatten_with_names(loaded)
|
| 51 |
+
inited_flat, _ = u.tree_flatten_with_names(inited)
|
| 52 |
+
loaded_flat = {k: v for k, v in loaded_flat}
|
| 53 |
+
inited_flat = {k: v for k, v in inited_flat}
|
| 54 |
+
|
| 55 |
+
# Let's first build the pytree from all common keys.
|
| 56 |
+
merged = {}
|
| 57 |
+
for name, init_val in inited_flat.items():
|
| 58 |
+
# param is present in both. Load or ignore it!
|
| 59 |
+
if name in loaded_flat and should_merge(name):
|
| 60 |
+
merged[name] = loaded_flat[name]
|
| 61 |
+
if match_dtype:
|
| 62 |
+
merged[name] = loaded_flat[name].astype(init_val.dtype)
|
| 63 |
+
else:
|
| 64 |
+
logging.info("Ignoring checkpoint and using init value for %s", name)
|
| 65 |
+
merged[name] = init_val
|
| 66 |
+
|
| 67 |
+
def pp(title, names, indent=" "): # Just pretty-printing
|
| 68 |
+
if names:
|
| 69 |
+
return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names))
|
| 70 |
+
else:
|
| 71 |
+
return ""
|
| 72 |
+
|
| 73 |
+
# Now, if there are keys that only exist in inited or loaded, be helpful:
|
| 74 |
+
not_in_loaded = inited_flat.keys() - loaded_flat.keys()
|
| 75 |
+
not_in_inited = loaded_flat.keys() - inited_flat.keys()
|
| 76 |
+
logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded))
|
| 77 |
+
logging.info(pp("Parameters in checkpoint but not in model", not_in_inited))
|
| 78 |
+
|
| 79 |
+
# And now see if any of them are not explicitly ignored => an error
|
| 80 |
+
not_in_loaded = {k for k in not_in_loaded if should_merge(k)}
|
| 81 |
+
not_in_inited = {k for k in not_in_inited if should_merge(k)}
|
| 82 |
+
|
| 83 |
+
if not_in_loaded or not_in_inited:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
pp("Params in checkpoint", loaded_flat.keys()) + "\n" +
|
| 86 |
+
pp("Params in model (code)", inited_flat.keys()) + "\n" +
|
| 87 |
+
pp("Params in model (code) but not in checkpoint and not `dont_load`ed",
|
| 88 |
+
not_in_loaded, indent=" - ") + "\n" + # Special indent for tests.
|
| 89 |
+
pp("Params in checkpoint but not in model (code) and not `dont_load`ed",
|
| 90 |
+
not_in_inited, indent=" + ")) # Special indent for tests.
|
| 91 |
+
|
| 92 |
+
return u.recover_tree(merged.keys(), merged.values())
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class AddPositionEmbs(nn.Module):
|
| 96 |
+
"""Adds positional embeddings to the inputs, supports caching for decode.
|
| 97 |
+
|
| 98 |
+
Attributes:
|
| 99 |
+
decode: whether to run in single-position autoregressive mode.
|
| 100 |
+
"""
|
| 101 |
+
decode: bool = False
|
| 102 |
+
|
| 103 |
+
@nn.compact
|
| 104 |
+
def __call__(self, inputs, posemb):
|
| 105 |
+
"""Applies AddPositionEmbs module.
|
| 106 |
+
|
| 107 |
+
Adds posemb to the inputs, supports single-position autoregressive mode.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
inputs: input data [batch_size, seq_len, emb_dim].
|
| 111 |
+
posemb: positional embeddings.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
output: inputs modulated by pos-embeddings [batch_size, seq_len, emb_dim].
|
| 115 |
+
"""
|
| 116 |
+
assert inputs.ndim == 3, f"Unexpected inputs shape: {inputs.shape}"
|
| 117 |
+
_, seq_len, emb_dim = inputs.shape
|
| 118 |
+
pe = posemb[:, :seq_len, :]
|
| 119 |
+
|
| 120 |
+
if self.decode:
|
| 121 |
+
is_initialized = self.has_variable("cache", "cache_index")
|
| 122 |
+
# We use a cache position index for tracking decoding position.
|
| 123 |
+
cache_index = self.variable("cache", "cache_index",
|
| 124 |
+
lambda: jnp.array(0, dtype=jnp.uint32))
|
| 125 |
+
if is_initialized:
|
| 126 |
+
i = cache_index.value
|
| 127 |
+
cache_index.value = i + 1
|
| 128 |
+
# Returns posemb[0, i, :], the positional embedding for the
|
| 129 |
+
# current decoding position.
|
| 130 |
+
pe = jax.lax.dynamic_slice(posemb,
|
| 131 |
+
start_indices=jnp.array((0, i, 0)),
|
| 132 |
+
slice_sizes=(1, 1, emb_dim))
|
| 133 |
+
return inputs + pe
|
Tipsomaly/model/big_vision/models/mlp_mixer.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""MLP-Mixer model."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Tuple
|
| 18 |
+
from absl import logging
|
| 19 |
+
|
| 20 |
+
from big_vision import utils
|
| 21 |
+
from big_vision.models import common
|
| 22 |
+
|
| 23 |
+
import einops
|
| 24 |
+
import flax.linen as nn
|
| 25 |
+
import flax.training.checkpoints
|
| 26 |
+
import jax
|
| 27 |
+
import jax.numpy as jnp
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MlpBlock(nn.Module):
|
| 31 |
+
mlp_dim: int
|
| 32 |
+
|
| 33 |
+
@nn.compact
|
| 34 |
+
def __call__(self, x):
|
| 35 |
+
y = nn.Dense(self.mlp_dim)(x)
|
| 36 |
+
y = nn.gelu(y)
|
| 37 |
+
return nn.Dense(x.shape[-1])(y)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MixerBlock(nn.Module):
|
| 41 |
+
"""Mixer block layer."""
|
| 42 |
+
tokens_mlp_dim: int
|
| 43 |
+
channels_mlp_dim: int
|
| 44 |
+
drop_p: float
|
| 45 |
+
|
| 46 |
+
@nn.compact
|
| 47 |
+
def __call__(self, x, *, train=False):
|
| 48 |
+
y = nn.LayerNorm()(x)
|
| 49 |
+
y = jnp.swapaxes(y, 1, 2)
|
| 50 |
+
y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y)
|
| 51 |
+
y = jnp.swapaxes(y, 1, 2)
|
| 52 |
+
x = x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
|
| 53 |
+
y = nn.LayerNorm()(x)
|
| 54 |
+
y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y)
|
| 55 |
+
return x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MlpMixer(nn.Module):
|
| 59 |
+
"""Mixer architecture."""
|
| 60 |
+
patch_size: Tuple[int, int]
|
| 61 |
+
num_classes: Optional[int]
|
| 62 |
+
num_blocks: int
|
| 63 |
+
hidden_dim: int
|
| 64 |
+
tokens_mlp_dim: int
|
| 65 |
+
channels_mlp_dim: int
|
| 66 |
+
model_name: Optional[str] = None
|
| 67 |
+
stoch_depth: float = 0.0
|
| 68 |
+
|
| 69 |
+
@nn.compact
|
| 70 |
+
def __call__(self, image, *, train=False):
|
| 71 |
+
out = {}
|
| 72 |
+
x = out["stem"] = nn.Conv(self.hidden_dim, self.patch_size,
|
| 73 |
+
strides=self.patch_size, name="stem")(image)
|
| 74 |
+
x = out["input_tokens"] = einops.rearrange(x, "n h w c -> n (h w) c")
|
| 75 |
+
for i in range(self.num_blocks):
|
| 76 |
+
drop_p = (i / max(self.num_blocks - 1, 1)) * self.stoch_depth
|
| 77 |
+
x = out[f"block_{i}"] = MixerBlock(
|
| 78 |
+
self.tokens_mlp_dim, self.channels_mlp_dim, drop_p)(x, train=train)
|
| 79 |
+
x = nn.LayerNorm(name="pre_head_layer_norm")(x)
|
| 80 |
+
x = out["pre_logits"] = jnp.mean(x, axis=1)
|
| 81 |
+
if self.num_classes:
|
| 82 |
+
x = out["logits"] = nn.Dense(
|
| 83 |
+
self.num_classes, kernel_init=nn.initializers.zeros, name="head")(x)
|
| 84 |
+
return x, out
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name
|
| 88 |
+
"""Factory function to easily create a Model variant like "L/16"."""
|
| 89 |
+
|
| 90 |
+
if variant is not None:
|
| 91 |
+
model_size, patch = variant.split("/")
|
| 92 |
+
kw.setdefault("patch_size", (int(patch), int(patch)))
|
| 93 |
+
config = {
|
| 94 |
+
"S": {
|
| 95 |
+
"hidden_dim": 512,
|
| 96 |
+
"num_blocks": 8,
|
| 97 |
+
"channels_mlp_dim": 2048,
|
| 98 |
+
"tokens_mlp_dim": 256
|
| 99 |
+
},
|
| 100 |
+
"B": {
|
| 101 |
+
"hidden_dim": 768,
|
| 102 |
+
"num_blocks": 12,
|
| 103 |
+
"channels_mlp_dim": 3072,
|
| 104 |
+
"tokens_mlp_dim": 384
|
| 105 |
+
},
|
| 106 |
+
"L": {
|
| 107 |
+
"hidden_dim": 1024,
|
| 108 |
+
"num_blocks": 24,
|
| 109 |
+
"channels_mlp_dim": 4096,
|
| 110 |
+
"tokens_mlp_dim": 512
|
| 111 |
+
},
|
| 112 |
+
"H": {
|
| 113 |
+
"hidden_dim": 1280,
|
| 114 |
+
"num_blocks": 32,
|
| 115 |
+
"channels_mlp_dim": 5120,
|
| 116 |
+
"tokens_mlp_dim": 640
|
| 117 |
+
},
|
| 118 |
+
}[model_size]
|
| 119 |
+
|
| 120 |
+
for k, v in config.items():
|
| 121 |
+
kw.setdefault(k, v)
|
| 122 |
+
|
| 123 |
+
logging.info("Mixer config: %s", kw)
|
| 124 |
+
return MlpMixer(num_classes=num_classes, **kw)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def load(init_params, init_file, model_cfg, dont_load=()):
|
| 128 |
+
"""Load checkpoint."""
|
| 129 |
+
|
| 130 |
+
del model_cfg
|
| 131 |
+
# Shortcut names for some canonical paper checkpoints:
|
| 132 |
+
init_file = {
|
| 133 |
+
# pylint: disable=line-too-long
|
| 134 |
+
# Pretrained models from the MLP-Mixer paper: https://arxiv.org/abs/2105.01601.
|
| 135 |
+
"B-i1k/16": "gs://mixer_models/imagenet1k/Mixer-B_16.npz",
|
| 136 |
+
"L-i1k/16": "gs://mixer_models/imagenet1k/Mixer-L_16.npz",
|
| 137 |
+
"B-i21k/16": "gs://mixer_models/imagenet21k/Mixer-B_16.npz",
|
| 138 |
+
"L-i21k/16": "gs://mixer_models/imagenet21k/Mixer-L_16.npz",
|
| 139 |
+
# pylint: enable=line-too-long
|
| 140 |
+
}.get(init_file, init_file)
|
| 141 |
+
restored_params = utils.load_params(init_file)
|
| 142 |
+
restored_params = flax.training.checkpoints.convert_pre_linen(restored_params)
|
| 143 |
+
|
| 144 |
+
if "Mixer" in restored_params:
|
| 145 |
+
restored_params["pre_head_layer_norm"] = restored_params["Mixer"].pop(
|
| 146 |
+
"encoder_norm"
|
| 147 |
+
)
|
| 148 |
+
restored_params["stem"] = restored_params.pop("embedding")
|
| 149 |
+
def unflatten_dense(d):
|
| 150 |
+
return {
|
| 151 |
+
"Dense_0": {
|
| 152 |
+
"bias": d["bias1"].squeeze(),
|
| 153 |
+
"kernel": d["kernel1"].squeeze(),
|
| 154 |
+
},
|
| 155 |
+
"Dense_1": {
|
| 156 |
+
"bias": d["bias2"].squeeze(),
|
| 157 |
+
"kernel": d["kernel2"].squeeze(),
|
| 158 |
+
},
|
| 159 |
+
}
|
| 160 |
+
for k, v in restored_params["Mixer"].items():
|
| 161 |
+
assert k.startswith("encoderblock_"), k
|
| 162 |
+
v["token_mixing"] = unflatten_dense(v.pop("token_mixing_phase_0"))
|
| 163 |
+
v["channel_mixing"] = unflatten_dense(v.pop("channel_mixing_phase_0"))
|
| 164 |
+
restored_params["MixerBlock_" + k[len("encoderblock_"):]] = v
|
| 165 |
+
del restored_params["Mixer"]
|
| 166 |
+
|
| 167 |
+
# possibly use the random init for some of the params (such as, the head).
|
| 168 |
+
restored_params = common.merge_params(restored_params, init_params, dont_load)
|
| 169 |
+
|
| 170 |
+
return restored_params
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _stoch_depth_mask(x, drop_p, deterministic, make_rng):
|
| 174 |
+
if not deterministic and drop_p:
|
| 175 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 176 |
+
return 1.0 - jax.random.bernoulli(make_rng("dropout"), drop_p, shape)
|
| 177 |
+
return 1.0
|
Tipsomaly/model/big_vision/models/vit.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""A refactored and simplified ViT.
|
| 16 |
+
|
| 17 |
+
However, the names of modules are made to match the old ones for easy loading.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from typing import Optional, Sequence, Union
|
| 21 |
+
|
| 22 |
+
from absl import logging
|
| 23 |
+
from big_vision import utils
|
| 24 |
+
from big_vision.models import common
|
| 25 |
+
import flax
|
| 26 |
+
import flax.linen as nn
|
| 27 |
+
import flax.training.checkpoints
|
| 28 |
+
import jax
|
| 29 |
+
import jax.numpy as jnp
|
| 30 |
+
import numpy as np
|
| 31 |
+
import scipy.ndimage
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32):
|
| 35 |
+
"""Follows the MoCo v3 logic."""
|
| 36 |
+
y, x = jnp.mgrid[:h, :w]
|
| 37 |
+
|
| 38 |
+
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
|
| 39 |
+
omega = jnp.arange(width // 4) / (width // 4 - 1)
|
| 40 |
+
omega = 1. / (temperature**omega)
|
| 41 |
+
y = jnp.einsum("m,d->md", y.flatten(), omega)
|
| 42 |
+
x = jnp.einsum("m,d->md", x.flatten(), omega)
|
| 43 |
+
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
|
| 44 |
+
return jnp.asarray(pe, dtype)[None, :, :]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
|
| 48 |
+
if typ == "learn":
|
| 49 |
+
return self.param(name, nn.initializers.normal(stddev=1/np.sqrt(width)),
|
| 50 |
+
(1, np.prod(seqshape), width), dtype)
|
| 51 |
+
elif typ == "sincos2d":
|
| 52 |
+
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(f"Unknown posemb type: {typ}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MlpBlock(nn.Module):
|
| 58 |
+
"""Transformer MLP / feed-forward block."""
|
| 59 |
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim
|
| 60 |
+
dropout: float = 0.0
|
| 61 |
+
dtype_mm: str = "float32"
|
| 62 |
+
|
| 63 |
+
@nn.compact
|
| 64 |
+
def __call__(self, x, deterministic=True):
|
| 65 |
+
"""Applies Transformer MlpBlock module."""
|
| 66 |
+
inits = dict(
|
| 67 |
+
kernel_init=nn.initializers.xavier_uniform(),
|
| 68 |
+
bias_init=nn.initializers.normal(stddev=1e-6),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
d = x.shape[-1]
|
| 72 |
+
x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
|
| 73 |
+
# In some extreme batch-size cases, this is needed as of Sept 2024:
|
| 74 |
+
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
| 75 |
+
x = nn.gelu(x)
|
| 76 |
+
x = nn.Dropout(rate=self.dropout)(x, deterministic)
|
| 77 |
+
x = nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Encoder1DBlock(nn.Module):
|
| 82 |
+
"""Single transformer encoder block (MHSA + MLP)."""
|
| 83 |
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim
|
| 84 |
+
num_heads: int = 12
|
| 85 |
+
dropout: float = 0.0
|
| 86 |
+
dtype_mm: str = "float32"
|
| 87 |
+
|
| 88 |
+
@nn.compact
|
| 89 |
+
def __call__(self, x, deterministic=True):
|
| 90 |
+
out = {}
|
| 91 |
+
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
| 92 |
+
y = nn.LayerNorm()(x)
|
| 93 |
+
y = out["sa"] = nn.MultiHeadDotProductAttention(
|
| 94 |
+
num_heads=self.num_heads,
|
| 95 |
+
kernel_init=nn.initializers.xavier_uniform(),
|
| 96 |
+
deterministic=deterministic,
|
| 97 |
+
dtype=self.dtype_mm,
|
| 98 |
+
)(y, y)
|
| 99 |
+
y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
|
| 100 |
+
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
| 101 |
+
x = out["+sa"] = x + y
|
| 102 |
+
|
| 103 |
+
y = nn.LayerNorm()(x)
|
| 104 |
+
y = out["mlp"] = MlpBlock(
|
| 105 |
+
mlp_dim=self.mlp_dim, dropout=self.dropout,
|
| 106 |
+
dtype_mm=self.dtype_mm,
|
| 107 |
+
)(y, deterministic)
|
| 108 |
+
y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
|
| 109 |
+
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
| 110 |
+
x = out["+mlp"] = x + y
|
| 111 |
+
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
| 112 |
+
return x, out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Encoder(nn.Module):
|
| 116 |
+
"""Transformer Model Encoder for sequence to sequence translation."""
|
| 117 |
+
depth: int
|
| 118 |
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim
|
| 119 |
+
num_heads: int = 12
|
| 120 |
+
dropout: float = 0.0
|
| 121 |
+
scan: bool = False
|
| 122 |
+
remat_policy: str = "nothing_saveable"
|
| 123 |
+
dtype_mm: str = "float32"
|
| 124 |
+
|
| 125 |
+
@nn.compact
|
| 126 |
+
def __call__(self, x, deterministic=True):
|
| 127 |
+
out = {}
|
| 128 |
+
|
| 129 |
+
if self.scan:
|
| 130 |
+
block = nn.remat(
|
| 131 |
+
Encoder1DBlock,
|
| 132 |
+
prevent_cse=False,
|
| 133 |
+
static_argnums=(2,), # 0=self, 2=deterministic
|
| 134 |
+
policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
|
| 135 |
+
)
|
| 136 |
+
x, scan_out = nn.scan(
|
| 137 |
+
block,
|
| 138 |
+
variable_axes={"params": 0},
|
| 139 |
+
split_rngs={"params": True, "dropout": True},
|
| 140 |
+
in_axes=nn.broadcast,
|
| 141 |
+
length=self.depth)(
|
| 142 |
+
name="encoderblock",
|
| 143 |
+
dtype_mm=self.dtype_mm,
|
| 144 |
+
mlp_dim=self.mlp_dim,
|
| 145 |
+
num_heads=self.num_heads,
|
| 146 |
+
dropout=self.dropout)(x, deterministic)
|
| 147 |
+
for lyr in range(self.depth):
|
| 148 |
+
out[f"block{lyr:02d}"] = jax.tree.map(lambda o, l=lyr: o[l], scan_out)
|
| 149 |
+
else:
|
| 150 |
+
# Input Encoder
|
| 151 |
+
for lyr in range(self.depth):
|
| 152 |
+
block_cur = Encoder1DBlock(
|
| 153 |
+
name=f"encoderblock_{lyr}",
|
| 154 |
+
dtype_mm=self.dtype_mm,
|
| 155 |
+
mlp_dim=self.mlp_dim, num_heads=self.num_heads,
|
| 156 |
+
dropout=self.dropout)
|
| 157 |
+
x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
|
| 158 |
+
out["pre_ln"] = x # Alias for last block, but without the number in it.
|
| 159 |
+
|
| 160 |
+
return nn.LayerNorm(name="encoder_norm")(x), out
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class MAPHead(nn.Module):
|
| 164 |
+
"""Multihead Attention Pooling."""
|
| 165 |
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim
|
| 166 |
+
num_heads: int = 12
|
| 167 |
+
|
| 168 |
+
@nn.compact
|
| 169 |
+
def __call__(self, x):
|
| 170 |
+
# TODO
|
| 171 |
+
n, l, d = x.shape # pylint: disable=unused-variable
|
| 172 |
+
probe = self.param("probe", nn.initializers.xavier_uniform(),
|
| 173 |
+
(1, 1, d), x.dtype)
|
| 174 |
+
probe = jnp.tile(probe, [n, 1, 1])
|
| 175 |
+
|
| 176 |
+
x = nn.MultiHeadDotProductAttention(
|
| 177 |
+
num_heads=self.num_heads,
|
| 178 |
+
kernel_init=nn.initializers.xavier_uniform())(probe, x)
|
| 179 |
+
|
| 180 |
+
# TODO: dropout on head?
|
| 181 |
+
y = nn.LayerNorm()(x)
|
| 182 |
+
x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
|
| 183 |
+
return x[:, 0]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class _Model(nn.Module):
|
| 187 |
+
"""ViT model."""
|
| 188 |
+
|
| 189 |
+
num_classes: Optional[int] = None
|
| 190 |
+
patch_size: Sequence[int] = (16, 16)
|
| 191 |
+
width: int = 768
|
| 192 |
+
depth: int = 12
|
| 193 |
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim
|
| 194 |
+
num_heads: int = 12
|
| 195 |
+
posemb: str = "learn" # Can also be "sincos2d"
|
| 196 |
+
rep_size: Union[int, bool] = False
|
| 197 |
+
dropout: float = 0.0
|
| 198 |
+
pool_type: str = "gap" # Can also be "map" or "tok"
|
| 199 |
+
head_zeroinit: bool = True
|
| 200 |
+
scan: bool = False
|
| 201 |
+
# or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
|
| 202 |
+
remat_policy: str = "nothing_saveable"
|
| 203 |
+
dtype_mm: str = "float32"
|
| 204 |
+
|
| 205 |
+
@nn.compact
|
| 206 |
+
def __call__(self, image, *, train=False):
|
| 207 |
+
out = {}
|
| 208 |
+
|
| 209 |
+
image = jnp.asarray(image, self.dtype_mm)
|
| 210 |
+
|
| 211 |
+
# Patch extraction
|
| 212 |
+
x = out["stem"] = nn.Conv(
|
| 213 |
+
self.width, self.patch_size, strides=self.patch_size,
|
| 214 |
+
padding="VALID", name="embedding", dtype=self.dtype_mm)(image)
|
| 215 |
+
|
| 216 |
+
n, h, w, c = x.shape
|
| 217 |
+
x = jnp.reshape(x, [n, h * w, c])
|
| 218 |
+
|
| 219 |
+
# Add posemb before adding extra token.
|
| 220 |
+
x = out["with_posemb"] = x + get_posemb(
|
| 221 |
+
self, self.posemb, (h, w), c, "pos_embedding", x.dtype)
|
| 222 |
+
|
| 223 |
+
if self.pool_type == "tok":
|
| 224 |
+
cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
|
| 225 |
+
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
|
| 226 |
+
|
| 227 |
+
n, l, c = x.shape # pylint: disable=unused-variable
|
| 228 |
+
x = nn.Dropout(rate=self.dropout)(x, not train)
|
| 229 |
+
|
| 230 |
+
x, out["encoder"] = Encoder(
|
| 231 |
+
depth=self.depth,
|
| 232 |
+
mlp_dim=self.mlp_dim,
|
| 233 |
+
num_heads=self.num_heads,
|
| 234 |
+
dropout=self.dropout,
|
| 235 |
+
scan=self.scan,
|
| 236 |
+
remat_policy=self.remat_policy,
|
| 237 |
+
dtype_mm=self.dtype_mm,
|
| 238 |
+
name="Transformer")(
|
| 239 |
+
x, deterministic=not train)
|
| 240 |
+
encoded = out["encoded"] = x
|
| 241 |
+
|
| 242 |
+
if self.pool_type == "map":
|
| 243 |
+
x = out["head_input"] = MAPHead(
|
| 244 |
+
num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
|
| 245 |
+
elif self.pool_type == "gap":
|
| 246 |
+
x = out["head_input"] = jnp.mean(x, axis=1)
|
| 247 |
+
elif self.pool_type == "0":
|
| 248 |
+
x = out["head_input"] = x[:, 0]
|
| 249 |
+
elif self.pool_type == "tok":
|
| 250 |
+
x = out["head_input"] = x[:, 0]
|
| 251 |
+
encoded = encoded[:, 1:]
|
| 252 |
+
elif self.pool_type == "none":
|
| 253 |
+
pass
|
| 254 |
+
else:
|
| 255 |
+
raise ValueError(f"Unknown pool type: '{self.pool_type}'")
|
| 256 |
+
|
| 257 |
+
x_2d = jnp.reshape(encoded, [n, h, w, -1])
|
| 258 |
+
|
| 259 |
+
if self.rep_size:
|
| 260 |
+
# raise Exception("It should not come here, patch embds should not be ...")
|
| 261 |
+
rep_size = self.width if self.rep_size is True else self.rep_size
|
| 262 |
+
hid = nn.Dense(rep_size, name="pre_logits")
|
| 263 |
+
# NOTE: In the past we did not include tanh in pre_logits.
|
| 264 |
+
# For few-shot, it should not matter much, as it whitens anyways.
|
| 265 |
+
x_2d = nn.tanh(hid(x_2d))
|
| 266 |
+
x = nn.tanh(hid(x))
|
| 267 |
+
print('here_rep_size')
|
| 268 |
+
# print('after_rep_size')
|
| 269 |
+
# print(f'self.pool_type: {self.pool_type}')
|
| 270 |
+
|
| 271 |
+
out["pre_logits_2d"] = x_2d
|
| 272 |
+
out["pre_logits"] = x
|
| 273 |
+
|
| 274 |
+
if self.num_classes:
|
| 275 |
+
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
|
| 276 |
+
head = nn.Dense(self.num_classes, name="head", **kw)
|
| 277 |
+
x_2d = out["logits_2d"] = head(x_2d)
|
| 278 |
+
x = out["logits"] = head(x)
|
| 279 |
+
|
| 280 |
+
return x, out
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name
|
| 284 |
+
"""Factory function, because linen really don't like what I'm doing!"""
|
| 285 |
+
return _Model(num_classes, **{**decode_variant(variant), **kw})
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def decode_variant(variant):
|
| 289 |
+
"""Converts a string like "B" or "B/32" into a params dict."""
|
| 290 |
+
if variant is None:
|
| 291 |
+
return {}
|
| 292 |
+
|
| 293 |
+
v, patch = variant, {}
|
| 294 |
+
if "/" in variant:
|
| 295 |
+
v, patch = variant.split("/")
|
| 296 |
+
patch = {"patch_size": (int(patch), int(patch))}
|
| 297 |
+
|
| 298 |
+
return {
|
| 299 |
+
# pylint:disable=line-too-long
|
| 300 |
+
# Reference: Table 2 of https://arxiv.org/abs/2106.04560.
|
| 301 |
+
"width": {"mu": 32, "Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "So400m": 1152, "H": 1280, "g": 1408, "g-opt": 1536, "G": 1664, "G-opt": 1536, "e": 1792}[v],
|
| 302 |
+
"depth": {"mu": 1, "Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "So400m": 27, "H": 32, "g": 40, "g-opt": 40, "G": 48, "G-opt": 48, "e": 56}[v],
|
| 303 |
+
"mlp_dim": {"mu": 128, "Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "So400m": 4304, "H": 5120, "g": 6144, "g-opt": 6144, "G": 8192, "G-opt": 8192, "e": 15360}[v],
|
| 304 |
+
"num_heads": {"mu": 2, "Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "So400m": 16, "H": 16, "g": 16, "g-opt": 16, "G": 16, "G-opt": 16, "e": 16}[v],
|
| 305 |
+
# pylint:enable=line-too-long
|
| 306 |
+
**patch
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
def resample_posemb(old, new):
|
| 310 |
+
"""This function implements "high-res finetuning" for transformer models."""
|
| 311 |
+
# Rescale the grid of position embeddings. Param shape is (1,N,1024)
|
| 312 |
+
if old.shape == new.shape:
|
| 313 |
+
return old
|
| 314 |
+
|
| 315 |
+
logging.info("ViT: resize %s to %s", old.shape, new.shape)
|
| 316 |
+
gs_old = int(np.sqrt(old.shape[1]))
|
| 317 |
+
gs_new = int(np.sqrt(new.shape[1]))
|
| 318 |
+
logging.info("ViT: grid-size from %s to %s", gs_old, gs_new)
|
| 319 |
+
grid = old.reshape(gs_old, gs_old, -1)
|
| 320 |
+
|
| 321 |
+
zoom = (gs_new/gs_old, gs_new/gs_old, 1)
|
| 322 |
+
grid = scipy.ndimage.zoom(grid, zoom, order=1)
|
| 323 |
+
grid = grid.reshape(1, gs_new*gs_new, -1)
|
| 324 |
+
return grid
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def fix_old_checkpoints(params):
|
| 328 |
+
"""Fix small bwd incompat that can't be resolved with names in model def."""
|
| 329 |
+
|
| 330 |
+
params = flax.core.unfreeze(
|
| 331 |
+
flax.training.checkpoints.convert_pre_linen(params))
|
| 332 |
+
|
| 333 |
+
# Original ViT paper variant had posemb in a module:
|
| 334 |
+
if "posembed_input" in params["Transformer"]:
|
| 335 |
+
logging.info("ViT: Loading and fixing VERY old posemb")
|
| 336 |
+
posemb = params["Transformer"].pop("posembed_input")
|
| 337 |
+
params["pos_embedding"] = posemb["pos_embedding"]
|
| 338 |
+
|
| 339 |
+
# Widely used version before 2022 had posemb in Encoder:
|
| 340 |
+
if "pos_embedding" in params["Transformer"]:
|
| 341 |
+
logging.info("ViT: Loading and fixing old posemb")
|
| 342 |
+
params["pos_embedding"] = params["Transformer"].pop("pos_embedding")
|
| 343 |
+
|
| 344 |
+
# Old vit.py used to first concat [cls] token, then add posemb.
|
| 345 |
+
# This means a B/32@224px would have 7x7+1 posembs. This is useless and clumsy
|
| 346 |
+
# so we changed to add posemb then concat [cls]. We can recover the old
|
| 347 |
+
# checkpoint by manually summing [cls] token and its posemb entry.
|
| 348 |
+
if "pos_embedding" in params:
|
| 349 |
+
pe = params["pos_embedding"]
|
| 350 |
+
if int(np.sqrt(pe.shape[1])) ** 2 + 1 == int(pe.shape[1]):
|
| 351 |
+
logging.info("ViT: Loading and fixing combined cls+posemb")
|
| 352 |
+
pe_cls, params["pos_embedding"] = pe[:, :1], pe[:, 1:]
|
| 353 |
+
if "cls" in params:
|
| 354 |
+
params["cls"] += pe_cls
|
| 355 |
+
|
| 356 |
+
# MAP-head variants during ViT-G development had it inlined:
|
| 357 |
+
if "probe" in params:
|
| 358 |
+
params["MAPHead_0"] = {
|
| 359 |
+
k: params.pop(k) for k in
|
| 360 |
+
["probe", "MlpBlock_0", "MultiHeadDotProductAttention_0", "LayerNorm_0"]
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
return params
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def pyloop_to_scan(params_pyloop):
|
| 367 |
+
"""Converts a python for-loop ViT checkpoint to a lax.scan based one."""
|
| 368 |
+
# On a high level, they are the same except that the for loop has separate
|
| 369 |
+
# array pytrees for each encoderblock, while the scan one has just one
|
| 370 |
+
# encoderblock pytree, with all block's params concatenated.
|
| 371 |
+
|
| 372 |
+
params_scan = jax.tree.map(lambda x: x, params_pyloop) # Structural copy
|
| 373 |
+
t = params_scan["Transformer"]
|
| 374 |
+
|
| 375 |
+
# Find highest index of encoderblocks in the checkpoint (they start at 0):
|
| 376 |
+
encoderblocks = {k for k in t if k.startswith("encoderblock_")}
|
| 377 |
+
depth = 1 + max({int(k.split("_")[-1]) for k in encoderblocks})
|
| 378 |
+
|
| 379 |
+
def stack(*values):
|
| 380 |
+
return np.stack(values)
|
| 381 |
+
|
| 382 |
+
# Stack all encoderblocks into a single one:
|
| 383 |
+
t["encoderblock"] = jax.tree.map(
|
| 384 |
+
stack, *[t[f"encoderblock_{lyr}"] for lyr in range(depth)])
|
| 385 |
+
|
| 386 |
+
for lyr in range(depth):
|
| 387 |
+
del t[f"encoderblock_{lyr}"]
|
| 388 |
+
|
| 389 |
+
return params_scan
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def scan_to_pyloop(params_scan):
|
| 393 |
+
"""Converts a lax.scan ViT checkpoint to a python for-loop based one."""
|
| 394 |
+
# See comment in pyloop_to_scan.
|
| 395 |
+
|
| 396 |
+
params_scan = jax.tree.map(lambda x: x, params_scan) # Structural copy
|
| 397 |
+
t = params_scan["Transformer"]
|
| 398 |
+
|
| 399 |
+
# Find out how many encoderblocks there are
|
| 400 |
+
depth = len(t["encoderblock"]["LayerNorm_0"]["bias"])
|
| 401 |
+
|
| 402 |
+
# Create that many encoderblocks, each with their slice of their sub-pytree.
|
| 403 |
+
for lyr in range(depth):
|
| 404 |
+
block = jax.tree.map(lambda x, lyr=lyr: x[lyr], t["encoderblock"])
|
| 405 |
+
t[f"encoderblock_{lyr}"] = block
|
| 406 |
+
|
| 407 |
+
del t["encoderblock"]
|
| 408 |
+
return params_scan
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above.
|
| 412 |
+
"""Load init from checkpoint, both old model and this one. +Hi-res posemb."""
|
| 413 |
+
init_file = VANITY_NAMES.get(init_file, init_file)
|
| 414 |
+
restored_params = utils.load_params(init_file)
|
| 415 |
+
|
| 416 |
+
restored_params = fix_old_checkpoints(restored_params)
|
| 417 |
+
|
| 418 |
+
# Detect attempts to load non-scan checkpoint into scan model.
|
| 419 |
+
if (model_cfg.get("scan") and
|
| 420 |
+
"encoderblock" not in restored_params["Transformer"]):
|
| 421 |
+
restored_params = pyloop_to_scan(restored_params)
|
| 422 |
+
if (not model_cfg.get("scan")
|
| 423 |
+
and "encoderblock" in restored_params["Transformer"]):
|
| 424 |
+
restored_params = scan_to_pyloop(restored_params)
|
| 425 |
+
|
| 426 |
+
# possibly use the random init for some of the params (such as, the head).
|
| 427 |
+
restored_params = common.merge_params(restored_params, init_params, dont_load)
|
| 428 |
+
|
| 429 |
+
# resample posemb if needed.
|
| 430 |
+
# TODO: Take this from model_cfg to avoid need for init_params.
|
| 431 |
+
if init_params and "pos_embedding" in init_params:
|
| 432 |
+
restored_params["pos_embedding"] = resample_posemb(
|
| 433 |
+
old=restored_params["pos_embedding"],
|
| 434 |
+
new=init_params["pos_embedding"])
|
| 435 |
+
|
| 436 |
+
return restored_params
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# Shortcut names for some canonical paper checkpoints:
|
| 440 |
+
VANITY_NAMES = {
|
| 441 |
+
# pylint: disable=line-too-long
|
| 442 |
+
# Recommended models from https://arxiv.org/abs/2106.10270
|
| 443 |
+
# Many more models at https://github.com/google-research/vision_transformer
|
| 444 |
+
"howto-i21k-Ti/16": "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
|
| 445 |
+
"howto-i21k-S/32": "gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
|
| 446 |
+
"howto-i21k-S/16": "gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
|
| 447 |
+
"howto-i21k-B/32": "gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
|
| 448 |
+
"howto-i21k-B/16": "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
|
| 449 |
+
"howto-i21k-B/8": "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz",
|
| 450 |
+
"howto-i21k-L/16": "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
|
| 451 |
+
|
| 452 |
+
# Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580
|
| 453 |
+
"i1k-s16-90ep": "gs://big_vision/vit_s16_i1k_90ep.npz",
|
| 454 |
+
"i1k-s16-150ep": "gs://big_vision/vit_s16_i1k_150ep.npz",
|
| 455 |
+
"i1k-s16-300ep": "gs://big_vision/vit_s16_i1k_300ep.npz",
|
| 456 |
+
|
| 457 |
+
# DeiT-3 checkpoints from https://github.com/facebookresearch/deit/blob/main/README_revenge.md
|
| 458 |
+
# First layer converted to take inputs in [-1,1]
|
| 459 |
+
"deit3_S_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_1k.npz",
|
| 460 |
+
"deit3_S_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_21k.npz",
|
| 461 |
+
"deit3_S_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_1k.npz",
|
| 462 |
+
"deit3_S_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_21k.npz",
|
| 463 |
+
"deit3_B_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_1k.npz",
|
| 464 |
+
"deit3_B_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_21k.npz",
|
| 465 |
+
"deit3_B_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_1k.npz",
|
| 466 |
+
"deit3_B_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_21k.npz",
|
| 467 |
+
"deit3_L_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_1k.npz",
|
| 468 |
+
"deit3_L_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_21k.npz",
|
| 469 |
+
"deit3_L_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_1k.npz",
|
| 470 |
+
"deit3_L_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_21k.npz",
|
| 471 |
+
|
| 472 |
+
# SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343
|
| 473 |
+
"SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz:img",
|
| 474 |
+
"SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz:img",
|
| 475 |
+
"SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz:img",
|
| 476 |
+
"SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz:img",
|
| 477 |
+
"SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz:img",
|
| 478 |
+
"SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz:img",
|
| 479 |
+
"SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz:img",
|
| 480 |
+
"SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz:img",
|
| 481 |
+
"SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz:img",
|
| 482 |
+
|
| 483 |
+
# SigLIP 2 image encoder checkpoints from https://arxiv.org/abs/2502.14786
|
| 484 |
+
"SigLIP2 B/16 224": "gs://big_vision/siglip2/siglip2_b16_224.npz:img",
|
| 485 |
+
"SigLIP2 B/16 256": "gs://big_vision/siglip2/siglip2_b16_256.npz:img",
|
| 486 |
+
"SigLIP2 B/16 384": "gs://big_vision/siglip2/siglip2_b16_384.npz:img",
|
| 487 |
+
"SigLIP2 B/16 512": "gs://big_vision/siglip2/siglip2_b16_512.npz:img",
|
| 488 |
+
"SigLIP2 B/32 256": "gs://big_vision/siglip2/siglip2_b32_256.npz:img",
|
| 489 |
+
"SigLIP2 L/16 256": "gs://big_vision/siglip2/siglip2_l16_256.npz:img",
|
| 490 |
+
"SigLIP2 L/16 384": "gs://big_vision/siglip2/siglip2_l16_384.npz:img",
|
| 491 |
+
"SigLIP2 L/16 512": "gs://big_vision/siglip2/siglip2_l16_512.npz:img",
|
| 492 |
+
"SigLIP2 So400m/14 224": "gs://big_vision/siglip2/siglip2_so400m14_224.npz:img",
|
| 493 |
+
"SigLIP2 So400m/14 384": "gs://big_vision/siglip2/siglip2_so400m14_384.npz:img",
|
| 494 |
+
"SigLIP2 So400m/16 256": "gs://big_vision/siglip2/siglip2_so400m16_256.npz:img",
|
| 495 |
+
"SigLIP2 So400m/16 384": "gs://big_vision/siglip2/siglip2_so400m16_384.npz:img",
|
| 496 |
+
"SigLIP2 So400m/16 512": "gs://big_vision/siglip2/siglip2_so400m16_512.npz:img",
|
| 497 |
+
"SigLIP2 g-opt/16 256": "gs://big_vision/siglip2/siglip2_g-opt16_256.npz:img",
|
| 498 |
+
"SigLIP2 g-opt/16 384": "gs://big_vision/siglip2/siglip2_g-opt16_384.npz:img",
|
| 499 |
+
# SigLIP 2 NaFlex image encoder checkpoints.
|
| 500 |
+
# These need `proj.image_text.naflex_vit.py` as the image encoder model
|
| 501 |
+
# and a non-standard preprocessing, see configs/proj/image_text/README_siglip2.md.
|
| 502 |
+
"SigLIP2 B/16 NaFlex": "gs://big_vision/siglip2/siglip2_b16_naflex.npz:img",
|
| 503 |
+
"SigLIP2 So400m/16 NaFlex": "gs://big_vision/siglip2/siglip2_so400m16_naflex.npz:img",
|
| 504 |
+
# pylint: enable=line-too-long
|
| 505 |
+
}
|
Tipsomaly/model/big_vision/pp/__init__.py
ADDED
|
File without changes
|
Tipsomaly/model/big_vision/pp/autoaugment.py
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""AutoAugment and RandAugment policies for enhanced image preprocessing.
|
| 16 |
+
|
| 17 |
+
AutoAugment Reference: https://arxiv.org/abs/1805.09501
|
| 18 |
+
RandAugment Reference: https://arxiv.org/abs/1909.13719
|
| 19 |
+
|
| 20 |
+
This code is forked from
|
| 21 |
+
https://github.com/tensorflow/tpu/blob/11d0db15cf1c3667f6e36fecffa111399e008acd/models/official/efficientnet/autoaugment.py
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import absolute_import
|
| 25 |
+
from __future__ import division
|
| 26 |
+
from __future__ import print_function
|
| 27 |
+
|
| 28 |
+
import dataclasses
|
| 29 |
+
import inspect
|
| 30 |
+
import math
|
| 31 |
+
import tensorflow.compat.v1 as tf
|
| 32 |
+
from tensorflow_addons import image as contrib_image
|
| 33 |
+
|
| 34 |
+
# This signifies the max integer that the controller RNN could predict for the
|
| 35 |
+
# augmentation scheme.
|
| 36 |
+
_MAX_LEVEL = 10.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclasses.dataclass
|
| 40 |
+
class HParams:
|
| 41 |
+
"""Parameters for AutoAugment and RandAugment."""
|
| 42 |
+
cutout_const: int
|
| 43 |
+
translate_const: int
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def policy_v0():
|
| 47 |
+
"""Autoaugment policy that was used in AutoAugment Paper."""
|
| 48 |
+
# Each tuple is an augmentation operation of the form
|
| 49 |
+
# (operation, probability, magnitude). Each element in policy is a
|
| 50 |
+
# sub-policy that will be applied sequentially on the image.
|
| 51 |
+
policy = [
|
| 52 |
+
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
| 53 |
+
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
| 54 |
+
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
| 55 |
+
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
| 56 |
+
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
| 57 |
+
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
| 58 |
+
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
| 59 |
+
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
| 60 |
+
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
| 61 |
+
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
| 62 |
+
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
| 63 |
+
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
| 64 |
+
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
| 65 |
+
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
| 66 |
+
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
| 67 |
+
[('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
|
| 68 |
+
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
| 69 |
+
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
| 70 |
+
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
| 71 |
+
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
| 72 |
+
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
| 73 |
+
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
| 74 |
+
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
|
| 75 |
+
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
| 76 |
+
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
| 77 |
+
]
|
| 78 |
+
return policy
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def policy_vtest():
|
| 82 |
+
"""Autoaugment test policy for debugging."""
|
| 83 |
+
# Each tuple is an augmentation operation of the form
|
| 84 |
+
# (operation, probability, magnitude). Each element in policy is a
|
| 85 |
+
# sub-policy that will be applied sequentially on the image.
|
| 86 |
+
policy = [
|
| 87 |
+
[('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
|
| 88 |
+
]
|
| 89 |
+
return policy
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def blend(image1, image2, factor):
|
| 93 |
+
"""Blend image1 and image2 using 'factor'.
|
| 94 |
+
Factor can be above 0.0. A value of 0.0 means only image1 is used.
|
| 95 |
+
A value of 1.0 means only image2 is used. A value between 0.0 and
|
| 96 |
+
1.0 means we linearly interpolate the pixel values between the two
|
| 97 |
+
images. A value greater than 1.0 "extrapolates" the difference
|
| 98 |
+
between the two pixel values, and we clip the results to values
|
| 99 |
+
between 0 and 255.
|
| 100 |
+
Args:
|
| 101 |
+
image1: An image Tensor of type uint8.
|
| 102 |
+
image2: An image Tensor of type uint8.
|
| 103 |
+
factor: A floating point value above 0.0.
|
| 104 |
+
Returns:
|
| 105 |
+
A blended image Tensor of type uint8.
|
| 106 |
+
"""
|
| 107 |
+
if factor == 0.0:
|
| 108 |
+
return tf.convert_to_tensor(image1)
|
| 109 |
+
if factor == 1.0:
|
| 110 |
+
return tf.convert_to_tensor(image2)
|
| 111 |
+
|
| 112 |
+
image1 = tf.to_float(image1)
|
| 113 |
+
image2 = tf.to_float(image2)
|
| 114 |
+
|
| 115 |
+
difference = image2 - image1
|
| 116 |
+
scaled = factor * difference
|
| 117 |
+
|
| 118 |
+
# Do addition in float.
|
| 119 |
+
temp = tf.to_float(image1) + scaled
|
| 120 |
+
|
| 121 |
+
# Interpolate
|
| 122 |
+
if factor > 0.0 and factor < 1.0:
|
| 123 |
+
# Interpolation means we always stay within 0 and 255.
|
| 124 |
+
return tf.cast(temp, tf.uint8)
|
| 125 |
+
|
| 126 |
+
# Extrapolate:
|
| 127 |
+
#
|
| 128 |
+
# We need to clip and then cast.
|
| 129 |
+
return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def cutout(image, pad_size, replace=0):
|
| 133 |
+
"""Apply cutout (https://arxiv.org/abs/1708.04552) to image.
|
| 134 |
+
This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
|
| 135 |
+
a random location within `img`. The pixel values filled in will be of the
|
| 136 |
+
value `replace`. The located where the mask will be applied is randomly
|
| 137 |
+
chosen uniformly over the whole image.
|
| 138 |
+
Args:
|
| 139 |
+
image: An image Tensor of type uint8.
|
| 140 |
+
pad_size: Specifies how big the zero mask that will be generated is that
|
| 141 |
+
is applied to the image. The mask will be of size
|
| 142 |
+
(2*pad_size x 2*pad_size).
|
| 143 |
+
replace: What pixel value to fill in the image in the area that has
|
| 144 |
+
the cutout mask applied to it.
|
| 145 |
+
Returns:
|
| 146 |
+
An image Tensor that is of type uint8.
|
| 147 |
+
"""
|
| 148 |
+
image_height = tf.shape(image)[0]
|
| 149 |
+
image_width = tf.shape(image)[1]
|
| 150 |
+
|
| 151 |
+
# Sample the center location in the image where the zero mask will be applied.
|
| 152 |
+
cutout_center_height = tf.random_uniform(
|
| 153 |
+
shape=[], minval=0, maxval=image_height,
|
| 154 |
+
dtype=tf.int32)
|
| 155 |
+
|
| 156 |
+
cutout_center_width = tf.random_uniform(
|
| 157 |
+
shape=[], minval=0, maxval=image_width,
|
| 158 |
+
dtype=tf.int32)
|
| 159 |
+
|
| 160 |
+
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
|
| 161 |
+
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
|
| 162 |
+
left_pad = tf.maximum(0, cutout_center_width - pad_size)
|
| 163 |
+
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
|
| 164 |
+
|
| 165 |
+
cutout_shape = [image_height - (lower_pad + upper_pad),
|
| 166 |
+
image_width - (left_pad + right_pad)]
|
| 167 |
+
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
|
| 168 |
+
mask = tf.pad(
|
| 169 |
+
tf.zeros(cutout_shape, dtype=image.dtype),
|
| 170 |
+
padding_dims, constant_values=1)
|
| 171 |
+
mask = tf.expand_dims(mask, -1)
|
| 172 |
+
mask = tf.tile(mask, [1, 1, 3])
|
| 173 |
+
image = tf.where(
|
| 174 |
+
tf.equal(mask, 0),
|
| 175 |
+
tf.ones_like(image, dtype=image.dtype) * replace,
|
| 176 |
+
image)
|
| 177 |
+
return image
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def solarize(image, threshold=128):
|
| 181 |
+
# For each pixel in the image, select the pixel
|
| 182 |
+
# if the value is less than the threshold.
|
| 183 |
+
# Otherwise, subtract 255 from the pixel.
|
| 184 |
+
return tf.where(image < threshold, image, 255 - image)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def solarize_add(image, addition=0, threshold=128):
|
| 188 |
+
# For each pixel in the image less than threshold
|
| 189 |
+
# we add 'addition' amount to it and then clip the
|
| 190 |
+
# pixel value to be between 0 and 255. The value
|
| 191 |
+
# of 'addition' is between -128 and 128.
|
| 192 |
+
added_image = tf.cast(image, tf.int64) + addition
|
| 193 |
+
added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
|
| 194 |
+
return tf.where(image < threshold, added_image, image)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def color(image, factor):
|
| 198 |
+
"""Equivalent of PIL Color."""
|
| 199 |
+
degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
|
| 200 |
+
return blend(degenerate, image, factor)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def contrast(image, factor):
|
| 204 |
+
"""Equivalent of PIL Contrast."""
|
| 205 |
+
degenerate = tf.image.rgb_to_grayscale(image)
|
| 206 |
+
# Cast before calling tf.histogram.
|
| 207 |
+
degenerate = tf.cast(degenerate, tf.int32)
|
| 208 |
+
|
| 209 |
+
# Compute the grayscale histogram, then compute the mean pixel value,
|
| 210 |
+
# and create a constant image size of that value. Use that as the
|
| 211 |
+
# blending degenerate target of the original image.
|
| 212 |
+
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
|
| 213 |
+
mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
|
| 214 |
+
degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
|
| 215 |
+
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
|
| 216 |
+
degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
|
| 217 |
+
return blend(degenerate, image, factor)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def brightness(image, factor):
|
| 221 |
+
"""Equivalent of PIL Brightness."""
|
| 222 |
+
degenerate = tf.zeros_like(image)
|
| 223 |
+
return blend(degenerate, image, factor)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def posterize(image, bits):
|
| 227 |
+
"""Equivalent of PIL Posterize."""
|
| 228 |
+
shift = 8 - bits
|
| 229 |
+
return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def rotate(image, degrees, replace):
|
| 233 |
+
"""Rotates the image by degrees either clockwise or counterclockwise.
|
| 234 |
+
Args:
|
| 235 |
+
image: An image Tensor of type uint8.
|
| 236 |
+
degrees: Float, a scalar angle in degrees to rotate all images by. If
|
| 237 |
+
degrees is positive the image will be rotated clockwise otherwise it will
|
| 238 |
+
be rotated counterclockwise.
|
| 239 |
+
replace: A one or three value 1D tensor to fill empty pixels caused by
|
| 240 |
+
the rotate operation.
|
| 241 |
+
Returns:
|
| 242 |
+
The rotated version of image.
|
| 243 |
+
"""
|
| 244 |
+
# Convert from degrees to radians.
|
| 245 |
+
degrees_to_radians = math.pi / 180.0
|
| 246 |
+
radians = degrees * degrees_to_radians
|
| 247 |
+
|
| 248 |
+
# In practice, we should randomize the rotation degrees by flipping
|
| 249 |
+
# it negatively half the time, but that's done on 'degrees' outside
|
| 250 |
+
# of the function.
|
| 251 |
+
image = contrib_image.rotate(wrap(image), radians)
|
| 252 |
+
return unwrap(image, replace)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def translate_x(image, pixels, replace):
|
| 256 |
+
"""Equivalent of PIL Translate in X dimension."""
|
| 257 |
+
image = contrib_image.translate(wrap(image), [-pixels, 0])
|
| 258 |
+
return unwrap(image, replace)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def translate_y(image, pixels, replace):
|
| 262 |
+
"""Equivalent of PIL Translate in Y dimension."""
|
| 263 |
+
image = contrib_image.translate(wrap(image), [0, -pixels])
|
| 264 |
+
return unwrap(image, replace)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def shear_x(image, level, replace):
|
| 268 |
+
"""Equivalent of PIL Shearing in X dimension."""
|
| 269 |
+
# Shear parallel to x axis is a projective transform
|
| 270 |
+
# with a matrix form of:
|
| 271 |
+
# [1 level
|
| 272 |
+
# 0 1].
|
| 273 |
+
image = contrib_image.transform(
|
| 274 |
+
wrap(image), [1., level, 0., 0., 1., 0., 0., 0.])
|
| 275 |
+
return unwrap(image, replace)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def shear_y(image, level, replace):
|
| 279 |
+
"""Equivalent of PIL Shearing in Y dimension."""
|
| 280 |
+
# Shear parallel to y axis is a projective transform
|
| 281 |
+
# with a matrix form of:
|
| 282 |
+
# [1 0
|
| 283 |
+
# level 1].
|
| 284 |
+
image = contrib_image.transform(
|
| 285 |
+
wrap(image), [1., 0., 0., level, 1., 0., 0., 0.])
|
| 286 |
+
return unwrap(image, replace)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def autocontrast(image):
|
| 290 |
+
"""Implements Autocontrast function from PIL using TF ops.
|
| 291 |
+
Args:
|
| 292 |
+
image: A 3D uint8 tensor.
|
| 293 |
+
Returns:
|
| 294 |
+
The image after it has had autocontrast applied to it and will be of type
|
| 295 |
+
uint8.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def scale_channel(image):
|
| 299 |
+
"""Scale the 2D image using the autocontrast rule."""
|
| 300 |
+
# A possibly cheaper version can be done using cumsum/unique_with_counts
|
| 301 |
+
# over the histogram values, rather than iterating over the entire image.
|
| 302 |
+
# to compute mins and maxes.
|
| 303 |
+
lo = tf.to_float(tf.reduce_min(image))
|
| 304 |
+
hi = tf.to_float(tf.reduce_max(image))
|
| 305 |
+
|
| 306 |
+
# Scale the image, making the lowest value 0 and the highest value 255.
|
| 307 |
+
def scale_values(im):
|
| 308 |
+
scale = 255.0 / (hi - lo)
|
| 309 |
+
offset = -lo * scale
|
| 310 |
+
im = tf.to_float(im) * scale + offset
|
| 311 |
+
im = tf.clip_by_value(im, 0.0, 255.0)
|
| 312 |
+
return tf.cast(im, tf.uint8)
|
| 313 |
+
|
| 314 |
+
result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
|
| 315 |
+
return result
|
| 316 |
+
|
| 317 |
+
# Assumes RGB for now. Scales each channel independently
|
| 318 |
+
# and then stacks the result.
|
| 319 |
+
s1 = scale_channel(image[:, :, 0])
|
| 320 |
+
s2 = scale_channel(image[:, :, 1])
|
| 321 |
+
s3 = scale_channel(image[:, :, 2])
|
| 322 |
+
image = tf.stack([s1, s2, s3], 2)
|
| 323 |
+
return image
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def sharpness(image, factor):
|
| 327 |
+
"""Implements Sharpness function from PIL using TF ops."""
|
| 328 |
+
orig_image = image
|
| 329 |
+
image = tf.cast(image, tf.float32)
|
| 330 |
+
# Make image 4D for conv operation.
|
| 331 |
+
image = tf.expand_dims(image, 0)
|
| 332 |
+
# SMOOTH PIL Kernel.
|
| 333 |
+
kernel = tf.constant(
|
| 334 |
+
[[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32,
|
| 335 |
+
shape=[3, 3, 1, 1]) / 13.
|
| 336 |
+
# Tile across channel dimension.
|
| 337 |
+
kernel = tf.tile(kernel, [1, 1, 3, 1])
|
| 338 |
+
strides = [1, 1, 1, 1]
|
| 339 |
+
with tf.device('/cpu:0'):
|
| 340 |
+
# Some augmentation that uses depth-wise conv will cause crashing when
|
| 341 |
+
# training on GPU. See ((internal link)) for details.
|
| 342 |
+
degenerate = tf.nn.depthwise_conv2d(
|
| 343 |
+
image, kernel, strides, padding='VALID', rate=[1, 1])
|
| 344 |
+
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
|
| 345 |
+
degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
|
| 346 |
+
|
| 347 |
+
# For the borders of the resulting image, fill in the values of the
|
| 348 |
+
# original image.
|
| 349 |
+
mask = tf.ones_like(degenerate)
|
| 350 |
+
padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
|
| 351 |
+
padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
|
| 352 |
+
result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
|
| 353 |
+
|
| 354 |
+
# Blend the final result.
|
| 355 |
+
return blend(result, orig_image, factor)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def equalize(image):
|
| 359 |
+
"""Implements Equalize function from PIL using TF ops."""
|
| 360 |
+
def scale_channel(im, c):
|
| 361 |
+
"""Scale the data in the channel to implement equalize."""
|
| 362 |
+
im = tf.cast(im[:, :, c], tf.int32)
|
| 363 |
+
# Compute the histogram of the image channel.
|
| 364 |
+
histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
|
| 365 |
+
|
| 366 |
+
# For the purposes of computing the step, filter out the nonzeros.
|
| 367 |
+
nonzero = tf.where(tf.not_equal(histo, 0))
|
| 368 |
+
nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
|
| 369 |
+
step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
|
| 370 |
+
|
| 371 |
+
def build_lut(histo, step):
|
| 372 |
+
# Compute the cumulative sum, shifting by step // 2
|
| 373 |
+
# and then normalization by step.
|
| 374 |
+
lut = (tf.cumsum(histo) + (step // 2)) // step
|
| 375 |
+
# Shift lut, prepending with 0.
|
| 376 |
+
lut = tf.concat([[0], lut[:-1]], 0)
|
| 377 |
+
# Clip the counts to be in range. This is done
|
| 378 |
+
# in the C code for image.point.
|
| 379 |
+
return tf.clip_by_value(lut, 0, 255)
|
| 380 |
+
|
| 381 |
+
# If step is zero, return the original image. Otherwise, build
|
| 382 |
+
# lut from the full histogram and step and then index from it.
|
| 383 |
+
result = tf.cond(tf.equal(step, 0),
|
| 384 |
+
lambda: im,
|
| 385 |
+
lambda: tf.gather(build_lut(histo, step), im))
|
| 386 |
+
|
| 387 |
+
return tf.cast(result, tf.uint8)
|
| 388 |
+
|
| 389 |
+
# Assumes RGB for now. Scales each channel independently
|
| 390 |
+
# and then stacks the result.
|
| 391 |
+
s1 = scale_channel(image, 0)
|
| 392 |
+
s2 = scale_channel(image, 1)
|
| 393 |
+
s3 = scale_channel(image, 2)
|
| 394 |
+
image = tf.stack([s1, s2, s3], 2)
|
| 395 |
+
return image
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def invert(image):
|
| 399 |
+
"""Inverts the image pixels."""
|
| 400 |
+
image = tf.convert_to_tensor(image)
|
| 401 |
+
return 255 - image
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def wrap(image):
|
| 405 |
+
"""Returns 'image' with an extra channel set to all 1s."""
|
| 406 |
+
shape = tf.shape(image)
|
| 407 |
+
extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
|
| 408 |
+
extended = tf.concat([image, extended_channel], 2)
|
| 409 |
+
return extended
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def unwrap(image, replace):
|
| 413 |
+
"""Unwraps an image produced by wrap.
|
| 414 |
+
Where there is a 0 in the last channel for every spatial position,
|
| 415 |
+
the rest of the three channels in that spatial dimension are grayed
|
| 416 |
+
(set to 128). Operations like translate and shear on a wrapped
|
| 417 |
+
Tensor will leave 0s in empty locations. Some transformations look
|
| 418 |
+
at the intensity of values to do preprocessing, and we want these
|
| 419 |
+
empty pixels to assume the 'average' value, rather than pure black.
|
| 420 |
+
Args:
|
| 421 |
+
image: A 3D Image Tensor with 4 channels.
|
| 422 |
+
replace: A one or three value 1D tensor to fill empty pixels.
|
| 423 |
+
Returns:
|
| 424 |
+
image: A 3D image Tensor with 3 channels.
|
| 425 |
+
"""
|
| 426 |
+
image_shape = tf.shape(image)
|
| 427 |
+
# Flatten the spatial dimensions.
|
| 428 |
+
flattened_image = tf.reshape(image, [-1, image_shape[2]])
|
| 429 |
+
|
| 430 |
+
# Find all pixels where the last channel is zero.
|
| 431 |
+
alpha_channel = flattened_image[:, 3]
|
| 432 |
+
|
| 433 |
+
replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)
|
| 434 |
+
|
| 435 |
+
# Where they are zero, fill them in with 'replace'.
|
| 436 |
+
flattened_image = tf.where(
|
| 437 |
+
tf.equal(alpha_channel, 0),
|
| 438 |
+
tf.ones_like(flattened_image, dtype=image.dtype) * replace,
|
| 439 |
+
flattened_image)
|
| 440 |
+
|
| 441 |
+
image = tf.reshape(flattened_image, image_shape)
|
| 442 |
+
image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
|
| 443 |
+
return image
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
NAME_TO_FUNC = {
|
| 447 |
+
'AutoContrast': autocontrast,
|
| 448 |
+
'Equalize': equalize,
|
| 449 |
+
'Invert': invert,
|
| 450 |
+
'Rotate': rotate,
|
| 451 |
+
'Posterize': posterize,
|
| 452 |
+
'Solarize': solarize,
|
| 453 |
+
'SolarizeAdd': solarize_add,
|
| 454 |
+
'Color': color,
|
| 455 |
+
'Contrast': contrast,
|
| 456 |
+
'Brightness': brightness,
|
| 457 |
+
'Sharpness': sharpness,
|
| 458 |
+
'ShearX': shear_x,
|
| 459 |
+
'ShearY': shear_y,
|
| 460 |
+
'TranslateX': translate_x,
|
| 461 |
+
'TranslateY': translate_y,
|
| 462 |
+
'Cutout': cutout,
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _randomly_negate_tensor(tensor):
|
| 467 |
+
"""With 50% prob turn the tensor negative."""
|
| 468 |
+
should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool)
|
| 469 |
+
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
|
| 470 |
+
return final_tensor
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _rotate_level_to_arg(level):
|
| 474 |
+
level = (level/_MAX_LEVEL) * 30.
|
| 475 |
+
level = _randomly_negate_tensor(level)
|
| 476 |
+
return (level,)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def _shrink_level_to_arg(level):
|
| 480 |
+
"""Converts level to ratio by which we shrink the image content."""
|
| 481 |
+
if level == 0:
|
| 482 |
+
return (1.0,) # if level is zero, do not shrink the image
|
| 483 |
+
# Maximum shrinking ratio is 2.9.
|
| 484 |
+
level = 2. / (_MAX_LEVEL / level) + 0.9
|
| 485 |
+
return (level,)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _enhance_level_to_arg(level):
|
| 489 |
+
return ((level/_MAX_LEVEL) * 1.8 + 0.1,)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def _shear_level_to_arg(level):
|
| 493 |
+
level = (level/_MAX_LEVEL) * 0.3
|
| 494 |
+
# Flip level to negative with 50% chance.
|
| 495 |
+
level = _randomly_negate_tensor(level)
|
| 496 |
+
return (level,)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _translate_level_to_arg(level, translate_const):
|
| 500 |
+
level = (level/_MAX_LEVEL) * float(translate_const)
|
| 501 |
+
# Flip level to negative with 50% chance.
|
| 502 |
+
level = _randomly_negate_tensor(level)
|
| 503 |
+
return (level,)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def level_to_arg(hparams):
|
| 507 |
+
return {
|
| 508 |
+
'AutoContrast': lambda level: (),
|
| 509 |
+
'Equalize': lambda level: (),
|
| 510 |
+
'Invert': lambda level: (),
|
| 511 |
+
'Rotate': _rotate_level_to_arg,
|
| 512 |
+
'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),),
|
| 513 |
+
'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),),
|
| 514 |
+
'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),),
|
| 515 |
+
'Color': _enhance_level_to_arg,
|
| 516 |
+
'Contrast': _enhance_level_to_arg,
|
| 517 |
+
'Brightness': _enhance_level_to_arg,
|
| 518 |
+
'Sharpness': _enhance_level_to_arg,
|
| 519 |
+
'ShearX': _shear_level_to_arg,
|
| 520 |
+
'ShearY': _shear_level_to_arg,
|
| 521 |
+
'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),),
|
| 522 |
+
'TranslateX': lambda level: _translate_level_to_arg(
|
| 523 |
+
level, hparams.translate_const),
|
| 524 |
+
'TranslateY': lambda level: _translate_level_to_arg(
|
| 525 |
+
level, hparams.translate_const),
|
| 526 |
+
# pylint:enable=g-long-lambda
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
|
| 531 |
+
"""Return the function that corresponds to `name` and update `level` param."""
|
| 532 |
+
func = NAME_TO_FUNC[name]
|
| 533 |
+
args = level_to_arg(augmentation_hparams)[name](level)
|
| 534 |
+
|
| 535 |
+
# Check to see if prob is passed into function. This is used for operations
|
| 536 |
+
# where we alter bboxes independently.
|
| 537 |
+
# pytype:disable=wrong-arg-types
|
| 538 |
+
if 'prob' in inspect.getfullargspec(func).args:
|
| 539 |
+
args = tuple([prob] + list(args))
|
| 540 |
+
# pytype:enable=wrong-arg-types
|
| 541 |
+
|
| 542 |
+
# Add in replace arg if it is required for the function that is being called.
|
| 543 |
+
# pytype:disable=wrong-arg-types
|
| 544 |
+
if 'replace' in inspect.getfullargspec(func).args:
|
| 545 |
+
# Make sure replace is the final argument
|
| 546 |
+
assert 'replace' == inspect.getfullargspec(func).args[-1]
|
| 547 |
+
args = tuple(list(args) + [replace_value])
|
| 548 |
+
# pytype:enable=wrong-arg-types
|
| 549 |
+
|
| 550 |
+
return (func, prob, args)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _apply_func_with_prob(func, image, args, prob):
|
| 554 |
+
"""Apply `func` to image w/ `args` as input with probability `prob`."""
|
| 555 |
+
assert isinstance(args, tuple)
|
| 556 |
+
|
| 557 |
+
# If prob is a function argument, then this randomness is being handled
|
| 558 |
+
# inside the function, so make sure it is always called.
|
| 559 |
+
# pytype:disable=wrong-arg-types
|
| 560 |
+
if 'prob' in inspect.getfullargspec(func).args:
|
| 561 |
+
prob = 1.0
|
| 562 |
+
# pytype:enable=wrong-arg-types
|
| 563 |
+
|
| 564 |
+
# Apply the function with probability `prob`.
|
| 565 |
+
should_apply_op = tf.cast(
|
| 566 |
+
tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool)
|
| 567 |
+
augmented_image = tf.cond(
|
| 568 |
+
should_apply_op,
|
| 569 |
+
lambda: func(image, *args),
|
| 570 |
+
lambda: image)
|
| 571 |
+
return augmented_image
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def select_and_apply_random_policy(policies, image):
|
| 575 |
+
"""Select a random policy from `policies` and apply it to `image`."""
|
| 576 |
+
policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32)
|
| 577 |
+
# Note that using tf.case instead of tf.conds would result in significantly
|
| 578 |
+
# larger graphs and would even break export for some larger policies.
|
| 579 |
+
for (i, policy) in enumerate(policies):
|
| 580 |
+
image = tf.cond(
|
| 581 |
+
tf.equal(i, policy_to_select),
|
| 582 |
+
lambda selected_policy=policy: selected_policy(image),
|
| 583 |
+
lambda: image)
|
| 584 |
+
return image
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def build_and_apply_nas_policy(policies, image,
|
| 588 |
+
augmentation_hparams):
|
| 589 |
+
"""Build a policy from the given policies passed in and apply to image.
|
| 590 |
+
Args:
|
| 591 |
+
policies: list of lists of tuples in the form `(func, prob, level)`, `func`
|
| 592 |
+
is a string name of the augmentation function, `prob` is the probability
|
| 593 |
+
of applying the `func` operation, `level` is the input argument for
|
| 594 |
+
`func`.
|
| 595 |
+
image: tf.Tensor that the resulting policy will be applied to.
|
| 596 |
+
augmentation_hparams: Hparams associated with the NAS learned policy.
|
| 597 |
+
Returns:
|
| 598 |
+
A version of image that now has data augmentation applied to it based on
|
| 599 |
+
the `policies` pass into the function.
|
| 600 |
+
"""
|
| 601 |
+
replace_value = [128, 128, 128]
|
| 602 |
+
|
| 603 |
+
# func is the string name of the augmentation function, prob is the
|
| 604 |
+
# probability of applying the operation and level is the parameter associated
|
| 605 |
+
# with the tf op.
|
| 606 |
+
|
| 607 |
+
# tf_policies are functions that take in an image and return an augmented
|
| 608 |
+
# image.
|
| 609 |
+
tf_policies = []
|
| 610 |
+
for policy in policies:
|
| 611 |
+
tf_policy = []
|
| 612 |
+
# Link string name to the correct python function and make sure the correct
|
| 613 |
+
# argument is passed into that function.
|
| 614 |
+
for policy_info in policy:
|
| 615 |
+
policy_info = list(policy_info) + [replace_value, augmentation_hparams]
|
| 616 |
+
|
| 617 |
+
tf_policy.append(_parse_policy_info(*policy_info))
|
| 618 |
+
# Now build the tf policy that will apply the augmentation procedue
|
| 619 |
+
# on image.
|
| 620 |
+
def make_final_policy(tf_policy_):
|
| 621 |
+
def final_policy(image_):
|
| 622 |
+
for func, prob, args in tf_policy_:
|
| 623 |
+
image_ = _apply_func_with_prob(
|
| 624 |
+
func, image_, args, prob)
|
| 625 |
+
return image_
|
| 626 |
+
return final_policy
|
| 627 |
+
tf_policies.append(make_final_policy(tf_policy))
|
| 628 |
+
|
| 629 |
+
augmented_image = select_and_apply_random_policy(
|
| 630 |
+
tf_policies, image)
|
| 631 |
+
return augmented_image
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def distort_image_with_autoaugment(image, augmentation_name):
|
| 635 |
+
"""Applies the AutoAugment policy to `image`.
|
| 636 |
+
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
|
| 637 |
+
Args:
|
| 638 |
+
image: `Tensor` of shape [height, width, 3] representing an image.
|
| 639 |
+
augmentation_name: The name of the AutoAugment policy to use. The available
|
| 640 |
+
options are `v0` and `test`. `v0` is the policy used for
|
| 641 |
+
all of the results in the paper and was found to achieve the best results
|
| 642 |
+
on the COCO dataset. `v1`, `v2` and `v3` are additional good policies
|
| 643 |
+
found on the COCO dataset that have slight variation in what operations
|
| 644 |
+
were used during the search procedure along with how many operations are
|
| 645 |
+
applied in parallel to a single image (2 vs 3).
|
| 646 |
+
Returns:
|
| 647 |
+
A tuple containing the augmented versions of `image`.
|
| 648 |
+
"""
|
| 649 |
+
available_policies = {'v0': policy_v0,
|
| 650 |
+
'test': policy_vtest}
|
| 651 |
+
if augmentation_name not in available_policies:
|
| 652 |
+
raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name))
|
| 653 |
+
|
| 654 |
+
policy = available_policies[augmentation_name]()
|
| 655 |
+
# Hparams that will be used for AutoAugment.
|
| 656 |
+
augmentation_hparams = HParams(
|
| 657 |
+
cutout_const=100, translate_const=250)
|
| 658 |
+
|
| 659 |
+
return build_and_apply_nas_policy(policy, image, augmentation_hparams)
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def distort_image_with_randaugment(image, num_layers, magnitude):
|
| 663 |
+
"""Applies the RandAugment policy to `image`.
|
| 664 |
+
RandAugment is from the paper https://arxiv.org/abs/1909.13719,
|
| 665 |
+
Args:
|
| 666 |
+
image: `Tensor` of shape [height, width, 3] representing an image.
|
| 667 |
+
num_layers: Integer, the number of augmentation transformations to apply
|
| 668 |
+
sequentially to an image. Represented as (N) in the paper. Usually best
|
| 669 |
+
values will be in the range [1, 3].
|
| 670 |
+
magnitude: Integer, shared magnitude across all augmentation operations.
|
| 671 |
+
Represented as (M) in the paper. Usually best values are in the range
|
| 672 |
+
[5, 30].
|
| 673 |
+
Returns:
|
| 674 |
+
The augmented version of `image`.
|
| 675 |
+
"""
|
| 676 |
+
replace_value = [128] * 3
|
| 677 |
+
tf.logging.info('Using RandAug.')
|
| 678 |
+
augmentation_hparams = HParams(
|
| 679 |
+
cutout_const=40, translate_const=100)
|
| 680 |
+
available_ops = [
|
| 681 |
+
'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize',
|
| 682 |
+
'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness',
|
| 683 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd']
|
| 684 |
+
|
| 685 |
+
for layer_num in range(num_layers):
|
| 686 |
+
op_to_select = tf.random_uniform(
|
| 687 |
+
[], maxval=len(available_ops), dtype=tf.int32)
|
| 688 |
+
random_magnitude = float(magnitude)
|
| 689 |
+
with tf.name_scope('randaug_layer_{}'.format(layer_num)):
|
| 690 |
+
for (i, op_name) in enumerate(available_ops):
|
| 691 |
+
prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32)
|
| 692 |
+
func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
|
| 693 |
+
replace_value, augmentation_hparams)
|
| 694 |
+
image = tf.cond(
|
| 695 |
+
tf.equal(i, op_to_select),
|
| 696 |
+
lambda selected_func=func, selected_args=args: selected_func(
|
| 697 |
+
image, *selected_args),
|
| 698 |
+
# pylint:enable=g-long-lambda
|
| 699 |
+
lambda: image)
|
| 700 |
+
return image
|
Tipsomaly/model/big_vision/pp/builder.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Preprocessing builder."""
|
| 16 |
+
|
| 17 |
+
from absl import logging
|
| 18 |
+
from big_vision.pp import registry
|
| 19 |
+
import tensorflow as tf
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_preprocess_fn(pp_pipeline, log_data=True, log_steps=False):
|
| 23 |
+
"""Transform an input string into the preprocessing function.
|
| 24 |
+
|
| 25 |
+
The minilanguage is as follows:
|
| 26 |
+
|
| 27 |
+
fn1|fn2(arg, arg2,...)|...
|
| 28 |
+
|
| 29 |
+
And describes the successive application of the various `fn`s to the input,
|
| 30 |
+
where each function can optionally have one or more arguments, which are
|
| 31 |
+
either positional or key/value, as dictated by the `fn`.
|
| 32 |
+
|
| 33 |
+
The output preprocessing function expects a dictionary as input. This
|
| 34 |
+
dictionary should have a key "image" that corresponds to a 3D tensor
|
| 35 |
+
(height x width x channel).
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
pp_pipeline: A string describing the pre-processing pipeline. If empty or
|
| 39 |
+
None, no preprocessing will be executed.
|
| 40 |
+
log_data: Whether to log the data before and after preprocessing. Can also
|
| 41 |
+
be a string to show in the log for debugging, for example dataset name.
|
| 42 |
+
log_steps: Whether to log the steps of the preprocessing pipeline.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
preprocessing function.
|
| 46 |
+
|
| 47 |
+
Raises:
|
| 48 |
+
ValueError: if preprocessing function name is unknown
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
names, ops, spec_strings = [], [], []
|
| 52 |
+
if pp_pipeline:
|
| 53 |
+
for op_spec in pp_pipeline.split("|"):
|
| 54 |
+
if not op_spec: continue # Skip empty section instead of error.
|
| 55 |
+
try:
|
| 56 |
+
ops.append(registry.Registry.lookup(f"preprocess_ops.{op_spec}")())
|
| 57 |
+
names.append(registry.parse_name(op_spec)[0])
|
| 58 |
+
spec_strings.append(op_spec)
|
| 59 |
+
except SyntaxError as err:
|
| 60 |
+
raise ValueError(f"Syntax error on: {op_spec}") from err
|
| 61 |
+
|
| 62 |
+
def _preprocess_fn(data):
|
| 63 |
+
"""The preprocessing function that is returned."""
|
| 64 |
+
nonlocal log_data, log_steps
|
| 65 |
+
|
| 66 |
+
# Apply all the individual steps in sequence.
|
| 67 |
+
if log_data:
|
| 68 |
+
logging.info("Data before pre-processing (%s):\n%s", log_data, data)
|
| 69 |
+
for name, op, spec in zip(names, ops, spec_strings):
|
| 70 |
+
if log_steps:
|
| 71 |
+
logging.info("Pre-processing step (%s): %s\n%s", name, spec, data)
|
| 72 |
+
with tf.name_scope(name):
|
| 73 |
+
data = op(data)
|
| 74 |
+
|
| 75 |
+
# Validate input
|
| 76 |
+
if not isinstance(data, dict):
|
| 77 |
+
raise ValueError("Argument `data` must be a dictionary, "
|
| 78 |
+
"not %s" % str(type(data)))
|
| 79 |
+
|
| 80 |
+
if log_data:
|
| 81 |
+
logging.info("Data after pre-processing (%s):\n%s", log_data, data)
|
| 82 |
+
log_data = False # For eager&pygrain: only log first one of each pipeline.
|
| 83 |
+
return data
|
| 84 |
+
|
| 85 |
+
return _preprocess_fn
|
Tipsomaly/model/big_vision/pp/builder_test.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for builder."""
|
| 16 |
+
|
| 17 |
+
from __future__ import absolute_import
|
| 18 |
+
from __future__ import division
|
| 19 |
+
from __future__ import print_function
|
| 20 |
+
|
| 21 |
+
from big_vision.pp import builder
|
| 22 |
+
from big_vision.pp import ops_general # pylint: disable=unused-import
|
| 23 |
+
from big_vision.pp import ops_image # pylint: disable=unused-import
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tensorflow.compat.v1 as tf
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BuilderTest(tf.test.TestCase):
|
| 29 |
+
|
| 30 |
+
def testSingle(self):
|
| 31 |
+
pp_fn = builder.get_preprocess_fn("resize(256)")
|
| 32 |
+
x = np.random.randint(0, 256, [640, 480, 3])
|
| 33 |
+
image = pp_fn({"image": x})["image"]
|
| 34 |
+
self.assertEqual(image.numpy().shape, (256, 256, 3))
|
| 35 |
+
|
| 36 |
+
def testEmpty(self):
|
| 37 |
+
pp_fn = builder.get_preprocess_fn("||inception_crop|||resize(256)||")
|
| 38 |
+
|
| 39 |
+
# Typical image input
|
| 40 |
+
x = np.random.randint(0, 256, [640, 480, 3])
|
| 41 |
+
image = pp_fn({"image": x})["image"]
|
| 42 |
+
self.assertEqual(image.numpy().shape, (256, 256, 3))
|
| 43 |
+
|
| 44 |
+
def testPreprocessingPipeline(self):
|
| 45 |
+
pp_str = ("inception_crop|resize(256)|resize((256, 256))|"
|
| 46 |
+
"central_crop((80, 120))|flip_lr|value_range(0,1)|"
|
| 47 |
+
"value_range(-1,1)")
|
| 48 |
+
pp_fn = builder.get_preprocess_fn(pp_str)
|
| 49 |
+
|
| 50 |
+
# Typical image input
|
| 51 |
+
x = np.random.randint(0, 256, [640, 480, 3])
|
| 52 |
+
image = pp_fn({"image": x})["image"]
|
| 53 |
+
self.assertEqual(image.numpy().shape, (80, 120, 3))
|
| 54 |
+
self.assertLessEqual(np.max(image.numpy()), 1)
|
| 55 |
+
self.assertGreaterEqual(np.min(image.numpy()), -1)
|
| 56 |
+
|
| 57 |
+
def testNumArgsException(self):
|
| 58 |
+
|
| 59 |
+
x = np.random.randint(0, 256, [640, 480, 3])
|
| 60 |
+
for pp_str in [
|
| 61 |
+
"inception_crop(1)",
|
| 62 |
+
"resize()",
|
| 63 |
+
"resize(1, 1, 1)"
|
| 64 |
+
"flip_lr(1)",
|
| 65 |
+
"central_crop()",
|
| 66 |
+
]:
|
| 67 |
+
with self.assertRaises(BaseException):
|
| 68 |
+
builder.get_preprocess_fn(pp_str)(x)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
tf.test.main()
|
Tipsomaly/model/big_vision/pp/ops_general.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Generic tensor preprocessing ops.
|
| 16 |
+
|
| 17 |
+
All preprocessing ops should return a data processing functors. A data
|
| 18 |
+
is represented as a dictionary of (TF) tensors. The functors output a modified
|
| 19 |
+
dictionary.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import collections
|
| 23 |
+
|
| 24 |
+
from big_vision.pp import utils
|
| 25 |
+
from big_vision.pp.registry import Registry
|
| 26 |
+
import big_vision.utils as bv_utils
|
| 27 |
+
import jax
|
| 28 |
+
import numpy as np
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@Registry.register("preprocess_ops.value_range")
|
| 33 |
+
@utils.InKeyOutKey()
|
| 34 |
+
def get_value_range(vmin=-1, vmax=1, in_min=0, in_max=255.0, clip_values=False):
|
| 35 |
+
"""Transforms a [in_min,in_max] image to [vmin,vmax] range.
|
| 36 |
+
|
| 37 |
+
Input ranges in_min/in_max can be equal-size lists to rescale the invidudal
|
| 38 |
+
channels independently.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
vmin: A scalar. Output max value.
|
| 42 |
+
vmax: A scalar. Output min value.
|
| 43 |
+
in_min: A scalar or a list of input min values to scale. If a list, the
|
| 44 |
+
length should match to the number of channels in the image.
|
| 45 |
+
in_max: A scalar or a list of input max values to scale. If a list, the
|
| 46 |
+
length should match to the number of channels in the image.
|
| 47 |
+
clip_values: Whether to clip the output values to the provided ranges.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
A function to rescale the values.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def _value_range(image):
|
| 54 |
+
"""Scales values in given range."""
|
| 55 |
+
in_min_t = tf.constant(in_min, tf.float32)
|
| 56 |
+
in_max_t = tf.constant(in_max, tf.float32)
|
| 57 |
+
image = tf.cast(image, tf.float32)
|
| 58 |
+
image = (image - in_min_t) / (in_max_t - in_min_t)
|
| 59 |
+
image = vmin + image * (vmax - vmin)
|
| 60 |
+
if clip_values:
|
| 61 |
+
image = tf.clip_by_value(image, vmin, vmax)
|
| 62 |
+
return image
|
| 63 |
+
|
| 64 |
+
return _value_range
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@Registry.register("preprocess_ops.lookup")
|
| 68 |
+
@utils.InKeyOutKey()
|
| 69 |
+
def get_lookup(mapping, npzkey="fnames", sep=None):
|
| 70 |
+
"""Map string to number."""
|
| 71 |
+
|
| 72 |
+
# For NumPy files, we use the `npzkey` array in that file as the list of
|
| 73 |
+
# strings which are mapped to their index in that array.
|
| 74 |
+
# This is especially useful when other data (eg precomputed predictions)
|
| 75 |
+
# goes along with this mapping, to have everything in one place (the npz).
|
| 76 |
+
if mapping.endswith(".npz"):
|
| 77 |
+
with tf.io.gfile.GFile(mapping, "rb") as f:
|
| 78 |
+
keys = np.array(np.load(f, allow_pickle=False)[npzkey])
|
| 79 |
+
vals = np.arange(len(keys))
|
| 80 |
+
|
| 81 |
+
# Otherwise, we simply use the file as a text file, with either of:
|
| 82 |
+
# - a string per line, mapped to its line-number
|
| 83 |
+
# - a pair, separated by `sep` per line, first value being the string, second
|
| 84 |
+
# value being the integer that the string is mapped to.
|
| 85 |
+
else:
|
| 86 |
+
with tf.io.gfile.GFile(mapping, "r") as f:
|
| 87 |
+
buf = f.read()
|
| 88 |
+
if sep is None: # values are the line numbers
|
| 89 |
+
keys = buf.splitlines()
|
| 90 |
+
vals = np.arange(len(keys))
|
| 91 |
+
else: # each line is key<sep>val, also make val int
|
| 92 |
+
keys, vals = zip(*[l.split(sep) for l in buf.splitlines()])
|
| 93 |
+
vals = [int(v) for v in vals]
|
| 94 |
+
|
| 95 |
+
def _do_the_mapping(needle):
|
| 96 |
+
"""Map string to number."""
|
| 97 |
+
with tf.init_scope(): # (Originally added for performance reasons.)
|
| 98 |
+
table = tf.lookup.StaticHashTable(
|
| 99 |
+
tf.lookup.KeyValueTensorInitializer(keys, vals), -1)
|
| 100 |
+
return table.lookup(needle)
|
| 101 |
+
|
| 102 |
+
return _do_the_mapping
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@Registry.register("preprocess_ops.onehot")
|
| 106 |
+
def get_onehot(depth,
|
| 107 |
+
key="labels",
|
| 108 |
+
key_result=None,
|
| 109 |
+
multi=True,
|
| 110 |
+
on=1.0,
|
| 111 |
+
off=0.0):
|
| 112 |
+
"""One-hot encodes the input.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
depth: Length of the one-hot vector (how many classes).
|
| 116 |
+
key: Key of the data to be one-hot encoded.
|
| 117 |
+
key_result: Key under which to store the result (same as `key` if None).
|
| 118 |
+
multi: If there are multiple labels, whether to merge them into the same
|
| 119 |
+
"multi-hot" vector (True) or keep them as an extra dimension (False).
|
| 120 |
+
on: Value to fill in for the positive label (default: 1).
|
| 121 |
+
off: Value to fill in for negative labels (default: 0).
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Data dictionary.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def _onehot(data):
|
| 128 |
+
# When there's more than one label, this is significantly more efficient
|
| 129 |
+
# than using tf.one_hot followed by tf.reduce_max; we tested.
|
| 130 |
+
labels = data[key]
|
| 131 |
+
labels = tf.cast(labels, tf.int64) # both scatter and one_hot expect this
|
| 132 |
+
if labels.shape.rank > 0 and multi:
|
| 133 |
+
x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), (depth,))
|
| 134 |
+
x = tf.clip_by_value(x, 0, 1) * (on - off) + off
|
| 135 |
+
else:
|
| 136 |
+
x = tf.one_hot(labels, depth, on_value=on, off_value=off)
|
| 137 |
+
data[key_result or key] = x
|
| 138 |
+
return data
|
| 139 |
+
|
| 140 |
+
return _onehot
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@Registry.register("preprocess_ops.keep")
|
| 144 |
+
def get_keep(*keys):
|
| 145 |
+
"""Keeps only the given keys."""
|
| 146 |
+
|
| 147 |
+
def _keep(data):
|
| 148 |
+
return {k: v for k, v in data.items() if k in keys}
|
| 149 |
+
|
| 150 |
+
return _keep
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@Registry.register("preprocess_ops.drop")
|
| 154 |
+
def get_drop(*keys):
|
| 155 |
+
"""Drops the given keys."""
|
| 156 |
+
|
| 157 |
+
def _drop(data):
|
| 158 |
+
return {k: v for k, v in data.items() if k not in keys}
|
| 159 |
+
|
| 160 |
+
return _drop
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@Registry.register("preprocess_ops.copy")
|
| 164 |
+
def get_copy(inkey, outkey):
|
| 165 |
+
"""Copies value of `inkey` into `outkey`."""
|
| 166 |
+
|
| 167 |
+
def _copy(data):
|
| 168 |
+
# A "semi-deep" copy. deepcopy doesn't work when tf tensors are part of the
|
| 169 |
+
# game. What we want, is to only copy the python structure (dicts, lists)
|
| 170 |
+
# and keep tensors as they are, since we never modify them in-place anyways.
|
| 171 |
+
# The following achieves exactly that.
|
| 172 |
+
data[outkey] = jax.tree.map(lambda x: x, data[inkey])
|
| 173 |
+
return data
|
| 174 |
+
|
| 175 |
+
return _copy
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@Registry.register("preprocess_ops.squeeze_last_dim")
|
| 179 |
+
@utils.InKeyOutKey()
|
| 180 |
+
def get_squeeze_last_dim():
|
| 181 |
+
def _squeeze_last_dim(x):
|
| 182 |
+
return tf.squeeze(x, axis=-1)
|
| 183 |
+
return _squeeze_last_dim
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@Registry.register("preprocess_ops.concat")
|
| 187 |
+
def get_concat(inkeys, outkey=None, axis=-1):
|
| 188 |
+
"""Concatenates elements along some axis."""
|
| 189 |
+
|
| 190 |
+
def _concat(data):
|
| 191 |
+
data[outkey or inkeys[0]] = tf.concat([data[k] for k in inkeys], axis)
|
| 192 |
+
return data
|
| 193 |
+
|
| 194 |
+
return _concat
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@Registry.register("preprocess_ops.rag_tensor")
|
| 198 |
+
@utils.InKeyOutKey()
|
| 199 |
+
def get_rag_tensor():
|
| 200 |
+
"""Converts the specified feature to ragged tensor."""
|
| 201 |
+
|
| 202 |
+
def rag_tensor(raw_tensor):
|
| 203 |
+
# Note: Add one more dimension as `from_tensor` requires at least rank 2.
|
| 204 |
+
return tf.RaggedTensor.from_tensor(raw_tensor[None])
|
| 205 |
+
|
| 206 |
+
return rag_tensor
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@Registry.register("preprocess_ops.pad_to_shape")
|
| 210 |
+
@utils.InKeyOutKey()
|
| 211 |
+
def get_pad_to_shape(shape, pad_value=0, where="after"):
|
| 212 |
+
"""Pads tensor to specified `shape`."""
|
| 213 |
+
|
| 214 |
+
def _pads(cur, tgt):
|
| 215 |
+
if tgt is None:
|
| 216 |
+
return [0, 0]
|
| 217 |
+
diff = tgt - cur
|
| 218 |
+
return {
|
| 219 |
+
"before": [diff, 0],
|
| 220 |
+
"after": [0, diff],
|
| 221 |
+
"both": [diff // 2, diff - diff // 2],
|
| 222 |
+
}[where]
|
| 223 |
+
|
| 224 |
+
def _pad_to_shape(x):
|
| 225 |
+
assert len(x.shape.as_list()) == len(shape)
|
| 226 |
+
paddings = [_pads(tgt=shape[i], cur=tf.shape(x)[i])
|
| 227 |
+
for i in range(len(shape))]
|
| 228 |
+
constant_value = tf.constant(pad_value, x.dtype)
|
| 229 |
+
ret = tf.pad(x, paddings, constant_values=constant_value)
|
| 230 |
+
ret.set_shape(shape)
|
| 231 |
+
return ret
|
| 232 |
+
|
| 233 |
+
return _pad_to_shape
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@Registry.register("preprocess_ops.flatten")
|
| 237 |
+
def get_flatten(keys=None):
|
| 238 |
+
"""Flattens the keys of data with separator '/'."""
|
| 239 |
+
|
| 240 |
+
def _flatten(data):
|
| 241 |
+
flatten_keys = keys or list(data.keys())
|
| 242 |
+
not_flattened = {k: v for k, v in data.items() if k not in flatten_keys}
|
| 243 |
+
flattened = {k: v for k, v in data.items() if k in flatten_keys}
|
| 244 |
+
flattened, _ = bv_utils.tree_flatten_with_names(flattened)
|
| 245 |
+
return {**dict(flattened), **not_flattened}
|
| 246 |
+
|
| 247 |
+
return _flatten
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@Registry.register("preprocess_ops.reshape")
|
| 251 |
+
@utils.InKeyOutKey()
|
| 252 |
+
def get_reshape(new_shape):
|
| 253 |
+
"""Reshapes tensor to a given new shape.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
new_shape: new shape for the tensor.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
A function for reshaping a tensor.
|
| 260 |
+
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def _reshape(tensor):
|
| 264 |
+
"""Reshapes a tensor to a given shape."""
|
| 265 |
+
dtype = tensor.dtype
|
| 266 |
+
tensor = tf.reshape(tensor, new_shape)
|
| 267 |
+
return tf.cast(tensor, dtype)
|
| 268 |
+
|
| 269 |
+
return _reshape
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@Registry.register("preprocess_ops.setdefault")
|
| 273 |
+
def get_setdefault(key, value):
|
| 274 |
+
"""If `key` is an empty tensor or missing, set it to `value`."""
|
| 275 |
+
def _setdefault(data):
|
| 276 |
+
x = data.get(key, tf.constant(value))
|
| 277 |
+
v = tf.constant(value, dtype=x.dtype)
|
| 278 |
+
v = tf.broadcast_to(v, [s or 1 for s in x.shape])
|
| 279 |
+
data[key] = tf.cond(tf.size(x) > 0, lambda: x, lambda: v)
|
| 280 |
+
return data
|
| 281 |
+
return _setdefault
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@Registry.register("preprocess_ops.choice")
|
| 285 |
+
def get_choice(n="single", key=None, fewer_ok=False, inkey=None, outkey=None):
|
| 286 |
+
"""Chooses the same `n` random entries of all `keys`.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
n: how many entries to randomly sample (without repeat). Possible values:
|
| 290 |
+
- int: that many entries (or fewer if there's fewer, see `fewer_ok`.)
|
| 291 |
+
- "single": The string "single" only chooses one and drop the leading dim.
|
| 292 |
+
- [min, max]: A pair means randomly take between min/max examples (incl.).
|
| 293 |
+
key: str or list of str: See Note.
|
| 294 |
+
fewer_ok: whether to fail when there's fewer than `n` elements to choose
|
| 295 |
+
from (and hence set static shape to `n`), or whether to allow it.
|
| 296 |
+
(and hence have unknown static shape).
|
| 297 |
+
inkey: str or list of str: See Note.
|
| 298 |
+
outkey: str or list of str: See Note.
|
| 299 |
+
|
| 300 |
+
Note:
|
| 301 |
+
If key/inkey/outkey is a list, then the same random entries are chosen for
|
| 302 |
+
all of the keys. Other than that, they function the same as InKeyOutKey.
|
| 303 |
+
|
| 304 |
+
The outkey can also contain the placeholder `{key}` that'll be .
|
| 305 |
+
|
| 306 |
+
Examples:
|
| 307 |
+
choice(key="alt_text/text")
|
| 308 |
+
choice(n=128, key=["patches", "positions"])
|
| 309 |
+
choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"])
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
The pp op.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
# Normalize keys:
|
| 316 |
+
inkeys = utils.maybe_repeat(inkey or key, 1)
|
| 317 |
+
outkeys = utils.maybe_repeat(outkey or key, 1)
|
| 318 |
+
outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)]
|
| 319 |
+
|
| 320 |
+
# Let's DRY on this condition and give it a name.
|
| 321 |
+
is_varlen = isinstance(n, (list, tuple))
|
| 322 |
+
min_n = n[0] if is_varlen else 1 if n == "single" else n
|
| 323 |
+
|
| 324 |
+
def _choice(data):
|
| 325 |
+
# Catch a hard to identify/understand user error:
|
| 326 |
+
assert data[inkeys[0]].ndim > 0, (
|
| 327 |
+
f"You're calling `choice_no_replacement` on {inkeys}, a scalar."
|
| 328 |
+
" That's probably a mistake ; double-check and then just don't."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
nitems = tf.shape(data[inkeys[0]])[0]
|
| 332 |
+
|
| 333 |
+
# Sanity check that all keys have same leading dimension, and that is at
|
| 334 |
+
# least as large as the minimum requested output.
|
| 335 |
+
lengths = [tf.shape(data[k])[0] for k in inkeys]
|
| 336 |
+
checks = [tf.debugging.assert_equal(l, nitems) for l in lengths]
|
| 337 |
+
if not fewer_ok: # Since we check for all-same, a single suffices here.
|
| 338 |
+
checks.append(tf.debugging.assert_greater_equal(nitems, min_n))
|
| 339 |
+
with tf.control_dependencies(checks):
|
| 340 |
+
nitems = tf.identity(nitems)
|
| 341 |
+
|
| 342 |
+
if n == "single":
|
| 343 |
+
index = tf.random.uniform([], 0, nitems, dtype=tf.int32)
|
| 344 |
+
else:
|
| 345 |
+
# Subsample by shuffling and taking first n, but...
|
| 346 |
+
indices = tf.random.shuffle(tf.range(nitems))
|
| 347 |
+
end = n
|
| 348 |
+
if is_varlen:
|
| 349 |
+
end = tf.random.uniform([], n[0], n[1] + 1, dtype=tf.int32)
|
| 350 |
+
# ...keep the order while subsampling (it might have a meaning, eg boxes)
|
| 351 |
+
indices = tf.sort(indices[:end])
|
| 352 |
+
|
| 353 |
+
for ik, ok in zip(inkeys, outkeys):
|
| 354 |
+
if n == "single":
|
| 355 |
+
result = data[ik][index]
|
| 356 |
+
else:
|
| 357 |
+
result = tf.gather(data[ik], indices, axis=0)
|
| 358 |
+
if not is_varlen: # Give static shape when we can.
|
| 359 |
+
result = tf.ensure_shape(result, [n] + [None] * (result.ndim - 1))
|
| 360 |
+
data[ok] = result
|
| 361 |
+
|
| 362 |
+
return data
|
| 363 |
+
return _choice
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def _shuffled_index(count, nitems, seed):
|
| 367 |
+
"""Returns index from a shuffled sequence (items only repeat after epoch)."""
|
| 368 |
+
nitems = tf.cast(nitems, count.dtype)
|
| 369 |
+
item_epoch, item_offset = (count // nitems, count % nitems)
|
| 370 |
+
shuffled_indices = tf.random.experimental.stateless_shuffle(
|
| 371 |
+
tf.range(nitems), seed=tf.random.fold_in(seed, item_epoch))
|
| 372 |
+
return shuffled_indices[item_offset]
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@Registry.register("preprocess_ops.choice_no_replacement")
|
| 376 |
+
def get_choice_no_replacement(key=None, inkey=None, outkey=None):
|
| 377 |
+
"""Chooses the same random (no replacement) entry of all `keys`.
|
| 378 |
+
|
| 379 |
+
Note: Consider using this for iterating over small datasets with a small
|
| 380 |
+
number of epochs. It differs from `choice(n='single')` in that if an example,
|
| 381 |
+
as identified by its `_id` field, is seen N times then it will cycled through
|
| 382 |
+
all the inkeys values before repeating them. Additionally each repetition uses
|
| 383 |
+
a different order.
|
| 384 |
+
|
| 385 |
+
Caveats: requires dataset to provide a _id field and uses host RAM to keep a
|
| 386 |
+
counter how often each id is seen. It is also not robust to preemptions.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
key: str or list of str: See Note.
|
| 390 |
+
inkey: str or list of str: See Note.
|
| 391 |
+
outkey: str or list of str: See Note.
|
| 392 |
+
|
| 393 |
+
Note:
|
| 394 |
+
If key/inkey/outkey is a list, then the same random entries are chosen for
|
| 395 |
+
all of the keys. Other than that, they function the same as InKeyOutKey.
|
| 396 |
+
|
| 397 |
+
The outkey can also contain the placeholder `{key}` that'll be replaced
|
| 398 |
+
by the inkey name.
|
| 399 |
+
|
| 400 |
+
Examples:
|
| 401 |
+
choice(key="alt_text/text")
|
| 402 |
+
choice(key=["patches", "positions"])
|
| 403 |
+
choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"])
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
The pp op.
|
| 407 |
+
"""
|
| 408 |
+
# Normalize keys:
|
| 409 |
+
inkeys = utils.maybe_repeat(inkey or key, 1)
|
| 410 |
+
outkeys = utils.maybe_repeat(outkey or key, 1)
|
| 411 |
+
outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)]
|
| 412 |
+
|
| 413 |
+
# TODO: Ideally the data pipeline should provide us with an epoch
|
| 414 |
+
# counter. For now count how often we see a given example id and don't worry
|
| 415 |
+
# on memory consumption. Counter returns 0 the first time an example is seen.
|
| 416 |
+
counter = collections.defaultdict(lambda: -1)
|
| 417 |
+
def _seen_count(example_id):
|
| 418 |
+
example_id = example_id.item()
|
| 419 |
+
counter[example_id] += 1
|
| 420 |
+
return counter[example_id]
|
| 421 |
+
|
| 422 |
+
# We need a seed to deterministically decide on a shuffled sequence and use
|
| 423 |
+
# the number of times an example was seen to iterate through it. The seed
|
| 424 |
+
# should be different for every instance of a create preprocessing function
|
| 425 |
+
# but it has to be fixed for each instance.
|
| 426 |
+
seed = tf.random.uniform(
|
| 427 |
+
[2], minval=tf.int32.min, maxval=tf.int32.max, dtype=tf.int32)
|
| 428 |
+
|
| 429 |
+
def _choice(data):
|
| 430 |
+
# Catch a hard to identify/understand user error:
|
| 431 |
+
assert data[inkeys[0]].ndim > 0, (
|
| 432 |
+
f"You're calling `choice` on {inkeys}, a scalar."
|
| 433 |
+
" That's probably a mistake ; double-check and then just don't."
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
nitems = tf.shape(data[inkeys[0]])[0]
|
| 437 |
+
|
| 438 |
+
# Sanity check that all keys have same leading dimension.
|
| 439 |
+
checks = [
|
| 440 |
+
tf.debugging.assert_equal(tf.shape(data[k])[0], nitems)
|
| 441 |
+
for k in inkeys
|
| 442 |
+
]
|
| 443 |
+
with tf.control_dependencies(checks):
|
| 444 |
+
nitems = tf.identity(nitems)
|
| 445 |
+
|
| 446 |
+
# Using the seed, example id and the number of times an example was seen
|
| 447 |
+
# pick an `index` such that items are only repeated after all items are seen
|
| 448 |
+
# an equal number of times. E.g. it could return indexes from this sequence:
|
| 449 |
+
# [0, 1, 2, 1, 2, 0, 2, 0, 1, 0, 2, 1, ...].
|
| 450 |
+
count = tf.numpy_function(
|
| 451 |
+
_seen_count, (data["_id"],), Tout=tf.int64, stateful=True)
|
| 452 |
+
count = tf.cast(count, tf.int32)
|
| 453 |
+
nitems = tf.cast(nitems, tf.int32)
|
| 454 |
+
shuffle_epoch = count // nitems
|
| 455 |
+
shuffle_offset = count % nitems
|
| 456 |
+
|
| 457 |
+
example_seed = tf.random.fold_in(seed, data["_id"])
|
| 458 |
+
shuffle_seed = tf.random.fold_in(example_seed, shuffle_epoch)
|
| 459 |
+
shuffle = tf.random.experimental.stateless_shuffle(
|
| 460 |
+
tf.range(nitems), seed=shuffle_seed)
|
| 461 |
+
index = shuffle[shuffle_offset]
|
| 462 |
+
|
| 463 |
+
# Select item[index] for all keys.
|
| 464 |
+
for ik, ok in zip(inkeys, outkeys):
|
| 465 |
+
data[ok] = data[ik][index]
|
| 466 |
+
return data
|
| 467 |
+
|
| 468 |
+
return _choice
|
Tipsomaly/model/big_vision/pp/ops_general_test.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for ops_general."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
|
| 19 |
+
import big_vision.pp.ops_general as pp
|
| 20 |
+
import numpy as np
|
| 21 |
+
import tensorflow as tf
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PreprocessOpsTest(tf.test.TestCase):
|
| 25 |
+
|
| 26 |
+
def tfrun(self, ppfn, data):
|
| 27 |
+
# Run once as standalone, as could happen eg in colab.
|
| 28 |
+
yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()}
|
| 29 |
+
|
| 30 |
+
# And then once again as part of tfdata pipeline.
|
| 31 |
+
# You'd be surprised how much these two differ!
|
| 32 |
+
tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data))
|
| 33 |
+
for npdata in tfdata.map(ppfn).as_numpy_iterator():
|
| 34 |
+
yield npdata
|
| 35 |
+
|
| 36 |
+
def test_value_range(self):
|
| 37 |
+
img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32)
|
| 38 |
+
data = {"image": tf.cast(img, tf.uint8)}
|
| 39 |
+
for out in self.tfrun(pp.get_value_range(-0.5, 0.5), data):
|
| 40 |
+
self.assertLessEqual(np.max(out["image"]), 0.5)
|
| 41 |
+
self.assertGreaterEqual(np.min(out["image"]), -0.5)
|
| 42 |
+
|
| 43 |
+
def test_value_range_custom_input_range(self):
|
| 44 |
+
img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32)
|
| 45 |
+
data = {"image": tf.cast(img, tf.uint8)}
|
| 46 |
+
for out in self.tfrun(pp.get_value_range(-0.5, 0.5, -256, 255, True), data):
|
| 47 |
+
self.assertLessEqual(np.max(out["image"]), 0.5)
|
| 48 |
+
self.assertGreaterEqual(np.min(out["image"]), 0.0)
|
| 49 |
+
|
| 50 |
+
def test_get_keep_drop(self):
|
| 51 |
+
data = {"image": 1, "labels": 2, "something": 3}
|
| 52 |
+
|
| 53 |
+
for data_keep in self.tfrun(pp.get_keep("image", "labels"), data):
|
| 54 |
+
self.assertAllEqual(set(data_keep.keys()), {"image", "labels"})
|
| 55 |
+
|
| 56 |
+
for data_drop in self.tfrun(pp.get_drop("image", "labels"), data):
|
| 57 |
+
self.assertAllEqual(set(data_drop.keys()), {"something"})
|
| 58 |
+
|
| 59 |
+
def test_onehot(self):
|
| 60 |
+
data = {"labels": tf.constant(2, dtype=tf.int64)}
|
| 61 |
+
for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data):
|
| 62 |
+
self.assertAllClose(out["labels"], [0., 0., 1., 0.])
|
| 63 |
+
|
| 64 |
+
def test_onehot_multi(self):
|
| 65 |
+
data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)}
|
| 66 |
+
for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data):
|
| 67 |
+
self.assertAllClose(out["labels"], [
|
| 68 |
+
[0., 0., 1., 0.],
|
| 69 |
+
[0., 0., 0., 1.],
|
| 70 |
+
[1., 0., 0., 0.]])
|
| 71 |
+
|
| 72 |
+
for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data):
|
| 73 |
+
self.assertAllClose(out["labels"], [1., 0., 1., 1.])
|
| 74 |
+
|
| 75 |
+
def test_onehot_2d(self):
|
| 76 |
+
data = {"labels": tf.constant([[2, 3], [0, 1]], dtype=tf.int64)}
|
| 77 |
+
for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data):
|
| 78 |
+
self.assertAllClose(out["labels"], [
|
| 79 |
+
[[0., 0., 1., 0.], [0., 0., 0., 1.]],
|
| 80 |
+
[[1., 0., 0., 0.], [0., 1., 0., 0.]]])
|
| 81 |
+
|
| 82 |
+
def test_onehot_smoothing(self):
|
| 83 |
+
data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)}
|
| 84 |
+
for out in self.tfrun(
|
| 85 |
+
pp.get_onehot(4, "labels", multi=False, on=0.8, off=0.1), data):
|
| 86 |
+
self.assertAllClose(out["labels"], [
|
| 87 |
+
[0.1, 0.1, 0.8, 0.1],
|
| 88 |
+
[0.1, 0.1, 0.1, 0.8],
|
| 89 |
+
[0.8, 0.1, 0.1, 0.1]])
|
| 90 |
+
|
| 91 |
+
for out in self.tfrun(
|
| 92 |
+
pp.get_onehot(4, "labels", multi=True, on=0.8, off=0.1), data):
|
| 93 |
+
self.assertAllClose(out["labels"], [0.8, 0.1, 0.8, 0.8])
|
| 94 |
+
|
| 95 |
+
def test_squeeze_last_dim(self):
|
| 96 |
+
data = {"image": tf.constant(np.zeros((32, 32, 3, 1)))}
|
| 97 |
+
for out in self.tfrun(pp.get_squeeze_last_dim(), data):
|
| 98 |
+
self.assertAllEqual(out["image"].shape, [32, 32, 3])
|
| 99 |
+
|
| 100 |
+
def test_pad_to_shape(self):
|
| 101 |
+
desired_shape = (8, 10)
|
| 102 |
+
for input_shape in [(8, 4), (8, 3), (8, 10), (8, 1)]:
|
| 103 |
+
data = {"x": tf.ones(input_shape, dtype=tf.float32)}
|
| 104 |
+
for out in self.tfrun(
|
| 105 |
+
pp.get_pad_to_shape(desired_shape, pad_value=-1, key="x"), data):
|
| 106 |
+
self.assertEqual(
|
| 107 |
+
tf.reduce_sum(out["x"]),
|
| 108 |
+
2 * np.prod(input_shape) - np.prod(desired_shape))
|
| 109 |
+
|
| 110 |
+
def test_pad_to_shape_none(self):
|
| 111 |
+
data = {"x": tf.ones((8, 4), dtype=tf.float32)}
|
| 112 |
+
for out in self.tfrun(
|
| 113 |
+
pp.get_pad_to_shape((None, 6), pad_value=-1, key="x"), data):
|
| 114 |
+
self.assertEqual(out["x"].shape, (8, 6))
|
| 115 |
+
self.assertEqual(tf.reduce_sum(out["x"]), 8*4 - 8*2)
|
| 116 |
+
|
| 117 |
+
def test_pad_to_shape_which_side(self):
|
| 118 |
+
data = {"x": tf.ones((8, 4), dtype=tf.float32)}
|
| 119 |
+
for where, idxs in [("before", [0]), ("both", [0, -1]), ("after", [-1])]:
|
| 120 |
+
for out in self.tfrun(
|
| 121 |
+
pp.get_pad_to_shape((8, 6), key="x", where=where), data):
|
| 122 |
+
self.assertEqual(out["x"].shape, (8, 6))
|
| 123 |
+
self.assertEqual(tf.reduce_sum(out["x"]), 8*4)
|
| 124 |
+
for i in idxs:
|
| 125 |
+
self.assertEqual(out["x"][0, i], 0)
|
| 126 |
+
|
| 127 |
+
def test_flatten(self):
|
| 128 |
+
d = {"a": {"b": tf.constant([1, 2, 3])}, "c": "str"}
|
| 129 |
+
self.assertEqual(pp.get_flatten()(d), {
|
| 130 |
+
"a/b": tf.constant([1, 2, 3]),
|
| 131 |
+
"c": "str"
|
| 132 |
+
})
|
| 133 |
+
|
| 134 |
+
def test_reshape(self):
|
| 135 |
+
data = {"image": tf.constant(np.zeros((8, 32 * 32 * 3)))}
|
| 136 |
+
for out in self.tfrun(pp.get_reshape(new_shape=(8, 32, 32, 3)), data):
|
| 137 |
+
self.assertAllEqual(out["image"].shape, [8, 32, 32, 3])
|
| 138 |
+
|
| 139 |
+
def test_setdefault(self):
|
| 140 |
+
data = {
|
| 141 |
+
"empty_image": tf.zeros([0, 0, 0]),
|
| 142 |
+
"image": tf.constant(np.arange(9).reshape(3, 3)),
|
| 143 |
+
"empty_text": tf.zeros([0], tf.string),
|
| 144 |
+
"text": tf.constant(["Hello", "World"], tf.string),
|
| 145 |
+
}
|
| 146 |
+
for out in self.tfrun(pp.get_setdefault("empty_image", 1), data):
|
| 147 |
+
self.assertAllEqual(out["empty_image"], np.array([[[1]]]))
|
| 148 |
+
for out in self.tfrun(pp.get_setdefault("image", 1), data):
|
| 149 |
+
self.assertAllEqual(out["image"], data["image"])
|
| 150 |
+
for out in self.tfrun(pp.get_setdefault("empty_text", "Lucas"), data):
|
| 151 |
+
self.assertAllEqual(out["empty_text"], np.array(["Lucas"]))
|
| 152 |
+
for out in self.tfrun(pp.get_setdefault("text", "Lucas"), data):
|
| 153 |
+
self.assertAllEqual(out["text"], data["text"])
|
| 154 |
+
|
| 155 |
+
def _data_for_choice(self):
|
| 156 |
+
return {
|
| 157 |
+
"one_f32": tf.constant([0.42], tf.float32),
|
| 158 |
+
"two_f32": tf.constant([3.14, 0.42], tf.float32),
|
| 159 |
+
"one_str": tf.constant(["Hi"], tf.string),
|
| 160 |
+
"two_str": tf.constant(["Hi", "Lucas"], tf.string),
|
| 161 |
+
"one_vec": tf.reshape(tf.range(2, dtype=tf.float32), (1, 2)),
|
| 162 |
+
"two_vec": tf.reshape(tf.range(4, dtype=tf.float32), (2, 2)),
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def test_choice(self):
|
| 166 |
+
# Test for the default call (n="single")
|
| 167 |
+
data = self._data_for_choice()
|
| 168 |
+
self.assertEqual(
|
| 169 |
+
pp.get_choice(inkey="one_f32", outkey="choice")(data)["choice"], 0.42)
|
| 170 |
+
self.assertEqual(
|
| 171 |
+
pp.get_choice(inkey="one_str", outkey="choice")(data)["choice"], "Hi")
|
| 172 |
+
self.assertIn(
|
| 173 |
+
pp.get_choice(inkey="two_f32", outkey="choice")(data)["choice"],
|
| 174 |
+
[3.14, 0.42])
|
| 175 |
+
self.assertIn(
|
| 176 |
+
pp.get_choice(inkey="two_str", outkey="choice")(data)["choice"],
|
| 177 |
+
["Hi", "Lucas"])
|
| 178 |
+
|
| 179 |
+
def test_choice_nmax(self):
|
| 180 |
+
# n == nelems should be identity (and keep ordering!)
|
| 181 |
+
data = self._data_for_choice()
|
| 182 |
+
for k in ("one_f32", "one_str", "one_vec"):
|
| 183 |
+
for out in self.tfrun(pp.get_choice(n=1, key=[k]), data):
|
| 184 |
+
self.assertAllEqual(out[k], data[k])
|
| 185 |
+
for out in self.tfrun(pp.get_choice(n=[1, 1], key=[k]), data):
|
| 186 |
+
self.assertAllEqual(out[k], data[k])
|
| 187 |
+
for k in ("two_f32", "two_str", "two_vec"):
|
| 188 |
+
for out in self.tfrun(pp.get_choice(n=2, key=[k]), data):
|
| 189 |
+
self.assertAllEqual(out[k], data[k])
|
| 190 |
+
for out in self.tfrun(pp.get_choice(n=[2, 2], key=[k]), data):
|
| 191 |
+
self.assertAllEqual(out[k], data[k])
|
| 192 |
+
|
| 193 |
+
def test_choice_n(self):
|
| 194 |
+
# n < nelems should be one of them:
|
| 195 |
+
data = self._data_for_choice()
|
| 196 |
+
for k in ("two_f32", "two_str"):
|
| 197 |
+
for out in self.tfrun(pp.get_choice(n=1, key=[k]), data):
|
| 198 |
+
self.assertIn(out[k], data[k])
|
| 199 |
+
|
| 200 |
+
# Special testing for vectors.
|
| 201 |
+
for out in self.tfrun(pp.get_choice(n=1, key=["two_vec"]), data):
|
| 202 |
+
self.assertTrue(tf.logical_or(
|
| 203 |
+
tf.reduce_all(out["two_vec"][0] == data["two_vec"][0]),
|
| 204 |
+
tf.reduce_all(out["two_vec"][0] == data["two_vec"][1]),
|
| 205 |
+
))
|
| 206 |
+
|
| 207 |
+
def test_choice_multi(self):
|
| 208 |
+
# Select consistently across multiple keys.
|
| 209 |
+
data = self._data_for_choice()
|
| 210 |
+
op = pp.get_choice(n=1, key=["two_f32", "two_str"])
|
| 211 |
+
for out in self.tfrun(op, data):
|
| 212 |
+
self.assertTrue(tf.logical_or(
|
| 213 |
+
tf.logical_and(
|
| 214 |
+
tf.reduce_all(out["two_f32"][0] == data["two_f32"][0]),
|
| 215 |
+
tf.reduce_all(out["two_str"][0] == data["two_str"][0]),
|
| 216 |
+
),
|
| 217 |
+
tf.logical_and(
|
| 218 |
+
tf.reduce_all(out["two_f32"][0] == data["two_f32"][1]),
|
| 219 |
+
tf.reduce_all(out["two_str"][0] == data["two_str"][1]),
|
| 220 |
+
),
|
| 221 |
+
))
|
| 222 |
+
|
| 223 |
+
def test_choice_n_range(self):
|
| 224 |
+
# n < nelems should be one of them:
|
| 225 |
+
data = self._data_for_choice()
|
| 226 |
+
for k in ("two_f32", "two_str", "two_vec"):
|
| 227 |
+
for out in self.tfrun(pp.get_choice(n=[1, 2], key=[k]), data):
|
| 228 |
+
self.assertTrue(tf.reduce_any([
|
| 229 |
+
tf.reduce_all(out[k] == data[k][0:1]),
|
| 230 |
+
tf.reduce_all(out[k] == data[k][1:2]),
|
| 231 |
+
tf.reduce_all(out[k] == data[k][0:2]),
|
| 232 |
+
]))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
tf.test.main()
|
Tipsomaly/model/big_vision/pp/ops_image.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Image-centric preprocessing ops.
|
| 16 |
+
|
| 17 |
+
All preprocessing ops should return a data processing functors. A data
|
| 18 |
+
is represented as a dictionary of (TF) tensors. The functors output a modified
|
| 19 |
+
dictionary.
|
| 20 |
+
|
| 21 |
+
The key named "image" is commonly used for the image, and is a 3D tensor of
|
| 22 |
+
shape (height x width x channels).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from big_vision.pp import utils
|
| 26 |
+
from big_vision.pp.registry import Registry
|
| 27 |
+
|
| 28 |
+
import tensorflow as tf
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@Registry.register("preprocess_ops.decode")
|
| 32 |
+
@utils.InKeyOutKey()
|
| 33 |
+
def get_decode(channels=3, precise=False):
|
| 34 |
+
"""Decode an encoded image string, see tf.io.decode_image.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
channels: see tf.io.decode_image.
|
| 38 |
+
precise: if False, use default TF image decoding algorithm.
|
| 39 |
+
If True, change DCT method for JPEG decoding to match PIL/cv2/PyTorch.
|
| 40 |
+
See also (internal link) for a concrete example.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
The decoded image.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def _decode(image):
|
| 47 |
+
if precise:
|
| 48 |
+
return tf.image.decode_jpeg( # Also supports png btw.
|
| 49 |
+
image, channels=channels, dct_method="INTEGER_ACCURATE")
|
| 50 |
+
else:
|
| 51 |
+
return tf.io.decode_image(
|
| 52 |
+
image, channels=channels, expand_animations=False)
|
| 53 |
+
|
| 54 |
+
return _decode
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@Registry.register("preprocess_ops.resize")
|
| 58 |
+
@utils.InKeyOutKey()
|
| 59 |
+
def get_resize(size, method="bilinear", antialias=False):
|
| 60 |
+
"""Resizes image to a given size.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
size: either an integer H, where H is both the new height and width
|
| 64 |
+
of the resized image, or a list or tuple [H, W] of integers, where H and W
|
| 65 |
+
are new image"s height and width respectively.
|
| 66 |
+
method: resize method, see tf.image.resize docs for options.
|
| 67 |
+
antialias: see tf.image.resize. Ideally set to True for all new configs.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
A function for resizing an image.
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
size = utils.maybe_repeat(size, 2)
|
| 74 |
+
|
| 75 |
+
def _resize(image):
|
| 76 |
+
"""Resizes image to a given size."""
|
| 77 |
+
# Note: use TF-2 version of tf.image.resize as the version in TF-1 is
|
| 78 |
+
# buggy: https://github.com/tensorflow/tensorflow/issues/6720.
|
| 79 |
+
# In particular it was not equivariant with rotation and lead to the network
|
| 80 |
+
# to learn a shortcut in self-supervised rotation task, if rotation was
|
| 81 |
+
# applied after resize.
|
| 82 |
+
dtype = image.dtype
|
| 83 |
+
tf_dtype = tf.type_spec_from_value(image).dtype
|
| 84 |
+
image = tf.image.resize(image, size, method=method, antialias=antialias)
|
| 85 |
+
return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype)
|
| 86 |
+
|
| 87 |
+
return _resize
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# This functionality is used by resize_small and resize_long. But we're not
|
| 91 |
+
# registering it as a pp op yet, as there is no need for it. However, it can
|
| 92 |
+
# probably be slightly generalized into "scale augmentation" eventually.
|
| 93 |
+
def _resize_factor(image, factor, method="area", antialias=True):
|
| 94 |
+
"""Resizes the image by a (float) `factor`, keeping the aspect ratio fixed."""
|
| 95 |
+
h, w = tf.shape(image)[0], tf.shape(image)[1]
|
| 96 |
+
|
| 97 |
+
h = tf.cast(tf.round(tf.cast(h, tf.float32) * factor), tf.int32)
|
| 98 |
+
w = tf.cast(tf.round(tf.cast(w, tf.float32) * factor), tf.int32)
|
| 99 |
+
|
| 100 |
+
dtype = image.dtype
|
| 101 |
+
tf_dtype = tf.type_spec_from_value(image).dtype
|
| 102 |
+
image = tf.image.resize(image, (h, w), method=method, antialias=antialias)
|
| 103 |
+
return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@Registry.register("preprocess_ops.resize_small")
|
| 107 |
+
@utils.InKeyOutKey()
|
| 108 |
+
def get_resize_small(smaller_size, method="area", antialias=False):
|
| 109 |
+
"""Resizes the smaller side to `smaller_size` keeping aspect ratio.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
smaller_size: an integer, that represents a new size of the smaller side of
|
| 113 |
+
an input image.
|
| 114 |
+
method: the resize method. `area` is a meaningful, bwd-compat default.
|
| 115 |
+
antialias: see tf.image.resize. Ideally set to True for all new configs.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
A function, that resizes an image and preserves its aspect ratio.
|
| 119 |
+
|
| 120 |
+
Note:
|
| 121 |
+
backwards-compat for "area"+antialias tested here:
|
| 122 |
+
(internal link)
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def _resize_small(image): # pylint: disable=missing-docstring
|
| 126 |
+
h, w = tf.shape(image)[0], tf.shape(image)[1]
|
| 127 |
+
factor = (
|
| 128 |
+
tf.cast(smaller_size, tf.float32) /
|
| 129 |
+
tf.cast(tf.minimum(h, w), tf.float32))
|
| 130 |
+
return _resize_factor(image, factor, method=method, antialias=antialias)
|
| 131 |
+
return _resize_small
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@Registry.register("preprocess_ops.resize_long")
|
| 135 |
+
@utils.InKeyOutKey()
|
| 136 |
+
def get_resize_long(longer_size, method="area", antialias=True):
|
| 137 |
+
"""Resizes the longer side to `longer_size` keeping aspect ratio.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
longer_size: an integer, that represents a new size of the longer side of
|
| 141 |
+
an input image.
|
| 142 |
+
method: the resize method. `area` is a meaningful, bwd-compat default.
|
| 143 |
+
antialias: see tf.image.resize. Ideally set to True for all new configs.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
A function, that resizes an image and preserves its aspect ratio.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def _resize_long(image): # pylint: disable=missing-docstring
|
| 150 |
+
h, w = tf.shape(image)[0], tf.shape(image)[1]
|
| 151 |
+
factor = (
|
| 152 |
+
tf.cast(longer_size, tf.float32) /
|
| 153 |
+
tf.cast(tf.maximum(h, w), tf.float32))
|
| 154 |
+
return _resize_factor(image, factor, method=method, antialias=antialias)
|
| 155 |
+
return _resize_long
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@Registry.register("preprocess_ops.inception_crop")
|
| 159 |
+
@utils.InKeyOutKey()
|
| 160 |
+
def get_inception_crop(size=None, area_min=5, area_max=100,
|
| 161 |
+
method="bilinear", antialias=False):
|
| 162 |
+
"""Makes inception-style image crop.
|
| 163 |
+
|
| 164 |
+
Inception-style crop is a random image crop (its size and aspect ratio are
|
| 165 |
+
random) that was used for training Inception models, see
|
| 166 |
+
https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
size: Resize image to [size, size] after crop.
|
| 170 |
+
area_min: minimal crop area.
|
| 171 |
+
area_max: maximal crop area.
|
| 172 |
+
method: rezied method, see tf.image.resize docs for options.
|
| 173 |
+
antialias: see tf.image.resize. Ideally set to True for all new configs.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
A function, that applies inception crop.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def _inception_crop(image): # pylint: disable=missing-docstring
|
| 180 |
+
begin, crop_size, _ = tf.image.sample_distorted_bounding_box(
|
| 181 |
+
tf.shape(image),
|
| 182 |
+
tf.zeros([0, 0, 4], tf.float32),
|
| 183 |
+
area_range=(area_min / 100, area_max / 100),
|
| 184 |
+
min_object_covered=0, # Don't enforce a minimum area.
|
| 185 |
+
use_image_if_no_bounding_boxes=True)
|
| 186 |
+
crop = tf.slice(image, begin, crop_size)
|
| 187 |
+
# Unfortunately, the above operation loses the depth-dimension. So we need
|
| 188 |
+
# to restore it the manual way.
|
| 189 |
+
crop.set_shape([None, None, image.shape[-1]])
|
| 190 |
+
if size:
|
| 191 |
+
crop = get_resize(size, method, antialias)({"image": crop})["image"]
|
| 192 |
+
return crop
|
| 193 |
+
|
| 194 |
+
return _inception_crop
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@Registry.register("preprocess_ops.decode_jpeg_and_inception_crop")
|
| 198 |
+
@utils.InKeyOutKey()
|
| 199 |
+
def get_decode_jpeg_and_inception_crop(size=None, area_min=5, area_max=100,
|
| 200 |
+
ratio_min=0.75, ratio_max=1.33,
|
| 201 |
+
method="bilinear", antialias=False):
|
| 202 |
+
"""Decode jpeg string and make inception-style image crop.
|
| 203 |
+
|
| 204 |
+
Inception-style crop is a random image crop (its size and aspect ratio are
|
| 205 |
+
random) that was used for training Inception models, see
|
| 206 |
+
https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
size: Resize image to [size, size] after crop.
|
| 210 |
+
area_min: minimal crop area.
|
| 211 |
+
area_max: maximal crop area.
|
| 212 |
+
ratio_min: minimal aspect ratio.
|
| 213 |
+
ratio_max: maximal aspect ratio.
|
| 214 |
+
method: rezied method, see tf.image.resize docs for options.
|
| 215 |
+
antialias: see tf.image.resize. Ideally set to True for all new configs.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
A function, that applies inception crop.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def _inception_crop(image_data): # pylint: disable=missing-docstring
|
| 222 |
+
shape = tf.image.extract_jpeg_shape(image_data)
|
| 223 |
+
begin, crop_size, _ = tf.image.sample_distorted_bounding_box(
|
| 224 |
+
shape,
|
| 225 |
+
tf.zeros([0, 0, 4], tf.float32),
|
| 226 |
+
area_range=(area_min / 100, area_max / 100),
|
| 227 |
+
aspect_ratio_range=(ratio_min, ratio_max),
|
| 228 |
+
min_object_covered=0, # Don't enforce a minimum area.
|
| 229 |
+
use_image_if_no_bounding_boxes=True)
|
| 230 |
+
|
| 231 |
+
# Crop the image to the specified bounding box.
|
| 232 |
+
offset_y, offset_x, _ = tf.unstack(begin)
|
| 233 |
+
target_height, target_width, _ = tf.unstack(crop_size)
|
| 234 |
+
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
| 235 |
+
image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3)
|
| 236 |
+
|
| 237 |
+
if size:
|
| 238 |
+
image = get_resize(size, method, antialias)({"image": image})["image"]
|
| 239 |
+
|
| 240 |
+
return image
|
| 241 |
+
|
| 242 |
+
return _inception_crop
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@Registry.register("preprocess_ops.random_crop")
|
| 246 |
+
@utils.InKeyOutKey()
|
| 247 |
+
def get_random_crop(crop_size):
|
| 248 |
+
"""Makes a random crop of a given size.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
crop_size: either an integer H, where H is both the height and width of the
|
| 252 |
+
random crop, or a list or tuple [H, W] of integers, where H and W are
|
| 253 |
+
height and width of the random crop respectively.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
A function, that applies random crop.
|
| 257 |
+
"""
|
| 258 |
+
crop_size = utils.maybe_repeat(crop_size, 2)
|
| 259 |
+
|
| 260 |
+
def _crop(image):
|
| 261 |
+
return tf.image.random_crop(image, (*crop_size, image.shape[-1]))
|
| 262 |
+
|
| 263 |
+
return _crop
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
@Registry.register("preprocess_ops.central_crop")
|
| 267 |
+
@utils.InKeyOutKey()
|
| 268 |
+
def get_central_crop(crop_size=None):
|
| 269 |
+
"""Makes central crop of a given size.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
crop_size: either an integer H, where H is both the height and width of the
|
| 273 |
+
central crop, or a list or tuple [H, W] of integers, where H and W are
|
| 274 |
+
height and width of the central crop respectively. If `crop_size` is not
|
| 275 |
+
specified, then the largest possible center crop will be taken.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
A function, that applies central crop.
|
| 279 |
+
"""
|
| 280 |
+
if crop_size:
|
| 281 |
+
crop_size = utils.maybe_repeat(crop_size, 2)
|
| 282 |
+
|
| 283 |
+
def _crop(image):
|
| 284 |
+
if crop_size:
|
| 285 |
+
h, w = crop_size[0], crop_size[1]
|
| 286 |
+
else:
|
| 287 |
+
h = w = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
|
| 288 |
+
dy = (tf.shape(image)[0] - h) // 2
|
| 289 |
+
dx = (tf.shape(image)[1] - w) // 2
|
| 290 |
+
return tf.image.crop_to_bounding_box(image, dy, dx, h, w)
|
| 291 |
+
|
| 292 |
+
return _crop
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@Registry.register("preprocess_ops.flip_lr")
|
| 296 |
+
@utils.InKeyOutKey()
|
| 297 |
+
def get_random_flip_lr():
|
| 298 |
+
"""Flips an image horizontally with probability 50%."""
|
| 299 |
+
|
| 300 |
+
def _random_flip_lr_pp(image):
|
| 301 |
+
return tf.image.random_flip_left_right(image)
|
| 302 |
+
|
| 303 |
+
return _random_flip_lr_pp
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@Registry.register("preprocess_ops.vgg_value_range")
|
| 307 |
+
@utils.InKeyOutKey()
|
| 308 |
+
def get_vgg_value_range(
|
| 309 |
+
mean=(0.485 * 255, 0.456 * 255, 0.406 * 255),
|
| 310 |
+
std=(0.229 * 255, 0.224 * 255, 0.225 * 255),
|
| 311 |
+
):
|
| 312 |
+
"""VGG-style preprocessing, subtracts mean and divides by stddev.
|
| 313 |
+
|
| 314 |
+
This preprocessing is very common for ImageNet pre-trained models since VGG,
|
| 315 |
+
and to this day the standard for models coming from most PyTorch codes.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
mean: Tuple of values to be subtracted. Default to widespread VGG values.
|
| 319 |
+
std: Tuple of values to be divided by. Default to widespread VGG values.
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
A function to rescale the values.
|
| 323 |
+
"""
|
| 324 |
+
mean = tf.constant(mean, tf.float32)
|
| 325 |
+
std = tf.constant(std, tf.float32)
|
| 326 |
+
|
| 327 |
+
def _vgg_value_range(image):
|
| 328 |
+
return (tf.cast(image, tf.float32) - mean) / std
|
| 329 |
+
return _vgg_value_range
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@Registry.register("preprocess_ops.clip_value_range")
|
| 333 |
+
@utils.InKeyOutKey()
|
| 334 |
+
def get_clip_value_range():
|
| 335 |
+
mean = (0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255)
|
| 336 |
+
std = (0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255)
|
| 337 |
+
|
| 338 |
+
def _clip_value_range(image):
|
| 339 |
+
return (tf.cast(image, tf.float32) - mean) / std
|
| 340 |
+
return _clip_value_range
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@Registry.register("preprocess_ops.convert_to_video")
|
| 344 |
+
@utils.InKeyOutKey()
|
| 345 |
+
def get_convert_to_video(num_frames):
|
| 346 |
+
"""Converts an image to a video with zero padded frames.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
num_frames: total number of frames that the video should have.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
A function for converting an image to a video.
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
def _convert_to_video(image):
|
| 356 |
+
return tf.pad(
|
| 357 |
+
tf.expand_dims(image, axis=0),
|
| 358 |
+
[[0, num_frames - 1], [0, 0], [0, 0], [0, 0]],
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
return _convert_to_video
|
Tipsomaly/model/big_vision/pp/ops_image_test.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for ops_image."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import io
|
| 19 |
+
|
| 20 |
+
import big_vision.pp.ops_image as pp
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_image_data():
|
| 27 |
+
img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) # Can't ask uint8!?
|
| 28 |
+
return {"image": tf.cast(img, tf.uint8)}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PreprocessOpsTest(tf.test.TestCase):
|
| 32 |
+
|
| 33 |
+
def tfrun(self, ppfn, data):
|
| 34 |
+
# Run once as standalone, as could happen eg in colab.
|
| 35 |
+
yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()}
|
| 36 |
+
|
| 37 |
+
# And then once again as part of tfdata pipeline.
|
| 38 |
+
# You'd be surprised how much these two differ!
|
| 39 |
+
tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data))
|
| 40 |
+
for npdata in tfdata.map(ppfn).as_numpy_iterator():
|
| 41 |
+
yield npdata
|
| 42 |
+
|
| 43 |
+
def test_resize(self):
|
| 44 |
+
for data in self.tfrun(pp.get_resize([120, 80]), get_image_data()):
|
| 45 |
+
self.assertEqual(data["image"].shape, (120, 80, 3))
|
| 46 |
+
|
| 47 |
+
def test_resize_small(self):
|
| 48 |
+
for data in self.tfrun(pp.get_resize_small(240), get_image_data()):
|
| 49 |
+
self.assertEqual(data["image"].shape, (320, 240, 3))
|
| 50 |
+
|
| 51 |
+
def test_resize_long(self):
|
| 52 |
+
for data in self.tfrun(pp.get_resize_long(320), get_image_data()):
|
| 53 |
+
self.assertEqual(data["image"].shape, (320, 240, 3))
|
| 54 |
+
|
| 55 |
+
def test_inception_crop(self):
|
| 56 |
+
for data in self.tfrun(pp.get_inception_crop(), get_image_data()):
|
| 57 |
+
self.assertEqual(data["image"].shape[-1], 3)
|
| 58 |
+
|
| 59 |
+
def test_decode_jpeg_and_inception_crop(self):
|
| 60 |
+
f = io.BytesIO()
|
| 61 |
+
plt.imsave(f, get_image_data()["image"].numpy(), format="jpg")
|
| 62 |
+
data = {"image": tf.cast(f.getvalue(), tf.string)}
|
| 63 |
+
for data in self.tfrun(pp.get_decode_jpeg_and_inception_crop(), data):
|
| 64 |
+
self.assertEqual(data["image"].shape[-1], 3)
|
| 65 |
+
|
| 66 |
+
def test_random_crop(self):
|
| 67 |
+
for data in self.tfrun(pp.get_random_crop([120, 80]), get_image_data()):
|
| 68 |
+
self.assertEqual(data["image"].shape, (120, 80, 3))
|
| 69 |
+
|
| 70 |
+
def test_central_crop(self):
|
| 71 |
+
for data in self.tfrun(pp.get_central_crop([20, 80]), get_image_data()):
|
| 72 |
+
self.assertEqual(data["image"].shape, (20, 80, 3))
|
| 73 |
+
|
| 74 |
+
def test_random_flip_lr(self):
|
| 75 |
+
data_orig = get_image_data()
|
| 76 |
+
for data in self.tfrun(pp.get_random_flip_lr(), data_orig):
|
| 77 |
+
self.assertTrue(
|
| 78 |
+
np.all(data_orig["image"].numpy() == data["image"]) or
|
| 79 |
+
np.all(data_orig["image"].numpy() == data["image"][:, ::-1]))
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
tf.test.main()
|
Tipsomaly/model/big_vision/pp/ops_text.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Text-centric preprocessing ops.
|
| 16 |
+
|
| 17 |
+
All preprocessing ops should return a data processing functors. A data
|
| 18 |
+
is represented as a dictionary of (TF) tensors. The functors output a modified
|
| 19 |
+
dictionary.
|
| 20 |
+
|
| 21 |
+
A commonly used key for the tokenized output is "labels".
|
| 22 |
+
"""
|
| 23 |
+
import functools
|
| 24 |
+
import importlib
|
| 25 |
+
import string
|
| 26 |
+
|
| 27 |
+
from absl import logging
|
| 28 |
+
from big_vision.datasets.imagenet import class_names as imagenet_class_names
|
| 29 |
+
from big_vision.pp import ops_general
|
| 30 |
+
from big_vision.pp import tokenizer as bv_tok
|
| 31 |
+
from big_vision.pp import utils
|
| 32 |
+
from big_vision.pp.registry import Registry
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
|
| 35 |
+
from tensorflow.io import gfile
|
| 36 |
+
|
| 37 |
+
import sentencepiece
|
| 38 |
+
SPProcessor = sentencepiece.SentencePieceProcessor
|
| 39 |
+
|
| 40 |
+
import os
|
| 41 |
+
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
| 42 |
+
import sentencepiece.sentencepiece_model_pb2
|
| 43 |
+
del os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION']
|
| 44 |
+
SPModelProto = sentencepiece.sentencepiece_model_pb2.ModelProto
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# TODO: b/lbeyer - softly introduce and move to new tokenizer API.
|
| 48 |
+
|
| 49 |
+
KNOWN_TOKENIZERS = {
|
| 50 |
+
"mc4": # used in multilingual models (mT5, PaLI), vocab_size=250_000
|
| 51 |
+
"gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
|
| 52 |
+
"cc_all": # vocab_size=32_000
|
| 53 |
+
"gs://t5-data/vocabs/cc_all.32000/sentencepiece.model",
|
| 54 |
+
"c4_en": # vocab_size=32_000
|
| 55 |
+
"gs://t5-data/vocabs/cc_en.32000/sentencepiece.model",
|
| 56 |
+
"t5": # same as cc_all, but with 100 extra dummy tokens used by T5 models
|
| 57 |
+
"gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model",
|
| 58 |
+
"mt5": # same as mc4, but with 100 extra dummy tokens used by T5 models
|
| 59 |
+
"gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def create_tokenizer(model="c4_en", add_eos=True, add_bos=False):
|
| 64 |
+
"""Creates a tokenizer which can be used in tfds."""
|
| 65 |
+
logging.info("Creating tokenizer: %s", model)
|
| 66 |
+
with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f:
|
| 67 |
+
model = f.read()
|
| 68 |
+
|
| 69 |
+
# Lazy import of tensorflow_text so it is an optional dependency for
|
| 70 |
+
# the users of this file.
|
| 71 |
+
import tensorflow_text
|
| 72 |
+
return tensorflow_text.SentencepieceTokenizer(
|
| 73 |
+
model=model, add_eos=add_eos, add_bos=add_bos
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def tokenize(input_text, tokenizer, max_len, *, pad_value, force_eos,
|
| 78 |
+
multi_text=False):
|
| 79 |
+
"""Tokenizes string, and adds `pad_value` if longer than `max_len`."""
|
| 80 |
+
|
| 81 |
+
def pad(tokens):
|
| 82 |
+
# Truncate/pad to max_len.
|
| 83 |
+
if force_eos:
|
| 84 |
+
tokens = tf.cond(
|
| 85 |
+
tf.shape(tokens)[0] >= max_len,
|
| 86 |
+
lambda: tf.concat(
|
| 87 |
+
# For too long, cut them off, but do keep the final EOS token.
|
| 88 |
+
[tokens[:max_len - 1], tokens[-1:]], axis=0),
|
| 89 |
+
lambda: tf.pad(
|
| 90 |
+
tokens, [(0, max_len - tf.shape(tokens)[0])],
|
| 91 |
+
constant_values=pad_value),
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
tokens = tokens[:max_len]
|
| 95 |
+
tokens = tf.pad(
|
| 96 |
+
tokens, [(0, max_len - tf.shape(tokens)[0])],
|
| 97 |
+
constant_values=pad_value)
|
| 98 |
+
tokens.set_shape([max_len])
|
| 99 |
+
return tokens
|
| 100 |
+
|
| 101 |
+
tokens = tokenizer.tokenize(input_text)
|
| 102 |
+
|
| 103 |
+
if multi_text:
|
| 104 |
+
tokens = tokens.to_tensor(pad_value) # tf.RaggedTensor to tf.Tensor
|
| 105 |
+
tokens = tf.reshape(tokens, [-1, tf.shape(tokens)[-1]])
|
| 106 |
+
tokens = tf.map_fn(pad, tokens) # `map_fn` only maps on axis 0
|
| 107 |
+
|
| 108 |
+
final_shape = tf.concat([tf.shape(input_text), [max_len]], axis=0)
|
| 109 |
+
return tf.reshape(tokens, final_shape)
|
| 110 |
+
else:
|
| 111 |
+
return pad(tokens)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@Registry.register("preprocess_ops.tokenize")
|
| 115 |
+
@utils.InKeyOutKey(indefault=None, outdefault="labels")
|
| 116 |
+
def get_pp_tokenize(
|
| 117 |
+
max_len,
|
| 118 |
+
eos,
|
| 119 |
+
model="c4_en",
|
| 120 |
+
lower=True,
|
| 121 |
+
sample_if_multi=True,
|
| 122 |
+
pad_value="<pad>",
|
| 123 |
+
add_bos=False
|
| 124 |
+
):
|
| 125 |
+
"""Tokenizes a text.
|
| 126 |
+
|
| 127 |
+
Let's assume max_len=3 and id("</s>")=1, id("a")=2, then we have
|
| 128 |
+
|
| 129 |
+
1. `eos="none", pad_value=0`:
|
| 130 |
+
- "a" -> [2, 0, 0]
|
| 131 |
+
- "aa" -> [2, 2, 0]
|
| 132 |
+
- "aaa" -> [2, 2, 2]
|
| 133 |
+
|
| 134 |
+
2. `eos="yes", pad_value=0`:
|
| 135 |
+
- "a" -> [2, 1, 0]
|
| 136 |
+
- "aa" -> [2, 2, 1]
|
| 137 |
+
- "aaa" -> [2, 2, 2]
|
| 138 |
+
|
| 139 |
+
This is usually used with generative models that need to learn when to
|
| 140 |
+
properly predict a "</s>" (when the sentence is finished) and when to
|
| 141 |
+
abstain (when the sentence is truncated).
|
| 142 |
+
|
| 143 |
+
3. `eos="sticky", pad_value=0`:
|
| 144 |
+
- "a" -> [2, 1, 0]
|
| 145 |
+
- "aa" -> [2, 2, 1]
|
| 146 |
+
- "aaa" -> [2, 2, 1]
|
| 147 |
+
|
| 148 |
+
4. `eos="sticky", pad_value=1`:
|
| 149 |
+
- "a" -> [2, 1, 1]
|
| 150 |
+
- "aa" -> [2, 2, 1]
|
| 151 |
+
- "aaa" -> [2, 2, 1]
|
| 152 |
+
|
| 153 |
+
This is traditionally used with contrastive models that use the last token
|
| 154 |
+
for embeddings, similarly to "cls" tokens in BERT-style models.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
max_len: maximum length of the tokenized text.
|
| 158 |
+
eos: Whether to add an "</s>" (end of sentence) token and whether to keep it
|
| 159 |
+
when the sequence is longer than `max_len - 1`. See examples above for
|
| 160 |
+
details. Valid values: "none", "yes", "sticky".
|
| 161 |
+
model: a path to the pretrained sentencepiece model.
|
| 162 |
+
lower: lowercase the text before tokenizing.
|
| 163 |
+
sample_if_multi: If there's more than one, randomly pick one if this is
|
| 164 |
+
True; otherwise pick all texts and keep the input's batch shape in result.
|
| 165 |
+
pad_value: which token to pad the sequence with. If a string (for example
|
| 166 |
+
`"<pad>"`), tokenize it and use its first token. Note that there is no
|
| 167 |
+
guarantee to have any padding at the end of the sentence, if the sentence
|
| 168 |
+
is longer than `max_len`.
|
| 169 |
+
add_bos: adds beginning of sentence symbol.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
an op that outputs tokenized text.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
if eos not in ("yes", "none", "sticky"):
|
| 176 |
+
raise ValueError(f"Invalid value for eos: '{eos}'.")
|
| 177 |
+
|
| 178 |
+
tokenizer = create_tokenizer(model, add_eos=eos != "none", add_bos=add_bos)
|
| 179 |
+
|
| 180 |
+
if isinstance(pad_value, str):
|
| 181 |
+
pad_value = tokenizer.string_to_id(pad_value)
|
| 182 |
+
|
| 183 |
+
def _pp_tokenize(txt):
|
| 184 |
+
if sample_if_multi and tf.convert_to_tensor(txt).ndim:
|
| 185 |
+
# TODO: I wish this code-path could die.
|
| 186 |
+
logging.warning("sample_if_multi is deprecated and will be removed."
|
| 187 |
+
"Call `choice` (and maybe `setdefault`) instead.")
|
| 188 |
+
txt = ops_general.get_choice(key="t")(
|
| 189 |
+
ops_general.get_setdefault("t", "")({"t": txt}))["t"]
|
| 190 |
+
|
| 191 |
+
if lower:
|
| 192 |
+
txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn(
|
| 193 |
+
tf.strings.lower, txt)
|
| 194 |
+
|
| 195 |
+
return tokenize(
|
| 196 |
+
txt,
|
| 197 |
+
tokenizer,
|
| 198 |
+
max_len,
|
| 199 |
+
pad_value=pad_value,
|
| 200 |
+
force_eos=eos == "sticky",
|
| 201 |
+
multi_text=not sample_if_multi)
|
| 202 |
+
|
| 203 |
+
return _pp_tokenize
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@Registry.register("preprocess_ops.coco_captions")
|
| 207 |
+
def get_coco_captions(outkey="captions"):
|
| 208 |
+
"""Extracts coco's captions from nested dict."""
|
| 209 |
+
|
| 210 |
+
def _pp_coco_captions(data):
|
| 211 |
+
data[outkey] = data["captions"]["text"]
|
| 212 |
+
return data
|
| 213 |
+
|
| 214 |
+
return _pp_coco_captions
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@Registry.register("preprocess_ops.clip_i1k_label_names")
|
| 218 |
+
@utils.InKeyOutKey(indefault="label", outdefault="labels")
|
| 219 |
+
def get_pp_clip_i1k_label_names():
|
| 220 |
+
"""Convert i1k label numbers to strings, using CLIP's class names."""
|
| 221 |
+
|
| 222 |
+
def _pp_imagenet_labels(label):
|
| 223 |
+
return tf.gather(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, label)
|
| 224 |
+
|
| 225 |
+
return _pp_imagenet_labels
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@Registry.register("preprocess_ops.i21k_label_names")
|
| 229 |
+
@utils.InKeyOutKey(indefault="label", outdefault="labels")
|
| 230 |
+
def get_pp_i21k_label_names():
|
| 231 |
+
"""Converts i21k label ids to strings."""
|
| 232 |
+
|
| 233 |
+
def _pp_imagenet_labels(label):
|
| 234 |
+
return tf.gather(imagenet_class_names.IMAGENET21k_CLASS_NAMES, label)
|
| 235 |
+
|
| 236 |
+
return _pp_imagenet_labels
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@Registry.register("preprocess_ops.lower")
|
| 240 |
+
@utils.InKeyOutKey(indefault="text", outdefault="text")
|
| 241 |
+
def get_lower():
|
| 242 |
+
"""Lowercases text feature."""
|
| 243 |
+
|
| 244 |
+
def _pp_lower(text):
|
| 245 |
+
return tf.strings.lower(text)
|
| 246 |
+
|
| 247 |
+
return _pp_lower
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@Registry.register("preprocess_ops.strfmt")
|
| 251 |
+
def get_strfmt(template, outkey="text"):
|
| 252 |
+
"""Formats a string template with content form the data dict."""
|
| 253 |
+
|
| 254 |
+
def _template(data):
|
| 255 |
+
outputs = []
|
| 256 |
+
parts = string.Formatter().parse(template)
|
| 257 |
+
for (literal_text, field_name, format_spec, conversion) in parts:
|
| 258 |
+
# For now, we keep it simple and don't support fancy format specs.
|
| 259 |
+
# But we can add support to that via py_func as soon as we need it.
|
| 260 |
+
assert not format_spec and not conversion
|
| 261 |
+
outputs.append(tf.constant(literal_text))
|
| 262 |
+
if field_name:
|
| 263 |
+
value = data[field_name]
|
| 264 |
+
# Convert any non-strings (numbers, vectors) to a string.
|
| 265 |
+
if tf.convert_to_tensor(value).dtype != tf.string:
|
| 266 |
+
value = tf.strings.format("{}", value, summarize=-1)
|
| 267 |
+
outputs.append(value)
|
| 268 |
+
data[outkey] = tf.strings.join(outputs)
|
| 269 |
+
return data
|
| 270 |
+
|
| 271 |
+
return _template
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _add_pieces(model_bytes, extra_pieces):
|
| 275 |
+
"""Adds extra pieces to sentencpiece model specified by `model_bytes`."""
|
| 276 |
+
|
| 277 |
+
model = SPProcessor()
|
| 278 |
+
model.LoadFromSerializedProto(model_bytes)
|
| 279 |
+
unk_idx = model.PieceToId("<unk>")
|
| 280 |
+
assert model.IdToPiece(unk_idx) == "<unk>", model.IdToPiece(unk_idx)
|
| 281 |
+
|
| 282 |
+
model_proto = SPModelProto.FromString(model_bytes)
|
| 283 |
+
idx_to_updated_piece = {}
|
| 284 |
+
for piece in extra_pieces:
|
| 285 |
+
# The SentencePieceModel proto stores whitespaces as the special
|
| 286 |
+
# character '▁'. We perform the conversion here.
|
| 287 |
+
piece = piece.replace(" ", "▁")
|
| 288 |
+
spiece = model_proto.SentencePiece(
|
| 289 |
+
piece=piece,
|
| 290 |
+
# We set the highest score to force priority on user defined tokens.
|
| 291 |
+
score=0.0,
|
| 292 |
+
type=model_proto.SentencePiece().Type.USER_DEFINED,
|
| 293 |
+
)
|
| 294 |
+
existing_idx = model.PieceToId(piece)
|
| 295 |
+
if (existing_idx != unk_idx) ^ (piece == "<unk>"):
|
| 296 |
+
idx_to_updated_piece[existing_idx] = spiece
|
| 297 |
+
logging.info("Updating token at idx %d: %s", existing_idx, spiece.piece)
|
| 298 |
+
else:
|
| 299 |
+
model_proto.pieces.append(spiece)
|
| 300 |
+
|
| 301 |
+
# Replace duplicated pieces with updated ones.
|
| 302 |
+
updated_pieces = [
|
| 303 |
+
idx_to_updated_piece.get(i, piece)
|
| 304 |
+
for i, piece in enumerate(model_proto.pieces)
|
| 305 |
+
]
|
| 306 |
+
del model_proto.pieces[:]
|
| 307 |
+
model_proto.pieces.extend(updated_pieces)
|
| 308 |
+
|
| 309 |
+
return model_proto.SerializeToString()
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _iterable(x):
|
| 313 |
+
if isinstance(x, tf.RaggedTensor):
|
| 314 |
+
return True
|
| 315 |
+
if getattr(x, "ndim", 0) > 1: # np, jnp
|
| 316 |
+
return True
|
| 317 |
+
if isinstance(x, (list, tuple)) and not isinstance(x[0], (int, float)):
|
| 318 |
+
return True
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@Registry.register("tokenizers.sp")
|
| 323 |
+
class SentencepieceTokenizer(bv_tok.Tokenizer):
|
| 324 |
+
"""Wraps a `tftext.SentencepieceTokenizer`.
|
| 325 |
+
|
| 326 |
+
If you plan to use this tokenizer, please familiarize yourself with the test
|
| 327 |
+
cases first. This is likely to save you a lot of troubles down the road, trust
|
| 328 |
+
me!
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, model, tokensets=()):
|
| 332 |
+
with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f:
|
| 333 |
+
model_bytes = f.read()
|
| 334 |
+
extras = bv_tok.get_extra_tokens(tokensets)
|
| 335 |
+
model_bytes = _add_pieces(model_bytes, extras)
|
| 336 |
+
self._tok_sp = SPProcessor()
|
| 337 |
+
self._tok_sp.LoadFromSerializedProto(model_bytes)
|
| 338 |
+
self.extras = {self._tok_sp.PieceToId(x): x for x in extras}
|
| 339 |
+
|
| 340 |
+
def to_int(self, text, *, bos=False, eos=False):
|
| 341 |
+
def _single(s):
|
| 342 |
+
return (
|
| 343 |
+
([self.bos_token] if bos else []) +
|
| 344 |
+
self._tok_sp.EncodeAsIds(s) +
|
| 345 |
+
([self.eos_token] if eos else [])
|
| 346 |
+
)
|
| 347 |
+
if isinstance(text, str):
|
| 348 |
+
return _single(text)
|
| 349 |
+
return type(text)([_single(s) for s in text])
|
| 350 |
+
|
| 351 |
+
def to_str(self, tokens, *, stop_at_eos=True):
|
| 352 |
+
def _single(toks):
|
| 353 |
+
toks = [int(t) for t in toks] # We really need this for DecodeIds.
|
| 354 |
+
if stop_at_eos:
|
| 355 |
+
try: # The SentencePiece strips eos, but does not stop at it, so we do.
|
| 356 |
+
toks = toks[:toks.index(self.eos_token)]
|
| 357 |
+
except ValueError: # No eos token found, nothing to do.
|
| 358 |
+
pass
|
| 359 |
+
return self._tok_sp.DecodeIds(toks)
|
| 360 |
+
if _iterable(tokens):
|
| 361 |
+
return [_single(toks) for toks in tokens]
|
| 362 |
+
return _single(tokens)
|
| 363 |
+
|
| 364 |
+
def _check_known(self, piece):
|
| 365 |
+
if (id_ := self._tok_sp.PieceToId(piece)) == self._tok_sp.unk_id():
|
| 366 |
+
logging.error("Piece '%s' is not known (unk=%s)!", piece, id_)
|
| 367 |
+
return id_
|
| 368 |
+
|
| 369 |
+
def to_piece(self, idx):
|
| 370 |
+
return self._tok_sp.IdToPiece(int(idx))
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def pad_token(self):
|
| 374 |
+
return self._tok_sp.pad_id()
|
| 375 |
+
|
| 376 |
+
@property
|
| 377 |
+
def eos_token(self):
|
| 378 |
+
return self._tok_sp.eos_id()
|
| 379 |
+
|
| 380 |
+
@property
|
| 381 |
+
def bos_token(self):
|
| 382 |
+
return self._tok_sp.bos_id()
|
| 383 |
+
|
| 384 |
+
@property
|
| 385 |
+
def vocab_size(self):
|
| 386 |
+
return self._tok_sp.GetPieceSize()
|
| 387 |
+
|
| 388 |
+
# For the _tf_op variants, we need a lot of wrapping boilerplate.
|
| 389 |
+
|
| 390 |
+
def to_int_tf_op(self, text, *, bos=False, eos=False):
|
| 391 |
+
text = tf.convert_to_tensor(text)
|
| 392 |
+
if text.ndim == 0:
|
| 393 |
+
def fn(txt):
|
| 394 |
+
s = txt.numpy().decode()
|
| 395 |
+
return tf.constant(self.to_int(s, bos=bos, eos=eos), tf.int32)
|
| 396 |
+
return tf.py_function(fn, [text], tf.int32)
|
| 397 |
+
else:
|
| 398 |
+
def fn(txt):
|
| 399 |
+
strings = [s.decode() for s in txt.numpy().tolist()]
|
| 400 |
+
toks = self.to_int(strings, bos=bos, eos=eos)
|
| 401 |
+
return tf.ragged.constant(toks)
|
| 402 |
+
out_type = tf.RaggedTensorSpec([tf.shape(text)[0], None], tf.int32)
|
| 403 |
+
return tf.py_function(fn, [text], Tout=out_type)
|
| 404 |
+
|
| 405 |
+
def to_str_tf_op(self, tokens, *, stop_at_eos=True):
|
| 406 |
+
def single(t):
|
| 407 |
+
fn = functools.partial(self.to_str, stop_at_eos=stop_at_eos)
|
| 408 |
+
return tf.numpy_function(fn, [t], tf.string, stateful=False)
|
| 409 |
+
if _iterable(tokens):
|
| 410 |
+
return tf.map_fn(single, tokens, tf.string)
|
| 411 |
+
return single(tokens)
|
Tipsomaly/model/big_vision/pp/ops_text_test.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for ops_text."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
|
| 19 |
+
from absl.testing import parameterized
|
| 20 |
+
import big_vision.pp.ops_text as pp
|
| 21 |
+
from big_vision.pp.registry import Registry
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PyToTfWrapper:
|
| 27 |
+
"""Allows to use `to_{int,str}_tf()` via `to_{int,str}()`."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, tok):
|
| 30 |
+
self.tok = tok
|
| 31 |
+
self.bos_token = tok.bos_token
|
| 32 |
+
self.eos_token = tok.eos_token
|
| 33 |
+
self.vocab_size = tok.vocab_size
|
| 34 |
+
|
| 35 |
+
def to_int(self, text, *, bos=False, eos=False):
|
| 36 |
+
ret = self.tok.to_int_tf_op(text, bos=bos, eos=eos)
|
| 37 |
+
if isinstance(ret, tf.RaggedTensor):
|
| 38 |
+
return [t.numpy().tolist() for t in ret]
|
| 39 |
+
return ret.numpy().tolist()
|
| 40 |
+
|
| 41 |
+
def to_str(self, tokens, stop_at_eos=True):
|
| 42 |
+
ret = self.tok.to_str_tf_op(
|
| 43 |
+
tf.ragged.constant(tokens),
|
| 44 |
+
stop_at_eos=stop_at_eos,
|
| 45 |
+
)
|
| 46 |
+
if ret.ndim == 0:
|
| 47 |
+
return ret.numpy().decode()
|
| 48 |
+
return [t.numpy().decode() for t in ret]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PpOpsTest(tf.test.TestCase, parameterized.TestCase):
|
| 52 |
+
|
| 53 |
+
def tfrun(self, ppfn, data):
|
| 54 |
+
# Run once as standalone, as could happen eg in colab.
|
| 55 |
+
yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()}
|
| 56 |
+
|
| 57 |
+
# And then once again as part of tfdata pipeline.
|
| 58 |
+
# You'd be surprised how much these two differ!
|
| 59 |
+
tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data))
|
| 60 |
+
for npdata in tfdata.map(ppfn).as_numpy_iterator():
|
| 61 |
+
yield npdata
|
| 62 |
+
|
| 63 |
+
def testtok(self):
|
| 64 |
+
# https://github.com/google/sentencepiece/blob/master/python/test/test_model.model
|
| 65 |
+
return "test_model.model" # Should we just commit it? It's 200kB
|
| 66 |
+
|
| 67 |
+
def test_get_pp_clip_i1k_label_names(self):
|
| 68 |
+
op = pp.get_pp_clip_i1k_label_names()
|
| 69 |
+
labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist()
|
| 70 |
+
self.assertAllEqual(labels, ["tench", "goldfish"])
|
| 71 |
+
|
| 72 |
+
def test_get_pp_i21k_label_names(self):
|
| 73 |
+
op = pp.get_pp_i21k_label_names()
|
| 74 |
+
labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist()
|
| 75 |
+
self.assertAllEqual(labels, ["organism", "benthos"])
|
| 76 |
+
|
| 77 |
+
@parameterized.parameters((b"Hello world ScAlAr!", b"hello world scalar!"),
|
| 78 |
+
(["Decoded Array!"], ["decoded array!"]),
|
| 79 |
+
([b"aA", "bB"], [b"aa", "bb"]))
|
| 80 |
+
def test_get_lower(self, inputs, expected_output):
|
| 81 |
+
op = pp.get_lower()
|
| 82 |
+
out = op({"text": tf.constant(inputs)})
|
| 83 |
+
self.assertAllEqual(out["text"].numpy(), np.array(expected_output))
|
| 84 |
+
|
| 85 |
+
@parameterized.named_parameters(
|
| 86 |
+
("py", False),
|
| 87 |
+
("tf", True),
|
| 88 |
+
)
|
| 89 |
+
def test_sentencepiece_tokenizer(self, wrap_tok):
|
| 90 |
+
tok = pp.SentencepieceTokenizer(self.testtok())
|
| 91 |
+
if wrap_tok:
|
| 92 |
+
tok = PyToTfWrapper(tok)
|
| 93 |
+
self.assertEqual(tok.vocab_size, 1000)
|
| 94 |
+
bos, eos = tok.bos_token, tok.eos_token
|
| 95 |
+
self.assertEqual(bos, 1)
|
| 96 |
+
self.assertEqual(eos, 2)
|
| 97 |
+
# Note: test model does NOT have a <pad> token (similar to e.g. "mistral").
|
| 98 |
+
# `.to_int()` wraps `.to_int_tf_ops` which is thus also tested
|
| 99 |
+
self.assertEqual(tok.to_int("blah"), [80, 180, 60])
|
| 100 |
+
self.assertEqual(tok.to_int("blah", bos=True), [bos, 80, 180, 60])
|
| 101 |
+
self.assertEqual(tok.to_int("blah", eos=True), [80, 180, 60, eos])
|
| 102 |
+
self.assertEqual(
|
| 103 |
+
tok.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos]
|
| 104 |
+
)
|
| 105 |
+
self.assertEqual(
|
| 106 |
+
tok.to_int(["blah", "blah blah"]),
|
| 107 |
+
[[80, 180, 60], [80, 180, 60, 80, 180, 60]],
|
| 108 |
+
)
|
| 109 |
+
# inverse of above
|
| 110 |
+
# `.to_str()` wraps `.to_str_tf_ops` which is thus also tested
|
| 111 |
+
self.assertEqual(tok.to_str([80, 180, 60]), "blah")
|
| 112 |
+
self.assertEqual(tok.to_str([1, 80, 180, 60]), "blah")
|
| 113 |
+
self.assertEqual(tok.to_str([80, 180, 60, 2]), "blah")
|
| 114 |
+
self.assertEqual(
|
| 115 |
+
tok.to_str([[80, 180, 60], [80, 180, 60, 80, 180, 60]]),
|
| 116 |
+
["blah", "blah blah"],
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def test_sentencepiece_tokenizer_tf_op_ndarray_input(self):
|
| 120 |
+
tok = pp.SentencepieceTokenizer(self.testtok())
|
| 121 |
+
bos, eos = tok.bos_token, tok.eos_token
|
| 122 |
+
arr = np.array([[bos, 80, 180, 60, eos]] * 2, dtype=np.int32)
|
| 123 |
+
self.assertEqual(tok.to_str_tf_op(arr).numpy().tolist(), [b"blah"] * 2)
|
| 124 |
+
|
| 125 |
+
def test_sentencepiece_tokenizer_tokensets(self):
|
| 126 |
+
tok = pp.SentencepieceTokenizer(self.testtok(), tokensets=["loc"])
|
| 127 |
+
self.assertEqual(tok.vocab_size, 2024)
|
| 128 |
+
self.assertEqual(
|
| 129 |
+
tok.to_int("blah<loc0000><loc1023>"), [80, 180, 60, 1000, 2023]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def test_sentencepiece_stop_at_eos(self):
|
| 133 |
+
tok = pp.SentencepieceTokenizer(self.testtok())
|
| 134 |
+
self.assertEqual(tok.to_str([80, 180, 60], stop_at_eos=False), "blah")
|
| 135 |
+
eos = tok.eos_token
|
| 136 |
+
self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=False), "blah")
|
| 137 |
+
self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=True), "b")
|
| 138 |
+
self.assertEqual(
|
| 139 |
+
tok.to_str([[80, eos, 180, 60], [80, 180, eos, 60]], stop_at_eos=True),
|
| 140 |
+
["b", "bla"]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def test_sentencepiece_extra_tokens(self):
|
| 144 |
+
tok = pp.SentencepieceTokenizer(self.testtok())
|
| 145 |
+
self.assertEqual(tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "blah")
|
| 146 |
+
tok = pp.SentencepieceTokenizer(
|
| 147 |
+
self.testtok(), tokensets=["sp_extra_tokens"]
|
| 148 |
+
)
|
| 149 |
+
self.assertEqual(tok.vocab_size, 1001) # Also added the <pad> token.
|
| 150 |
+
self.assertEqual(
|
| 151 |
+
tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "<s> blah</s>"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def test_strfmt(self):
|
| 155 |
+
data = {
|
| 156 |
+
"int": tf.constant(42, tf.uint8),
|
| 157 |
+
"float": tf.constant(3.14, tf.float32),
|
| 158 |
+
"vec": tf.range(3),
|
| 159 |
+
"empty_str": tf.constant(""),
|
| 160 |
+
"regex_problem1": tf.constant(r"no \replace pattern"),
|
| 161 |
+
"regex_problem2": tf.constant(r"yes \1 pattern"),
|
| 162 |
+
}
|
| 163 |
+
for out in self.tfrun(pp.get_strfmt("Nothing"), data):
|
| 164 |
+
self.assertEqual(out["text"], b"Nothing")
|
| 165 |
+
for out in self.tfrun(pp.get_strfmt("{int}"), data):
|
| 166 |
+
self.assertEqual(out["text"], b"42")
|
| 167 |
+
for out in self.tfrun(pp.get_strfmt("A{int}"), data):
|
| 168 |
+
self.assertEqual(out["text"], b"A42")
|
| 169 |
+
for out in self.tfrun(pp.get_strfmt("{int}A"), data):
|
| 170 |
+
self.assertEqual(out["text"], b"42A")
|
| 171 |
+
for out in self.tfrun(pp.get_strfmt("{int}{int}"), data):
|
| 172 |
+
self.assertEqual(out["text"], b"4242")
|
| 173 |
+
for out in self.tfrun(pp.get_strfmt("A{int}A{int}A"), data):
|
| 174 |
+
self.assertEqual(out["text"], b"A42A42A")
|
| 175 |
+
for out in self.tfrun(pp.get_strfmt("A{float}A"), data):
|
| 176 |
+
self.assertEqual(out["text"], b"A3.14A")
|
| 177 |
+
for out in self.tfrun(pp.get_strfmt("A{float}A{int}"), data):
|
| 178 |
+
self.assertEqual(out["text"], b"A3.14A42")
|
| 179 |
+
for out in self.tfrun(pp.get_strfmt("A{vec}A"), data):
|
| 180 |
+
self.assertEqual(out["text"], b"A[0 1 2]A")
|
| 181 |
+
for out in self.tfrun(pp.get_strfmt("A{empty_str}A"), data):
|
| 182 |
+
self.assertEqual(out["text"], b"AA")
|
| 183 |
+
for out in self.tfrun(pp.get_strfmt("{empty_str}"), data):
|
| 184 |
+
self.assertEqual(out["text"], b"")
|
| 185 |
+
for out in self.tfrun(pp.get_strfmt("A{regex_problem1}A"), data):
|
| 186 |
+
self.assertEqual(out["text"], br"Ano \replace patternA")
|
| 187 |
+
for out in self.tfrun(pp.get_strfmt("A{regex_problem2}A"), data):
|
| 188 |
+
self.assertEqual(out["text"], br"Ayes \1 patternA")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@Registry.register("tokensets.sp_extra_tokens")
|
| 192 |
+
def _get_sp_extra_tokens():
|
| 193 |
+
# For sentencepiece, adding these tokens will make them visible when decoding.
|
| 194 |
+
# If a token is not found (e.g. "<pad>" is not found in "mistral"), then it is
|
| 195 |
+
# added to the vocabulary, increasing the vocab_size accordingly.
|
| 196 |
+
return ["<s>", "</s>", "<pad>"]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
tf.test.main()
|
Tipsomaly/model/big_vision/pp/registry.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Global Registry for big_vision pp ops.
|
| 16 |
+
|
| 17 |
+
Author: Joan Puigcerver (jpuigcerver@)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import absolute_import
|
| 21 |
+
from __future__ import division
|
| 22 |
+
from __future__ import print_function
|
| 23 |
+
|
| 24 |
+
import ast
|
| 25 |
+
import contextlib
|
| 26 |
+
import functools
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_name(string_to_parse):
|
| 30 |
+
"""Parses input to the registry's lookup function.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
string_to_parse: can be either an arbitrary name or function call
|
| 34 |
+
(optionally with positional and keyword arguments).
|
| 35 |
+
e.g. "multiclass", "resnet50_v2(filters_factor=8)".
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
A tuple of input name, argument tuple and a keyword argument dictionary.
|
| 39 |
+
Examples:
|
| 40 |
+
"multiclass" -> ("multiclass", (), {})
|
| 41 |
+
"resnet50_v2(9, filters_factor=4)" ->
|
| 42 |
+
("resnet50_v2", (9,), {"filters_factor": 4})
|
| 43 |
+
|
| 44 |
+
Author: Joan Puigcerver (jpuigcerver@)
|
| 45 |
+
"""
|
| 46 |
+
expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error
|
| 47 |
+
if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)):
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"The given string should be a name or a call, but a {} was parsed from "
|
| 50 |
+
"the string {!r}".format(type(expr), string_to_parse))
|
| 51 |
+
|
| 52 |
+
# Notes:
|
| 53 |
+
# name="some_name" -> type(expr) = ast.Name
|
| 54 |
+
# name="module.some_name" -> type(expr) = ast.Attribute
|
| 55 |
+
# name="some_name()" -> type(expr) = ast.Call
|
| 56 |
+
# name="module.some_name()" -> type(expr) = ast.Call
|
| 57 |
+
|
| 58 |
+
if isinstance(expr, ast.Name):
|
| 59 |
+
return string_to_parse, (), {}
|
| 60 |
+
elif isinstance(expr, ast.Attribute):
|
| 61 |
+
return string_to_parse, (), {}
|
| 62 |
+
|
| 63 |
+
def _get_func_name(expr):
|
| 64 |
+
if isinstance(expr, ast.Attribute):
|
| 65 |
+
return _get_func_name(expr.value) + "." + expr.attr
|
| 66 |
+
elif isinstance(expr, ast.Name):
|
| 67 |
+
return expr.id
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"Type {!r} is not supported in a function name, the string to parse "
|
| 71 |
+
"was {!r}".format(type(expr), string_to_parse))
|
| 72 |
+
|
| 73 |
+
def _get_func_args_and_kwargs(call):
|
| 74 |
+
args = tuple([ast.literal_eval(arg) for arg in call.args])
|
| 75 |
+
kwargs = {
|
| 76 |
+
kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords
|
| 77 |
+
}
|
| 78 |
+
return args, kwargs
|
| 79 |
+
|
| 80 |
+
func_name = _get_func_name(expr.func)
|
| 81 |
+
func_args, func_kwargs = _get_func_args_and_kwargs(expr)
|
| 82 |
+
|
| 83 |
+
return func_name, func_args, func_kwargs
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Registry(object):
|
| 87 |
+
"""Implements global Registry.
|
| 88 |
+
|
| 89 |
+
Authors: Joan Puigcerver (jpuigcerver@), Alexander Kolesnikov (akolesnikov@)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
_GLOBAL_REGISTRY = {}
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def global_registry():
|
| 96 |
+
return Registry._GLOBAL_REGISTRY
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def register(name, replace=False):
|
| 100 |
+
"""Creates a function that registers its input."""
|
| 101 |
+
|
| 102 |
+
def _register(item):
|
| 103 |
+
if name in Registry.global_registry() and not replace:
|
| 104 |
+
raise KeyError("The name {!r} was already registered.".format(name))
|
| 105 |
+
|
| 106 |
+
Registry.global_registry()[name] = item
|
| 107 |
+
return item
|
| 108 |
+
|
| 109 |
+
return _register
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def lookup(lookup_string, kwargs_extra=None):
|
| 113 |
+
"""Lookup a name in the registry."""
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
name, args, kwargs = parse_name(lookup_string)
|
| 117 |
+
except ValueError as e:
|
| 118 |
+
raise ValueError(f"Error parsing:\n{lookup_string}") from e
|
| 119 |
+
if kwargs_extra:
|
| 120 |
+
kwargs.update(kwargs_extra)
|
| 121 |
+
item = Registry.global_registry()[name]
|
| 122 |
+
return functools.partial(item, *args, **kwargs)
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def knows(lookup_string):
|
| 126 |
+
try:
|
| 127 |
+
name, _, _ = parse_name(lookup_string)
|
| 128 |
+
except ValueError as e:
|
| 129 |
+
raise ValueError(f"Error parsing:\n{lookup_string}") from e
|
| 130 |
+
return name in Registry.global_registry()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@contextlib.contextmanager
|
| 134 |
+
def temporary_ops(**kw):
|
| 135 |
+
"""Registers specified pp ops for use in a `with` block.
|
| 136 |
+
|
| 137 |
+
Example use:
|
| 138 |
+
|
| 139 |
+
with pp_registry.remporary_ops(
|
| 140 |
+
pow=lambda alpha: lambda d: {k: v**alpha for k, v in d.items()}):
|
| 141 |
+
pp = pp_builder.get_preprocess_fn("pow(alpha=2.0)|pow(alpha=0.5)")
|
| 142 |
+
features = pp(features)
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
**kw: Names are preprocess string function names to be used to specify the
|
| 146 |
+
preprocess function. Values are functions that can be called with params
|
| 147 |
+
(e.g. the `alpha` param in above example) and return functions to be used
|
| 148 |
+
to transform features.
|
| 149 |
+
|
| 150 |
+
Yields:
|
| 151 |
+
A context manager to be used in a `with` statement.
|
| 152 |
+
"""
|
| 153 |
+
reg = Registry.global_registry()
|
| 154 |
+
kw = {f"preprocess_ops.{k}": v for k, v in kw.items()}
|
| 155 |
+
for k in kw:
|
| 156 |
+
assert k not in reg
|
| 157 |
+
for k, v in kw.items():
|
| 158 |
+
reg[k] = v
|
| 159 |
+
try:
|
| 160 |
+
yield
|
| 161 |
+
finally:
|
| 162 |
+
for k in kw:
|
| 163 |
+
del reg[k]
|
Tipsomaly/model/big_vision/pp/registry_test.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Tests for registry."""
|
| 16 |
+
|
| 17 |
+
from __future__ import absolute_import
|
| 18 |
+
from __future__ import division
|
| 19 |
+
from __future__ import print_function
|
| 20 |
+
|
| 21 |
+
from unittest import mock
|
| 22 |
+
|
| 23 |
+
from absl.testing import absltest
|
| 24 |
+
from big_vision.pp import registry
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RegistryTest(absltest.TestCase):
|
| 28 |
+
|
| 29 |
+
def setUp(self):
|
| 30 |
+
super(RegistryTest, self).setUp()
|
| 31 |
+
# Mock global registry in each test to keep them isolated and allow for
|
| 32 |
+
# concurrent tests.
|
| 33 |
+
self.addCleanup(mock.patch.stopall)
|
| 34 |
+
self.global_registry = dict()
|
| 35 |
+
self.mocked_method = mock.patch.object(
|
| 36 |
+
registry.Registry, "global_registry",
|
| 37 |
+
return_value=self.global_registry).start()
|
| 38 |
+
|
| 39 |
+
def test_parse_name(self):
|
| 40 |
+
name, args, kwargs = registry.parse_name("f")
|
| 41 |
+
self.assertEqual(name, "f")
|
| 42 |
+
self.assertEqual(args, ())
|
| 43 |
+
self.assertEqual(kwargs, {})
|
| 44 |
+
|
| 45 |
+
name, args, kwargs = registry.parse_name("f()")
|
| 46 |
+
self.assertEqual(name, "f")
|
| 47 |
+
self.assertEqual(args, ())
|
| 48 |
+
self.assertEqual(kwargs, {})
|
| 49 |
+
|
| 50 |
+
name, args, kwargs = registry.parse_name("func(a=0,b=1,c='s')")
|
| 51 |
+
self.assertEqual(name, "func")
|
| 52 |
+
self.assertEqual(args, ())
|
| 53 |
+
self.assertEqual(kwargs, {"a": 0, "b": 1, "c": "s"})
|
| 54 |
+
|
| 55 |
+
name, args, kwargs = registry.parse_name("func(1,'foo',3)")
|
| 56 |
+
self.assertEqual(name, "func")
|
| 57 |
+
self.assertEqual(args, (1, "foo", 3))
|
| 58 |
+
self.assertEqual(kwargs, {})
|
| 59 |
+
|
| 60 |
+
name, args, kwargs = registry.parse_name("func(1,'2',a=3,foo='bar')")
|
| 61 |
+
self.assertEqual(name, "func")
|
| 62 |
+
self.assertEqual(args, (1, "2"))
|
| 63 |
+
self.assertEqual(kwargs, {"a": 3, "foo": "bar"})
|
| 64 |
+
|
| 65 |
+
name, args, kwargs = registry.parse_name("foo.bar.func(a=0,b=(1),c='s')")
|
| 66 |
+
self.assertEqual(name, "foo.bar.func")
|
| 67 |
+
self.assertEqual(kwargs, dict(a=0, b=1, c="s"))
|
| 68 |
+
|
| 69 |
+
with self.assertRaises(SyntaxError):
|
| 70 |
+
registry.parse_name("func(0")
|
| 71 |
+
with self.assertRaises(SyntaxError):
|
| 72 |
+
registry.parse_name("func(a=0,,b=0)")
|
| 73 |
+
with self.assertRaises(SyntaxError):
|
| 74 |
+
registry.parse_name("func(a=0,b==1,c='s')")
|
| 75 |
+
with self.assertRaises(ValueError):
|
| 76 |
+
registry.parse_name("func(a=0,b=undefined_name,c='s')")
|
| 77 |
+
|
| 78 |
+
def test_register(self):
|
| 79 |
+
# pylint: disable=unused-variable
|
| 80 |
+
@registry.Registry.register("func1")
|
| 81 |
+
def func1():
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
self.assertLen(registry.Registry.global_registry(), 1)
|
| 85 |
+
|
| 86 |
+
def test_lookup_function(self):
|
| 87 |
+
|
| 88 |
+
@registry.Registry.register("func1")
|
| 89 |
+
def func1(arg1, arg2, arg3): # pylint: disable=unused-variable
|
| 90 |
+
return arg1, arg2, arg3
|
| 91 |
+
|
| 92 |
+
self.assertTrue(callable(registry.Registry.lookup("func1")))
|
| 93 |
+
self.assertEqual(registry.Registry.lookup("func1")(1, 2, 3), (1, 2, 3))
|
| 94 |
+
self.assertEqual(
|
| 95 |
+
registry.Registry.lookup("func1(arg3=9)")(1, 2), (1, 2, 9))
|
| 96 |
+
self.assertEqual(
|
| 97 |
+
registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg3=3), (99, 9, 3))
|
| 98 |
+
self.assertEqual(
|
| 99 |
+
registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg1=1, arg3=3),
|
| 100 |
+
(1, 9, 3))
|
| 101 |
+
|
| 102 |
+
self.assertEqual(
|
| 103 |
+
registry.Registry.lookup("func1(1)")(1, 2), (1, 1, 2))
|
| 104 |
+
self.assertEqual(
|
| 105 |
+
registry.Registry.lookup("func1(1)")(arg3=3, arg2=2), (1, 2, 3))
|
| 106 |
+
self.assertEqual(
|
| 107 |
+
registry.Registry.lookup("func1(1, 2)")(3), (1, 2, 3))
|
| 108 |
+
self.assertEqual(
|
| 109 |
+
registry.Registry.lookup("func1(1, 2)")(arg3=3), (1, 2, 3))
|
| 110 |
+
self.assertEqual(
|
| 111 |
+
registry.Registry.lookup("func1(1, arg2=2)")(arg3=3), (1, 2, 3))
|
| 112 |
+
self.assertEqual(
|
| 113 |
+
registry.Registry.lookup("func1(1, arg3=2)")(arg2=3), (1, 3, 2))
|
| 114 |
+
self.assertEqual(
|
| 115 |
+
registry.Registry.lookup("func1(1, arg3=2)")(3), (1, 3, 2))
|
| 116 |
+
|
| 117 |
+
with self.assertRaises(TypeError):
|
| 118 |
+
registry.Registry.lookup("func1(1, arg2=2)")(3)
|
| 119 |
+
with self.assertRaises(TypeError):
|
| 120 |
+
registry.Registry.lookup("func1(1, arg3=3)")(arg3=3)
|
| 121 |
+
with self.assertRaises(TypeError):
|
| 122 |
+
registry.Registry.lookup("func1(1, arg3=3)")(arg1=3)
|
| 123 |
+
with self.assertRaises(SyntaxError):
|
| 124 |
+
registry.Registry.lookup("func1(arg1=1, 3)")(arg2=3)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
absltest.main()
|
Tipsomaly/model/big_vision/pp/tokenizer.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""The tokenizer API for big_vision, and central registration place."""
|
| 16 |
+
import functools
|
| 17 |
+
import importlib
|
| 18 |
+
from typing import Protocol
|
| 19 |
+
|
| 20 |
+
from absl import logging
|
| 21 |
+
from big_vision.pp import registry
|
| 22 |
+
import big_vision.utils as u
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Tokenizer(Protocol):
|
| 27 |
+
"""Just to unify on the API as we now have mmany different ones."""
|
| 28 |
+
|
| 29 |
+
def to_int(self, text, *, bos=False, eos=False):
|
| 30 |
+
"""Tokenizes `text` into a list of integer tokens.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
text: can be a single string, or a list of strings.
|
| 34 |
+
bos: Whether a beginning-of-sentence token should be prepended.
|
| 35 |
+
eos: Whether an end-of-sentence token should be appended.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
List or list-of-list of tokens.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def to_int_tf_op(self, text, *, bos=False, eos=False):
|
| 42 |
+
"""Same as `to_int()`, but as TF ops to be used in pp."""
|
| 43 |
+
|
| 44 |
+
def to_str(self, tokens, *, stop_at_eos=True):
|
| 45 |
+
"""Inverse of `to_int()`.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
tokens: list of tokens, or list of lists of tokens.
|
| 49 |
+
stop_at_eos: remove everything that may come after the first EOS.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
A string (if `tokens` is a list of tokens), or a list of strings.
|
| 53 |
+
Note that most tokenizers strip select few control tokens like
|
| 54 |
+
eos/bos/pad/unk from the output string.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def to_str_tf_op(self, tokens, *, stop_at_eos=True):
|
| 58 |
+
"""Same as `to_str()`, but as TF ops to be used in pp."""
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def pad_token(self):
|
| 62 |
+
"""Token id of padding token."""
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def eos_token(self):
|
| 66 |
+
"""Token id of end-of-sentence token."""
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def bos_token(self):
|
| 70 |
+
"""Token id of beginning-of-sentence token."""
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def vocab_size(self):
|
| 74 |
+
"""Returns the size of the vocabulary."""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@functools.cache
|
| 78 |
+
def get_tokenizer(name):
|
| 79 |
+
with u.chrono.log_timing(f"z/secs/tokenizer/{name}"):
|
| 80 |
+
if not registry.Registry.knows(f"tokenizers.{name}"):
|
| 81 |
+
raw_name, *_ = registry.parse_name(name)
|
| 82 |
+
logging.info("Tokenizer %s not registered, "
|
| 83 |
+
"trying import big_vision.pp.%s", name, raw_name)
|
| 84 |
+
importlib.import_module(f"big_vision.pp.{raw_name}")
|
| 85 |
+
|
| 86 |
+
return registry.Registry.lookup(f"tokenizers.{name}")()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_extra_tokens(tokensets):
|
| 90 |
+
extra_tokens = []
|
| 91 |
+
for tokenset in tokensets:
|
| 92 |
+
extra_tokens.extend(registry.Registry.lookup(f"tokensets.{tokenset}")())
|
| 93 |
+
return list(np.unique(extra_tokens)) # Preserves order. Dups make no sense.
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@registry.Registry.register("tokensets.loc")
|
| 97 |
+
def _get_loc1024(n=1024):
|
| 98 |
+
return [f"<loc{i:04d}>" for i in range(n)]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@registry.Registry.register("tokensets.seg")
|
| 102 |
+
def _get_seg(n=128):
|
| 103 |
+
return [f"<seg{i:03d}>" for i in range(n)]
|