AlirezaSalehi99 commited on
Commit
95cc73b
·
verified ·
1 Parent(s): 1954c27

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. Tipsomaly/.gitignore +2 -0
  3. Tipsomaly/imgs/Models_Architecture_page-0001.jpg +3 -0
  4. Tipsomaly/imgs/Qualitative_results_page-0001.jpg +3 -0
  5. Tipsomaly/imgs/results-table.png +3 -0
  6. Tipsomaly/model/big_vision/__pycache__/__init__.cpython-39.pyc +0 -0
  7. Tipsomaly/model/big_vision/__pycache__/load_siglip.cpython-39.pyc +0 -0
  8. Tipsomaly/model/big_vision/__pycache__/utils.cpython-39.pyc +0 -0
  9. Tipsomaly/model/big_vision/configs/__init__.py +0 -0
  10. Tipsomaly/model/big_vision/configs/bit_i1k.py +102 -0
  11. Tipsomaly/model/big_vision/configs/bit_i21k.py +85 -0
  12. Tipsomaly/model/big_vision/configs/common.py +188 -0
  13. Tipsomaly/model/big_vision/configs/common_fewshot.py +60 -0
  14. Tipsomaly/model/big_vision/configs/load_and_eval.py +143 -0
  15. Tipsomaly/model/big_vision/configs/mlp_mixer_i1k.py +120 -0
  16. Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vertexai_l4.py +115 -0
  17. Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vqav2.py +160 -0
  18. Tipsomaly/model/big_vision/configs/transfer.py +186 -0
  19. Tipsomaly/model/big_vision/configs/vit_i1k.py +177 -0
  20. Tipsomaly/model/big_vision/configs/vit_i21k.py +145 -0
  21. Tipsomaly/model/big_vision/configs/vit_s16_i1k.py +105 -0
  22. Tipsomaly/model/big_vision/datasets/core.py +77 -0
  23. Tipsomaly/model/big_vision/datasets/jsonl.py +177 -0
  24. Tipsomaly/model/big_vision/datasets/sequence_packing.py +77 -0
  25. Tipsomaly/model/big_vision/datasets/tfds.py +94 -0
  26. Tipsomaly/model/big_vision/evaluators/__init__.py +0 -0
  27. Tipsomaly/model/big_vision/evaluators/classification.py +76 -0
  28. Tipsomaly/model/big_vision/evaluators/common.py +228 -0
  29. Tipsomaly/model/big_vision/evaluators/fewshot_lsr.py +245 -0
  30. Tipsomaly/model/big_vision/evaluators/mean.py +80 -0
  31. Tipsomaly/model/big_vision/evaluators/save.py +121 -0
  32. Tipsomaly/model/big_vision/models/__init__.py +0 -0
  33. Tipsomaly/model/big_vision/models/bit.py +162 -0
  34. Tipsomaly/model/big_vision/models/bit_paper.py +260 -0
  35. Tipsomaly/model/big_vision/models/common.py +133 -0
  36. Tipsomaly/model/big_vision/models/mlp_mixer.py +177 -0
  37. Tipsomaly/model/big_vision/models/vit.py +505 -0
  38. Tipsomaly/model/big_vision/pp/__init__.py +0 -0
  39. Tipsomaly/model/big_vision/pp/autoaugment.py +700 -0
  40. Tipsomaly/model/big_vision/pp/builder.py +85 -0
  41. Tipsomaly/model/big_vision/pp/builder_test.py +72 -0
  42. Tipsomaly/model/big_vision/pp/ops_general.py +468 -0
  43. Tipsomaly/model/big_vision/pp/ops_general_test.py +236 -0
  44. Tipsomaly/model/big_vision/pp/ops_image.py +361 -0
  45. Tipsomaly/model/big_vision/pp/ops_image_test.py +82 -0
  46. Tipsomaly/model/big_vision/pp/ops_text.py +411 -0
  47. Tipsomaly/model/big_vision/pp/ops_text_test.py +200 -0
  48. Tipsomaly/model/big_vision/pp/registry.py +163 -0
  49. Tipsomaly/model/big_vision/pp/registry_test.py +128 -0
  50. Tipsomaly/model/big_vision/pp/tokenizer.py +103 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Tipsomaly/imgs/Models_Architecture_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
37
+ Tipsomaly/imgs/results-table.png filter=lfs diff=lfs merge=lfs -text
38
+ Tipsomaly/imgs/Qualitative_results_page-0001.jpg filter=lfs diff=lfs merge=lfs -text
Tipsomaly/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.pyc
Tipsomaly/imgs/Models_Architecture_page-0001.jpg ADDED

Git LFS Details

  • SHA256: 6e793f62366a11789d2f93727d36730378d7cff7e89d2f53a179d3799eb1ddfe
  • Pointer size: 131 Bytes
  • Size of remote file: 565 kB
Tipsomaly/imgs/Qualitative_results_page-0001.jpg ADDED

Git LFS Details

  • SHA256: 23581ba2e6b0fd8fee121395adb9eb4249c5088f23d255dd99c850cb881679ed
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
Tipsomaly/imgs/results-table.png ADDED

Git LFS Details

  • SHA256: 58efcca11d4ea3e7f418d4450b895ac0cae26cd719aefad91ef2f83d9f91eeef
  • Pointer size: 131 Bytes
  • Size of remote file: 317 kB
Tipsomaly/model/big_vision/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (163 Bytes). View file
 
Tipsomaly/model/big_vision/__pycache__/load_siglip.cpython-39.pyc ADDED
Binary file (5.4 kB). View file
 
Tipsomaly/model/big_vision/__pycache__/utils.cpython-39.pyc ADDED
Binary file (52.4 kB). View file
 
Tipsomaly/model/big_vision/configs/__init__.py ADDED
File without changes
Tipsomaly/model/big_vision/configs/bit_i1k.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training BiT on ILSVRC-2012 as in https://arxiv.org/abs/1912.11370
17
+
18
+ Run training of a BiT-ResNet-50x1 variant, which takes ~32min on v3-128:
19
+
20
+ big_vision.train \
21
+ --config big_vision/configs/bit_i1k.py \
22
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
23
+ --config.model.depth 50 --config.model.width 1
24
+ """
25
+
26
+ # from big_vision.configs.common_fewshot import get_fewshot_lsr
27
+ import ml_collections as mlc
28
+
29
+
30
+ def get_config(runlocal=False):
31
+ """Config for training on ImageNet-1k."""
32
+ config = mlc.ConfigDict()
33
+
34
+ config.seed = 0
35
+ config.total_epochs = 90
36
+ config.num_classes = 1000
37
+ config.loss = 'softmax_xent'
38
+
39
+ config.input = dict()
40
+ config.input.data = dict(
41
+ name='imagenet2012',
42
+ split='train[:99%]',
43
+ )
44
+ config.input.batch_size = 4096
45
+ config.input.cache_raw = True # Needs up to 120GB of RAM!
46
+ config.input.shuffle_buffer_size = 250_000 # Per host.
47
+
48
+ pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
49
+ pp_common += '|value_range(-1, 1)|keep("image", "labels")'
50
+ config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common.format(lbl='label')
51
+ pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
52
+
53
+ config.log_training_steps = 50
54
+ config.ckpt_steps = 1000
55
+
56
+ # Model section
57
+ config.model_name = 'bit'
58
+ config.model = dict(
59
+ depth=50, # You can also pass e.g. [3, 5, 10, 2]
60
+ width=1.0,
61
+ )
62
+
63
+ # Optimizer section
64
+ config.optax_name = 'big_vision.momentum_hp'
65
+ config.grad_clip_norm = 1.0
66
+
67
+ # linear scaling rule. Don't forget to sweep if sweeping batch_size.
68
+ config.wd = (1e-4 / 256) * config.input.batch_size
69
+ config.lr = (0.1 / 256) * config.input.batch_size
70
+ config.schedule = dict(decay_type='cosine', warmup_steps=1000)
71
+
72
+ # Eval section
73
+ def get_eval(split, dataset='imagenet2012'):
74
+ return dict(
75
+ type='classification',
76
+ data=dict(name=dataset, split=split),
77
+ pp_fn=pp_eval.format(lbl='label'),
78
+ loss_name=config.loss,
79
+ log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
80
+ cache='final_data',
81
+ )
82
+ config.evals = {}
83
+ config.evals.train = get_eval('train[:2%]')
84
+ config.evals.minival = get_eval('train[99%:]')
85
+ config.evals.val = get_eval('validation')
86
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
87
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
88
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
89
+
90
+ # config.evals.fewshot = get_fewshot_lsr(runlocal=runlocal)
91
+ # config.evals.fewshot.log_steps = 1000
92
+
93
+ if runlocal:
94
+ config.input.batch_size = 32
95
+ config.input.cache_raw = False
96
+ config.input.shuffle_buffer_size = 100
97
+
98
+ local_eval = config.evals.val
99
+ config.evals = {'val': local_eval}
100
+ config.evals.val.cache = 'none'
101
+
102
+ return config
Tipsomaly/model/big_vision/configs/bit_i21k.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""A config for pre-training BiT on ImageNet-21k.
17
+
18
+ This config relies on the Imagenet-21k tfds dataset, which is not yet
19
+ available publicly in TFDS. We intend to add the dataset to public TFDS soon,
20
+ and this config will then be runnable.
21
+ """
22
+
23
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
24
+ import ml_collections as mlc
25
+
26
+
27
+ def get_config():
28
+ """Config for training on imagenet-21k."""
29
+ config = mlc.ConfigDict()
30
+
31
+ config.seed = 0
32
+ config.total_epochs = 90
33
+ config.num_classes = 21843
34
+ config.init_head_bias = -10.0
35
+ config.loss = 'sigmoid_xent'
36
+
37
+ config.input = dict()
38
+ config.input.data = dict(
39
+ name='imagenet21k',
40
+ split='full[51200:]',
41
+ )
42
+ config.input.batch_size = 4096
43
+ config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
44
+
45
+ pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
46
+ pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
47
+ pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
48
+ config.input.pp = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
49
+ pp_eval = 'decode|resize_small(256)|central_crop(224)'
50
+
51
+ config.log_training_steps = 50
52
+ config.ckpt_steps = 1000
53
+
54
+ # Model section
55
+ config.model_name = 'bit_paper'
56
+ config.model = dict(depth=50, width=1.0)
57
+
58
+ # Optimizer section
59
+ config.optax_name = 'big_vision.momentum_hp'
60
+ config.grad_clip_norm = 1.0
61
+
62
+ # linear scaling rule. Don't forget to sweep if sweeping batch_size.
63
+ config.lr = (0.03 / 256) * config.input.batch_size
64
+ config.wd = (3e-5 / 256) * config.input.batch_size
65
+ config.schedule = dict(decay_type='cosine', warmup_steps=5000)
66
+
67
+ # Evaluations on i21k itself.
68
+ def eval_i21k(split):
69
+ return dict(
70
+ type='classification',
71
+ data={**config.input.data, 'split': split},
72
+ pp_fn=pp_eval + pp_common_i21k,
73
+ loss_name=config.loss,
74
+ log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
75
+ )
76
+ config.evals = {}
77
+ config.evals.test = eval_i21k('full[:25_600]')
78
+ config.evals.val = eval_i21k('full[25_600:51_200]')
79
+ config.evals.train = eval_i21k('full[51_200:76_800]')
80
+
81
+ # Few-shot evaluators
82
+ config.evals.fewshot = get_fewshot_lsr()
83
+ config.evals.fewshot.log_steps = 25_000
84
+
85
+ return config
Tipsomaly/model/big_vision/configs/common.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A few things commonly used across A LOT of config files."""
16
+
17
+ import string
18
+
19
+ import ml_collections as mlc
20
+
21
+
22
+ def input_for_quicktest(config_input, quicktest):
23
+ if quicktest:
24
+ config_input.batch_size = 8
25
+ config_input.shuffle_buffer_size = 10
26
+ config_input.cache_raw = False
27
+
28
+
29
+ def parse_arg(arg, lazy=False, **spec):
30
+ """Makes ConfigDict's get_config single-string argument more usable.
31
+
32
+ Example use in the config file:
33
+
34
+ import big_vision.configs.common as bvcc
35
+ def get_config(arg):
36
+ arg = bvcc.parse_arg(arg,
37
+ res=(224, int),
38
+ runlocal=False,
39
+ schedule='short',
40
+ )
41
+
42
+ # ...
43
+
44
+ config.shuffle_buffer = 250_000 if not arg.runlocal else 50
45
+
46
+ Ways that values can be passed when launching:
47
+
48
+ --config amazing.py:runlocal,schedule=long,res=128
49
+ --config amazing.py:res=128
50
+ --config amazing.py:runlocal # A boolean needs no value for "true".
51
+ --config amazing.py:runlocal=False # Explicit false boolean.
52
+ --config amazing.py:128 # The first spec entry may be passed unnamed alone.
53
+
54
+ Uses strict bool conversion (converting 'True', 'true' to True, and 'False',
55
+ 'false', '' to False).
56
+
57
+ Args:
58
+ arg: the string argument that's passed to get_config.
59
+ lazy: allow lazy parsing of arguments, which are not in spec. For these,
60
+ the type is auto-extracted in dependence of most complex possible type.
61
+ **spec: the name and default values of the expected options.
62
+ If the value is a tuple, the value's first element is the default value,
63
+ and the second element is a function called to convert the string.
64
+ Otherwise the type is automatically extracted from the default value.
65
+
66
+ Returns:
67
+ ConfigDict object with extracted type-converted values.
68
+ """
69
+ # Normalize arg and spec layout.
70
+ arg = arg or '' # Normalize None to empty string
71
+ spec = {k: get_type_with_default(v) for k, v in spec.items()}
72
+
73
+ result = mlc.ConfigDict(type_safe=False) # For convenient dot-access only.
74
+
75
+ # Expand convenience-cases for a single parameter without = sign.
76
+ if arg and ',' not in arg and '=' not in arg:
77
+ # (think :runlocal) If it's the name of sth in the spec (or there is no
78
+ # spec), it's that in bool.
79
+ if arg in spec or not spec:
80
+ arg = f'{arg}=True'
81
+ # Otherwise, it is the value for the first entry in the spec.
82
+ else:
83
+ arg = f'{list(spec.keys())[0]}={arg}'
84
+ # Yes, we rely on Py3.7 insertion order!
85
+
86
+ # Now, expand the `arg` string into a dict of keys and values:
87
+ raw_kv = {raw_arg.split('=')[0]:
88
+ raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True'
89
+ for raw_arg in arg.split(',') if raw_arg}
90
+
91
+ # And go through the spec, using provided or default value for each:
92
+ for name, (default, type_fn) in spec.items():
93
+ val = raw_kv.pop(name, None)
94
+ result[name] = type_fn(val) if val is not None else default
95
+
96
+ if raw_kv:
97
+ if lazy: # Process args which are not in spec.
98
+ for k, v in raw_kv.items():
99
+ result[k] = autotype(v)
100
+ else:
101
+ raise ValueError(f'Unhandled config args remain: {raw_kv}')
102
+
103
+ return result
104
+
105
+
106
+ def get_type_with_default(v):
107
+ """Returns (v, string_to_v_type) with lenient bool parsing."""
108
+ # For bool, do safe string conversion.
109
+ if isinstance(v, bool):
110
+ def strict_bool(x):
111
+ assert x.lower() in {'true', 'false', ''}
112
+ return x.lower() == 'true'
113
+ return (v, strict_bool)
114
+ # If already a (default, type) tuple, use that.
115
+ if isinstance(v, (tuple, list)):
116
+ assert len(v) == 2 and isinstance(v[1], type), (
117
+ 'List or tuple types are currently not supported because we use `,` as'
118
+ ' dumb delimiter. Contributions (probably using ast) welcome. You can'
119
+ ' unblock by using a string with eval(s.replace(";", ",")) or similar')
120
+ return (v[0], v[1])
121
+ # Otherwise, derive the type from the default value.
122
+ return (v, type(v))
123
+
124
+
125
+ def autotype(x):
126
+ """Auto-converts string to bool/int/float if possible."""
127
+ assert isinstance(x, str)
128
+ if x.lower() in {'true', 'false'}:
129
+ return x.lower() == 'true' # Returns as bool.
130
+ try:
131
+ return int(x) # Returns as int.
132
+ except ValueError:
133
+ try:
134
+ return float(x) # Returns as float.
135
+ except ValueError:
136
+ return x # Returns as str.
137
+
138
+
139
+ def pack_arg(**kw):
140
+ """Packs key-word args as a string to be parsed by `parse_arg()`."""
141
+ for v in kw.values():
142
+ assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}"
143
+ return ','.join([f'{k}={v}' for k, v in kw.items()])
144
+
145
+
146
+ def arg(**kw):
147
+ """Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg."""
148
+ return {'config_arg': pack_arg(**kw), **kw}
149
+
150
+
151
+ def _get_field_ref(config_dict, field_name):
152
+ path = field_name.split('.')
153
+ for field in path[:-1]:
154
+ config_dict = getattr(config_dict, field)
155
+ return config_dict.get_ref(path[-1])
156
+
157
+
158
+ def format_str(format_string, config):
159
+ """Format string with reference fields from config.
160
+
161
+ This makes it easy to build preprocess strings that contain references to
162
+ fields tha are edited after. E.g.:
163
+
164
+ ```
165
+ config = mlc.ConficDict()
166
+ config.res = (256, 256)
167
+ config.pp = bvcc.format_str('resize({res})', config)
168
+ ...
169
+ # if config.res is modified (e.g. via sweeps) it will propagate to pp field:
170
+ config.res = (512, 512)
171
+ assert config.pp == 'resize((512, 512))'
172
+ ```
173
+
174
+ Args:
175
+ format_string: string to format with references.
176
+ config: ConfigDict to get references to format the string.
177
+
178
+ Returns:
179
+ A reference field which renders a string using references to config fields.
180
+ """
181
+ output = ''
182
+ parts = string.Formatter().parse(format_string)
183
+ for (literal_text, field_name, format_spec, conversion) in parts:
184
+ assert not format_spec and not conversion
185
+ output += literal_text
186
+ if field_name:
187
+ output += _get_field_ref(config, field_name).to_str()
188
+ return output
Tipsomaly/model/big_vision/configs/common_fewshot.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Most common few-shot eval configuration."""
16
+
17
+ import ml_collections as mlc
18
+
19
+
20
+ def get_fewshot_lsr(target_resolution=224, resize_resolution=256,
21
+ runlocal=False, pp=None, **kw):
22
+ """Returns a standard-ish fewshot eval configuration."""
23
+ kw.setdefault('representation_layer', 'pre_logits')
24
+ kw.setdefault('shots', (1, 5, 10, 25))
25
+ kw.setdefault('l2_reg', 2.0 ** 10)
26
+ kw.setdefault('num_seeds', 3)
27
+ kw.setdefault('prefix', '') # No prefix as we already use a/ z/ and zz/
28
+
29
+ # Backward-compatible default:
30
+ if not any(f'log_{x}' in kw for x in ['steps', 'percent', 'examples', 'epochs']): # pylint: disable=line-too-long
31
+ kw['log_steps'] = 25_000
32
+
33
+ config = mlc.ConfigDict(kw)
34
+ config.type = 'fewshot_lsr'
35
+ config.datasets = {
36
+ 'caltech': ('caltech101', 'train', 'test'), # copybara:srtip
37
+ 'cars': ('cars196:2.1.0', 'train', 'test'),
38
+ 'cifar100': ('cifar100', 'train', 'test'),
39
+ 'dtd': ('dtd', 'train', 'test'),
40
+ # The first 65000 ImageNet samples have at least 30 shots per any class.
41
+ # Commented out by default because needs manual download.
42
+ # 'imagenet': ('imagenet2012', 'train[:65000]', 'validation'),
43
+ 'pets': ('oxford_iiit_pet', 'train', 'test'),
44
+ 'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
45
+ } if not runlocal else {
46
+ 'pets': ('oxford_iiit_pet', 'train', 'test'),
47
+ }
48
+
49
+ pp = pp or '|'.join([
50
+ 'decode',
51
+ f'resize({resize_resolution})',
52
+ f'central_crop({target_resolution})',
53
+ 'value_range(-1,1)'
54
+ ])
55
+ pp += '|keep("image", "label")'
56
+ config.pp_train = pp
57
+ config.pp_eval = pp
58
+ config.display_first = [('imagenet', 10)] if not runlocal else [('pets', 10)]
59
+
60
+ return config
Tipsomaly/model/big_vision/configs/load_and_eval.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pytype: disable=not-writable,attribute-error
16
+ # pylint: disable=line-too-long,missing-function-docstring
17
+ r"""A config to load and eval key model using the core train.py.
18
+
19
+ The runtime varies widely depending on the model, but each one should reproduce
20
+ the corresponding paper's numbers.
21
+ This configuration makes use of the "arg" to get_config to select which model
22
+ to run, so a few examples are given below:
23
+
24
+ Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k:
25
+
26
+ big_vision.train \
27
+ --config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \
28
+ --config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50
29
+
30
+ Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper:
31
+
32
+ big_vision.train \
33
+ --config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \
34
+ --config.model.variant B/32 --config.model_init howto-i21k-B/32
35
+ """
36
+
37
+ import big_vision.configs.common as bvcc
38
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
39
+
40
+
41
+ def eval_only(config, batch_size, spec_for_init):
42
+ """Set a few configs that turn trainer into (almost) eval-only."""
43
+ config.total_steps = 0
44
+ config.input = {}
45
+ config.input.batch_size = batch_size
46
+ config.input.data = dict(name='bv:dummy', spec=spec_for_init)
47
+ config.optax_name = 'identity'
48
+ config.lr = 0.0
49
+
50
+ config.mesh = [('data', -1)]
51
+ config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')]
52
+ config.sharding_rules = [('act_batch', ('data',))]
53
+
54
+ return config
55
+
56
+
57
+ def get_config(arg=''):
58
+ config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4)
59
+
60
+ # Make the config eval-only by setting some dummies.
61
+ eval_only(config, config.batch_size, spec_for_init=dict(
62
+ image=dict(shape=(224, 224, 3), dtype='float32'),
63
+ ))
64
+
65
+ config.evals = dict(fewshot=get_fewshot_lsr())
66
+
67
+ # Just calls the function with the name given as `config`.
68
+ # Could also be a giant if-block if you're into that kind of thing.
69
+ globals()[config.name](config)
70
+ return config
71
+
72
+
73
+ def bit_paper(config):
74
+ config.num_classes = 1000
75
+
76
+ config.model_name = 'bit_paper'
77
+ config.model_init = 'M-imagenet2012' # M = i21k, -imagenet2012 = fine-tuned
78
+ config.model = dict(width=1, depth=50)
79
+
80
+ def get_eval(split, lbl, dataset='imagenet2012_real'):
81
+ return dict(
82
+ type='classification',
83
+ data=dict(name=dataset, split=split),
84
+ loss_name='softmax_xent',
85
+ cache='none', # Only run once, on low-mem machine.
86
+ pp_fn=(
87
+ 'decode|resize(384)|value_range(-1, 1)'
88
+ f'|onehot(1000, key="{lbl}", key_result="labels")'
89
+ '|keep("image", "labels")'
90
+ ),
91
+ )
92
+ config.evals.test = get_eval('validation', 'original_label')
93
+ config.evals.real = get_eval('validation', 'real_label')
94
+ config.evals.v2 = get_eval('test', 'label', 'imagenet_v2')
95
+
96
+
97
+ def vit_i1k(config):
98
+ config.num_classes = 1000
99
+
100
+ config.model_name = 'vit'
101
+ config.model_init = '' # Will be set in sweep.
102
+ config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
103
+ rep_size=True)
104
+
105
+ config.evals.val = dict(
106
+ type='classification',
107
+ data=dict(name='imagenet2012', split='validation'),
108
+ pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
109
+ loss_name='softmax_xent',
110
+ cache='none', # Only run once, on low-mem machine.
111
+ )
112
+
113
+
114
+ def mlp_mixer_i1k(config):
115
+ config.num_classes = 1000
116
+
117
+ config.model_name = 'mlp_mixer'
118
+ config.model_init = '' # Will be set in sweep.
119
+ config.model = dict(variant='L/16')
120
+
121
+ config.evals.val = dict(
122
+ type='classification',
123
+ data=dict(name='imagenet2012', split='validation'),
124
+ pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
125
+ loss_name='softmax_xent',
126
+ cache='none', # Only run once, on low-mem machine.
127
+ )
128
+
129
+
130
+ def vit_i21k(config):
131
+ config.num_classes = 21843
132
+
133
+ config.model_name = 'vit'
134
+ config.model_init = '' # Will be set in sweep.
135
+ config.model = dict(variant='B/32', pool_type='tok')
136
+
137
+ config.evals.val = dict(
138
+ type='classification',
139
+ data=dict(name='imagenet21k', split='full[:51200]'),
140
+ pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")',
141
+ loss_name='sigmoid_xent',
142
+ cache='none', # Only run once, on low-mem machine.
143
+ )
Tipsomaly/model/big_vision/configs/mlp_mixer_i1k.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""A config for training MLP-Mixer-B/16 model on ILSVRC-2012 ("ImageNet-1k").
17
+
18
+ Achieves 76.3% top-1 accuracy on the test split in 2h11m on TPU v3-128
19
+ with 300 epochs. A shorter 60 epochs run is expected to get to 70.5% in 27m.
20
+
21
+ big_vision.train \
22
+ --config big_vision/configs/mlp_mixer_i1k.py \
23
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
24
+ """
25
+
26
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
27
+ import ml_collections as mlc
28
+
29
+
30
+ def get_config(mode=None):
31
+ """Config for training Mixer on i1k."""
32
+ config = mlc.ConfigDict()
33
+
34
+ config.seed = 0
35
+ config.total_epochs = 300
36
+ config.num_classes = 1000
37
+ config.loss = 'sigmoid_xent'
38
+ config.init_head_bias = -6.9
39
+
40
+ config.input = dict()
41
+ config.input.data = dict(
42
+ name='imagenet2012',
43
+ split='train[:99%]',
44
+ )
45
+ config.input.batch_size = 4096
46
+ config.input.cache_raw = True # Needs up to 120GB of RAM!
47
+ config.input.shuffle_buffer_size = 250_000
48
+
49
+ config.input.pp = (
50
+ 'decode_jpeg_and_inception_crop(224)'
51
+ '|flip_lr'
52
+ '|randaug(2,15)'
53
+ '|value_range(-1, 1)'
54
+ '|onehot(1000, key="label", key_result="labels")'
55
+ '|keep("image", "labels")'
56
+ )
57
+ pp_eval = (
58
+ 'decode'
59
+ '|resize_small(256)|central_crop(224)'
60
+ '|value_range(-1, 1)'
61
+ '|onehot(1000, key="{lbl}", key_result="labels")'
62
+ '|keep("image", "labels")'
63
+ )
64
+
65
+ # To continue using the near-defunct randaug op.
66
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
67
+
68
+ config.log_training_steps = 50
69
+ config.ckpt_steps = 1000
70
+
71
+ config.prefetch_to_device = 2
72
+
73
+ # Model section
74
+ config.model_name = 'mlp_mixer'
75
+ config.model = dict()
76
+ config.model.variant = 'B/16'
77
+ config.model.stoch_depth = 0.1
78
+
79
+ config.mixup = dict(fold_in=None, p=0.5)
80
+
81
+ # Optimizer section
82
+ config.optax_name = 'scale_by_adam'
83
+ config.grad_clip_norm = 1.
84
+
85
+ config.lr = 0.001
86
+ config.wd = 1e-4
87
+ config.schedule = dict(
88
+ decay_type='linear',
89
+ warmup_steps=10_000,
90
+ linear_end=1e-5,
91
+ )
92
+
93
+ # Eval section
94
+ def get_eval(split, dataset='imagenet2012'):
95
+ return dict(
96
+ type='classification',
97
+ data=dict(name=dataset, split=split),
98
+ pp_fn=pp_eval.format(lbl='label'),
99
+ loss_name=config.loss,
100
+ log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
101
+ cache_final=mode != 'gpu8',
102
+ )
103
+ config.evals = {}
104
+ config.evals.train = get_eval('train[:2%]')
105
+ config.evals.minival = get_eval('train[99%:]')
106
+ config.evals.val = get_eval('validation')
107
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
108
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
109
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
110
+
111
+ config.fewshot = get_fewshot_lsr()
112
+
113
+ if mode == 'gpu8':
114
+ config.total_epochs = 60
115
+ config.input.batch_size = 512
116
+ config.input.cache_raw = False
117
+ if mode == 'regression_test':
118
+ config.total_epochs = 60
119
+
120
+ return config
Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vertexai_l4.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""PaliGemma transfer to a task stored in JSON-L, designed to fit on an L4 GPU.
17
+ """
18
+
19
+ import big_vision.configs.common as bvcc
20
+
21
+
22
+ def training_data(res, text_len):
23
+ """Creates training data config."""
24
+ c = bvcc.parse_arg('') # Just make a configdict without extra import.
25
+ c.data = dict(
26
+ name='bv:jsonl',
27
+ fname='gs://longcap100/data_train90.jsonl',
28
+ fopen_keys={'image': 'gs://longcap100/'},
29
+ # See docstring in datasets/jsonl.py for further details.
30
+ # download_keys=['image'], # If jsonl contains external paths.
31
+ )
32
+ c.pp = '|'.join([
33
+ # Read and prepare the image by just resizing it:
34
+ f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
35
+ # The texts are already prepared in `prefix` and `suffix` keys.
36
+ 'strfmt("caption en", outkey="prefix")',
37
+ combine_and_keep(text_len),
38
+ ])
39
+ # Keep the whole dataset in RAM after first pass. Useful optimization for
40
+ # small/mid-size datasets, but risks a host OOM for large datasets.
41
+ c.cache_raw = True
42
+ return c
43
+
44
+
45
+ def get_config(arg=None):
46
+ """Config for training."""
47
+ # You probably do NOT want to add settings here. The `arg` way of settings is
48
+ # really only for things you'd want to sweep and which affect MULTIPLE config
49
+ # settings at once or go into the pp string.
50
+ c = bvcc.parse_arg(arg, res=224, text_len=128, batch_size=4,
51
+ freeze_vit=False, freeze_llm=False)
52
+
53
+ c.input = training_data(c.res, c.text_len)
54
+
55
+ # These settings are suited for fitting in a single L4.
56
+ c.total_epochs = 1
57
+ c.input.batch_size = c.batch_size
58
+ c.optax_name = 'big_vision.sgd' # Without momentum, so really low-memory.
59
+ c.lr = 0.1
60
+ c.wd = 0.0
61
+ c.grad_clip_norm = 1.0
62
+ c.label_smoothing = 0.0
63
+
64
+ # Learning-rate schedule. Probably is fine like this.
65
+ sched = dict(decay_type='cosine', warmup_percent=0.05)
66
+ c.schedule = [
67
+ ('img/.*', None if c.freeze_vit else sched),
68
+ ('llm/.*', None if c.freeze_llm else sched),
69
+ ]
70
+
71
+ c.evals = {}
72
+
73
+ # Model section.
74
+ c.model_name = 'proj.paligemma.paligemma'
75
+ c.model = {}
76
+ # TODO: b/lbeyer - no scan and no remat might be better on 1-GPU machines?
77
+ c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
78
+ c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
79
+ c.model_init = f'pt_{c.res}'
80
+
81
+ # FSDP strategy.
82
+ c.mesh = [('data', -1)]
83
+ c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
84
+ c.sharding_rules = [('act_batch', ('data',))]
85
+
86
+ c.input.shuffle_buffer_size = 1000
87
+ c.log_training_steps = 1
88
+ c.ckpt_steps = 1_000
89
+ c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
90
+
91
+ c.seed = 0
92
+ return c
93
+
94
+
95
+ def tok(**kw):
96
+ """Creates the tokenization preprocessing string."""
97
+ # Single entry point so that it's consistent everywhere and easier to switch.
98
+ kw.setdefault('model', 'gemma(tokensets=("loc", "seg"))')
99
+ kw = ', '.join(f'{k}={repr(v)}' for k, v in kw.items())
100
+ return f'tok({kw})'
101
+
102
+
103
+ def combine_and_keep(text_len):
104
+ return '|'.join([
105
+ tok(key='prefix', bos='yes'),
106
+ tok(key='suffix', eos='yes'),
107
+ tok(key='septok', text='\n'),
108
+ # If masks confuse you, see (internal link)
109
+ 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_loss=[0, 0, 1])',
110
+ # For training, we +1 because the trainer removes EOS.
111
+ f'tolen({text_len+1}, pad_value=0, key="text")', # For text, value doesn't matter.
112
+ f'tolen({text_len+1}, pad_value=1, key="mask_ar")',
113
+ f'tolen({text_len+1}, pad_value=0, key="mask_loss")',
114
+ 'keep("image", "text", "mask_ar", "mask_loss")',
115
+ ])
Tipsomaly/model/big_vision/configs/proj/paligemma/transfers/vqav2.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""PaliGemma transfer to VQAv2.
17
+ """
18
+
19
+ import big_vision.configs.common as bvcc
20
+ from big_vision.configs.proj.paligemma.transfers.common import combine_and_keep_train, combine_and_keep_eval, TOKENIZER
21
+
22
+
23
+ def training_data(res, final_split, text_len=32):
24
+ """Creates training data config.
25
+
26
+ See (internal link)
27
+ You can add more arguments beside `res`, but give them good defaults.
28
+
29
+ Args:
30
+ res: The requested image resolution (eg 224).
31
+ final_split: Whether to use all of the validation data.
32
+ text_len: sequence length
33
+
34
+ Returns:
35
+ The ConfigDict for the input section.
36
+ """
37
+ c = bvcc.parse_arg('') # Just make a configdict without extra import.
38
+ c.data = dict(
39
+ name='vqa',
40
+ split='train + validation' if final_split else 'train + validation[:-10240]',
41
+ )
42
+ c.pp = '|'.join([
43
+ f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
44
+ 'strfmt("answer en {question_text}", outkey="prefix")',
45
+ 'choice_no_replacement(inkey="answers", outkey="suffix")',
46
+ combine_and_keep_train(text_len),
47
+ ])
48
+ return c
49
+
50
+
51
+ def add_eval(c, res, text_len=32, **kw):
52
+ """VQAv2 evaluators."""
53
+ pp = '|'.join([
54
+ f'decode|resize({res}, antialias=True)|value_range(-1, 1)',
55
+ 'strfmt("answer en {question_text}", outkey="prefix")',
56
+ combine_and_keep_eval(text_len, keep=('answers', 'answer_type', 'question_type', 'question_id')),
57
+ ])
58
+
59
+ for freq, name, split in [
60
+ (1/4, 'minitrain', 'train[:5120]'), # To gauge memorization.
61
+ (1/4, 'minival', 'validation[-10240:]'), # To tune hparams.
62
+ # To generate final predictions. Test sets combined since 2021 challenge.
63
+ (1.0, 'test', 'test + test-dev'),
64
+ ]:
65
+ c.evals[f'vqav2/{name}'] = dict(
66
+ type='proj.paligemma.transfers.vqav2',
67
+ pred='decode', pred_kw={'max_decode_len': text_len},
68
+ outfile=f'{{workdir}}/vqav2_{name}.json',
69
+ data={**training_data(res, True, text_len).data, 'split': split},
70
+ log_percent=freq, skip_first=freq == 1, tokenizer=TOKENIZER, pp_fn=pp)
71
+ c.evals[f'vqav2/{name}'].update(kw)
72
+
73
+
74
+ def add_eval_pplx(c, res, text_len=32):
75
+ """Perplexity evaluator to test runs before implementing the real deal."""
76
+ c_train = training_data(res, True, text_len) # Use mostly same settings as training.
77
+
78
+ for name, split in [
79
+ ('minitrain', 'train[:20_864]'), # To gauge memorization.
80
+ ('minival', 'validation[-10240:]'), # To tune hparams.
81
+ ]:
82
+ c.evals[f'vqav2/{name}/pplx'] = dict(
83
+ type='proj.paligemma.perplexity', pred='logits',
84
+ key='text', shift_labels=True,
85
+ log_percent=1/4, # Not too cheap, do 4x per run.
86
+ data={**c_train.data, 'split': split},
87
+ pp_fn=c_train.pp,
88
+ )
89
+
90
+
91
+ def sweep_best(add, arg=None):
92
+ """Train with best hyper-params."""
93
+ c = bvcc.parse_arg(arg, final_split=False)
94
+ # NOTE: lr was highest in sweep.
95
+ add(total_epochs=10, lr=1e-5, wd=1e-6, **bvcc.arg(res=224, **c))
96
+ add(total_epochs=10, lr=1e-5, wd=0.00, **bvcc.arg(res=448, **c))
97
+
98
+
99
+ sweep = sweep_best # Choose which sweep to run.
100
+
101
+
102
+ def get_config(arg=None):
103
+ """Config for training."""
104
+ c = bvcc.parse_arg(arg, mode='xm', res=224, final_split=False)
105
+
106
+ c.input = training_data(c.res, c.final_split)
107
+
108
+ # Instead of epochs, you can also use `total_examples` or `total_steps`.
109
+ c.total_epochs = 10
110
+ c.input.batch_size = 256
111
+ c.optax_name = 'scale_by_adam'
112
+ c.lr = 3e-6
113
+ c.wd = 3e-7
114
+ c.grad_clip_norm = 1.0
115
+ c.label_smoothing = 0.0
116
+ c.schedule = dict(decay_type='cosine', warmup_percent=0.05)
117
+
118
+ # Add evaluators.
119
+ c.evals = {}
120
+ add_eval(c, c.res, batch_size=1024)
121
+ add_eval_pplx(c, c.res)
122
+
123
+ # Model section.
124
+ c.model_name = 'proj.paligemma.paligemma'
125
+ c.model = {}
126
+ c.model.img = dict(variant='So400m/14', pool_type='none', scan=True)
127
+ c.model.llm = dict(vocab_size=256_000 + 1024 + 128, dropout=0.0)
128
+ c.model_init = f'pt_{c.res}'
129
+
130
+ # FSDP strategy.
131
+ c.mesh = [('data', -1)]
132
+ c.sharding_strategy = [('.*', 'fsdp(axis="data")')]
133
+ c.sharding_rules = [('act_batch', ('data',))]
134
+
135
+ # These probably do not need any change/tuning
136
+ c.input.shuffle_buffer_size = 50_000
137
+ c.log_training_steps = 50
138
+ c.ckpt_steps = 1_000
139
+ c.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'proj.paligemma.ops']
140
+
141
+ # Update configs for quicker local runs and avoid swapping.
142
+ if c.mode in ('runlocal', 'mock'):
143
+ c.input.shuffle_buffer_size = None
144
+ for ev in c.evals.values():
145
+ ev.data.split = ev.data.split.split('[')[0] + '[:16]'
146
+
147
+ if c.mode == 'runlocal':
148
+ c.log_training_steps = 1
149
+ c.input.batch_size = 2
150
+
151
+ c.seed = 0
152
+ return c
153
+
154
+
155
+ def metrics(arg=None): # pylint: disable=unused-argument
156
+ m = ['training_loss']
157
+ for split in ('minival', 'minitrain'):
158
+ m.append(f'vqav2/{split}/acc')
159
+ m.append(f'vqav2/{split}/pplx/avg')
160
+ return m
Tipsomaly/model/big_vision/configs/transfer.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long,missing-function-docstring
16
+ r"""A config for transferring vit-augreg.
17
+
18
+ Best HP selected on (mini)val, expected test results (repeated 5 times):
19
+
20
+ ViT-Augreg-B/32:
21
+ Dataset, crop, learning rate, mean (%), range (%)
22
+ - ImageNet, inception_crop, 0.03, 83.27, [83.22...83.33]
23
+ - Cifar10, resmall_crop, 0.003, 98.55, [98.46...98.6]
24
+ - Cifar100, resmall_crop, 0.01, 91.35, [91.09...91.62]
25
+ - Pets, inception_crop, 0.003, 93.78, [93.62...94.00]
26
+ - Flowers, inception_crop, 0.003, 99.43, [99.42...99.45]
27
+
28
+
29
+ Command to run:
30
+ big_vision.train \
31
+ --config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop \
32
+ --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03
33
+ """
34
+
35
+ import big_vision.configs.common as bvcc
36
+ import ml_collections as mlc
37
+
38
+
39
+ def _set_model(config, model):
40
+ """Load pre-trained models: vit or bit."""
41
+ # Reset the head to init (of zeros) when transferring.
42
+ config.model_load = dict(dont_load=['head/kernel', 'head/bias'])
43
+
44
+ if model == 'vit-i21k-augreg-b/32':
45
+ # Load "recommended" upstream B/32 from https://arxiv.org/abs/2106.10270
46
+ config.model_name = 'vit'
47
+ config.model_init = 'howto-i21k-B/32'
48
+ config.model = dict(variant='B/32', pool_type='tok')
49
+ elif model == 'vit-i21k-augreg-l/16':
50
+ config.model_name = 'vit'
51
+ config.model_init = 'howto-i21k-L/16'
52
+ config.model = dict(variant='L/16', pool_type='tok')
53
+ elif model == 'vit-s16':
54
+ config.model_name = 'vit'
55
+ config.model_init = 'i1k-s16-300ep'
56
+ config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
57
+ rep_size=True)
58
+ elif model == 'bit-m-r50x1':
59
+ config.model_name = 'bit_paper'
60
+ config.model_init = 'M'
61
+ config.model = dict(depth=50, width=1)
62
+ else:
63
+ raise ValueError(f'Unknown model: {model}, please define customized model.')
64
+
65
+
66
+ def _set_dataset(config, dataset, crop='inception_crop', h_res=448, l_res=384):
67
+ if dataset == 'cifar10':
68
+ _set_task(config, 'cifar10', 'train[:98%]', 'train[98%:]', 'test', 10, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
69
+ elif dataset == 'cifar100':
70
+ _set_task(config, 'cifar100', 'train[:98%]', 'train[98%:]', 'test', 100, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
71
+ elif dataset == 'imagenet2012':
72
+ _set_task(config, 'imagenet2012', 'train[:99%]', 'train[99%:]', 'validation', 1000, steps=20_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
73
+ _set_imagenet_variants(config)
74
+ elif dataset == 'oxford_iiit_pet':
75
+ _set_task(config, 'oxford_iiit_pet', 'train[:90%]', 'train[90%:]', 'test', 37, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
76
+ elif dataset == 'oxford_flowers102':
77
+ _set_task(config, 'oxford_flowers102', 'train[:90%]', 'train[90%:]', 'test', 102, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
78
+ else:
79
+ raise ValueError(
80
+ f'Unknown dataset: {dataset}, please define customized dataset.')
81
+
82
+
83
+ def _set_task(config, dataset, train, val, test, n_cls,
84
+ steps=20_000, warmup=500, lbl='label', crop='resmall_crop',
85
+ flip=True, h_res=448, l_res=384):
86
+ """Vision task with val and test splits."""
87
+ config.total_steps = steps
88
+ config.schedule = dict(
89
+ warmup_steps=warmup,
90
+ decay_type='cosine',
91
+ )
92
+
93
+ config.input.data = dict(name=dataset, split=train)
94
+ pp_common = (
95
+ '|value_range(-1, 1)|'
96
+ f'onehot({n_cls}, key="{lbl}", key_result="labels")|'
97
+ 'keep("image", "labels")'
98
+ )
99
+
100
+ if crop == 'inception_crop':
101
+ pp_train = f'decode|inception_crop({l_res})'
102
+ elif crop == 'resmall_crop':
103
+ pp_train = f'decode|resize_small({h_res})|random_crop({l_res})'
104
+ elif crop == 'resize_crop':
105
+ pp_train = f'decode|resize({h_res})|random_crop({l_res})'
106
+ else:
107
+ raise ValueError(f'Unknown crop: {crop}. Must be one of: '
108
+ 'inception_crop, resmall_crop, resize_crop')
109
+ if flip:
110
+ pp_train += '|flip_lr'
111
+ config.input.pp = pp_train + pp_common
112
+
113
+ pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common
114
+ config.num_classes = n_cls
115
+
116
+ def get_eval(split):
117
+ return dict(
118
+ type='classification',
119
+ data=dict(name=dataset, split=split),
120
+ loss_name='softmax_xent',
121
+ log_steps=100,
122
+ pp_fn=pp,
123
+ )
124
+ config.evals = dict(val=get_eval(val), test=get_eval(test))
125
+
126
+
127
+ def _set_imagenet_variants(config, h_res=448, l_res=384):
128
+ """Evaluation tasks on ImageNet variants: v2 and real."""
129
+ pp = (f'decode|resize_small({h_res})|central_crop({l_res})'
130
+ '|value_range(-1, 1)|onehot(1000, key="{lbl}", key_result="labels")|'
131
+ 'keep("image", "labels")'
132
+ )
133
+
134
+ # Special-case rename for i1k (val+test -> minival+val)
135
+ config.evals.minival = config.evals.val
136
+ config.evals.val = config.evals.test
137
+ # NOTE: keep test == val for convenience in subsequent analysis.
138
+
139
+ config.evals.real = dict(type='classification')
140
+ config.evals.real.data = dict(name='imagenet2012_real', split='validation')
141
+ config.evals.real.pp_fn = pp.format(lbl='real_label')
142
+ config.evals.real.loss_name = config.loss
143
+ config.evals.real.log_steps = 100
144
+
145
+ config.evals.v2 = dict(type='classification')
146
+ config.evals.v2.data = dict(name='imagenet_v2', split='test')
147
+ config.evals.v2.pp_fn = pp.format(lbl='label')
148
+ config.evals.v2.loss_name = config.loss
149
+ config.evals.v2.log_steps = 100
150
+
151
+
152
+ def get_config(arg=None):
153
+ """Config for adaptation."""
154
+ arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop',
155
+ h_res=448, l_res=384, batch_size=512, fsdp=False,
156
+ runlocal=False)
157
+ config = mlc.ConfigDict()
158
+
159
+ config.input = {}
160
+ config.input.batch_size = arg.batch_size if not arg.runlocal else 8
161
+ config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100
162
+
163
+ config.log_training_steps = 10
164
+ config.ckpt_steps = 1000
165
+ config.ckpt_timeout = 600
166
+
167
+ # Optimizer section
168
+ config.optax_name = 'big_vision.momentum_hp'
169
+ config.grad_clip_norm = 1.0
170
+ config.wd = None # That's our default, but just being explicit here!
171
+ config.loss = 'softmax_xent'
172
+ config.lr = 0.01
173
+ config.mixup = dict(p=0.0)
174
+
175
+ config.seed = 0
176
+
177
+ _set_dataset(config, arg.dataset, arg.crop, arg.h_res, arg.l_res)
178
+
179
+ _set_model(config, arg.model)
180
+ if arg.fsdp:
181
+ config.mesh = [('data', -1)]
182
+ config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
183
+ config.sharding_rules = [('act_batch', ('data',))]
184
+ config.model.scan = True
185
+
186
+ return config
Tipsomaly/model/big_vision/configs/vit_i1k.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training ViT on ILSVRC-2012 as in https://arxiv.org/abs/2106.10270
17
+
18
+ This config does NOT include regularization (dropout, stochastic depth), which
19
+ was shown to help with B/32, B/16, L/16 models in the paper (Figure 4).
20
+
21
+ This configuration makes use of the "arg" to get_config to select which model
22
+ to run, so a few examples are given below:
23
+
24
+ Run training of a B/16 model:
25
+
26
+ big_vision.train \
27
+ --config big_vision/configs/vit_i1k.py:variant=B/16 \
28
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
29
+
30
+ Run training of a B/32 model with custom aug-strenght and 300ep:
31
+
32
+ big_vision.train \
33
+ --config big_vision/configs/vit_i1k.py:variant=B/32,aug=light1 \
34
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
35
+ --config.total_epochs 300
36
+ """
37
+
38
+ import big_vision.configs.common as bvcc
39
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
40
+ import ml_collections as mlc
41
+
42
+ MIXUP_DEF = {
43
+ 'none': dict(p=0.0, fold_in=None),
44
+ 'light1': dict(p=0.0, fold_in=None),
45
+ 'light2': dict(p=0.2, fold_in=None),
46
+ 'medium1': dict(p=0.2, fold_in=None),
47
+ 'medium2': dict(p=0.5, fold_in=None),
48
+ 'strong1': dict(p=0.5, fold_in=None),
49
+ 'strong2': dict(p=0.8, fold_in=None),
50
+ }
51
+
52
+ RANDAUG_DEF = {
53
+ 'none': '',
54
+ 'light1': 'randaug(2,0)', # Actually not nothing!
55
+ 'light2': 'randaug(2,10)',
56
+ 'medium1': 'randaug(2,15)',
57
+ 'medium2': 'randaug(2,15)',
58
+ 'strong1': 'randaug(2,20)',
59
+ 'strong2': 'randaug(2,20)',
60
+ }
61
+
62
+
63
+ def get_config(arg=None):
64
+ """Config for training."""
65
+ arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug='')
66
+ config = mlc.ConfigDict()
67
+
68
+ config.seed = 0
69
+ config.total_epochs = 300
70
+ config.num_classes = 1000
71
+ config.loss = 'sigmoid_xent'
72
+ config.init_head_bias = -6.9
73
+
74
+ # If this gives a KeyError, lookup Fig4 of the paper and add an entry.
75
+ # Note, this here is a good average between 30ep and 300ep, sometimes you coud
76
+ # find a slightly better setting for either of them.
77
+ aug_setting = arg.aug or {
78
+ 'Ti/16': 'light1',
79
+ 'S/32': 'medium1',
80
+ 'S/16': 'medium2',
81
+ 'B/32': 'medium2',
82
+ 'B/16': 'medium2',
83
+ 'L/16': 'medium2',
84
+ }[arg.variant]
85
+
86
+ config.input = dict()
87
+ config.input.data = dict(
88
+ name='imagenet2012',
89
+ split='train[:99%]',
90
+ )
91
+ config.input.batch_size = 4096
92
+ config.input.cache = 'raw_data' if arg.runlocal else 'none' # Needs up to 120GB of RAM!
93
+ config.input.shuffle_buffer_size = 250_000
94
+
95
+ pp_common = (
96
+ '|value_range(-1, 1)'
97
+ '|onehot(1000, key="{lbl}", key_result="labels")'
98
+ '|keep("image", "labels")'
99
+ )
100
+ config.input.pp = (
101
+ 'decode_jpeg_and_inception_crop(224)|flip_lr|' +
102
+ RANDAUG_DEF[aug_setting] +
103
+ pp_common.format(lbl='label')
104
+ )
105
+ pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
106
+
107
+ # To continue using the near-defunct randaug op.
108
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
109
+
110
+ # Aggressive pre-fetching because our models here are small, so we not only
111
+ # can afford it, but we also need it for the smallest models to not be
112
+ # bottle-necked by the input pipeline. Play around with it for -L models tho.
113
+ config.input.prefetch = 8
114
+ config.prefetch_to_device = 4
115
+
116
+ config.log_training_steps = 50
117
+ config.ckpt_steps = 1000
118
+
119
+ # Model section
120
+ config.model_name = 'vit'
121
+ config.model = dict(
122
+ variant=arg.variant,
123
+ rep_size=True,
124
+ pool_type='tok',
125
+ )
126
+
127
+ # Optimizer section
128
+ config.grad_clip_norm = 1.0
129
+ config.optax_name = 'scale_by_adam'
130
+ config.optax = dict(mu_dtype='bfloat16')
131
+ # The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
132
+ # almost always behaves exactly like adam, but at a fraction of the memory
133
+ # cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
134
+ # good idea to try it when you are memory-bound!
135
+ # config.optax_name = 'big_vision.scale_by_adafactor'
136
+ # A good flag to play with when hitting instabilities, is the following:
137
+ # config.optax = dict(beta2_cap=0.95)
138
+
139
+ config.lr = 0.001
140
+ config.wd = 0.0001
141
+ config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
142
+
143
+ config.mixup = MIXUP_DEF[aug_setting]
144
+
145
+ # Eval section
146
+ def get_eval(split, dataset='imagenet2012'):
147
+ return dict(
148
+ type='classification',
149
+ data=dict(name=dataset, split=split),
150
+ pp_fn=pp_eval.format(lbl='label'),
151
+ loss_name=config.loss,
152
+ log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
153
+ cache='final_data' if arg.runlocal else 'none',
154
+ )
155
+ config.evals = {}
156
+ config.evals.train = get_eval('train[:2%]')
157
+ config.evals.minival = get_eval('train[99%:]')
158
+ config.evals.val = get_eval('validation')
159
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
160
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
161
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
162
+
163
+ config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
164
+ config.fewshot.log_steps = 10_000
165
+
166
+ # Make a few things much smaller for quick local debugging testruns.
167
+ if arg.runlocal:
168
+ config.input.shuffle_buffer_size = 10
169
+ config.input.batch_size = 8
170
+ config.input.cache_raw = False
171
+ config.evals.train.data.split = 'train[:16]'
172
+ config.evals.minival.data.split = 'train[:16]'
173
+ config.evals.val.data.split = 'validation[:16]'
174
+ config.evals.v2.data.split = 'test[:16]'
175
+ config.evals.real.data.split = 'validation[:16]'
176
+
177
+ return config
Tipsomaly/model/big_vision/configs/vit_i21k.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training ViT on ImageNet-21k as in https://arxiv.org/abs/2106.10270
17
+
18
+ This config relies on the Imagenet-21k tfds dataset, which is not yet
19
+ available publicly in TFDS. We intend to add the dataset to public TFDS soon,
20
+ and this config will then be runnable.
21
+
22
+ Note that regularization (dropout, stochastic depth) is not currently
23
+ implemented. This was not beneficial for ImageNet-21k pre-trainning.
24
+ """
25
+
26
+ import big_vision.configs.common as bvcc
27
+ from big_vision.configs.common_fewshot import get_fewshot_lsr
28
+ import ml_collections as mlc
29
+
30
+ MIXUP_DEF = {
31
+ 'none': dict(p=0.0, fold_in=None),
32
+ 'light1': dict(p=0.0, fold_in=None),
33
+ 'light2': dict(p=0.2, fold_in=None),
34
+ 'medium1': dict(p=0.2, fold_in=None),
35
+ 'medium2': dict(p=0.5, fold_in=None),
36
+ 'strong1': dict(p=0.5, fold_in=None),
37
+ 'strong2': dict(p=0.8, fold_in=None),
38
+ }
39
+
40
+ RANDAUG_DEF = {
41
+ 'none': '',
42
+ 'light1': 'randaug(2,0)', # Actually not nothing!
43
+ 'light2': 'randaug(2,10)',
44
+ 'medium1': 'randaug(2,15)',
45
+ 'medium2': 'randaug(2,15)',
46
+ 'strong1': 'randaug(2,20)',
47
+ 'strong2': 'randaug(2,20)',
48
+ }
49
+
50
+
51
+ def get_config(arg=None):
52
+ """Config for training."""
53
+ arg = bvcc.parse_arg(arg, variant='B/16', runlocal=False, aug=None)
54
+ config = mlc.ConfigDict()
55
+
56
+ config.seed = 0
57
+ config.total_epochs = 300
58
+ config.num_classes = 21843
59
+ config.init_head_bias = -10.0
60
+ config.loss = 'sigmoid_xent'
61
+
62
+ # If this gives a KeyError, lookup Fig4 of the paper and add an entry.
63
+ # Note, this here is a good average between 30ep and 300ep, sometimes you coud
64
+ # find a slightly better setting for either of them.
65
+ aug_setting = {
66
+ 'Ti/16': 'none',
67
+ 'S/32': 'none',
68
+ 'S/16': 'light1',
69
+ 'B/32': 'light2',
70
+ 'B/16': 'light2',
71
+ 'L/16': 'medium2',
72
+ }[arg.variant]
73
+
74
+ config.input = dict()
75
+ config.input.data = dict(
76
+ name='imagenet21k',
77
+ split='full[51200:]',
78
+ )
79
+ config.input.batch_size = 4096
80
+ config.input.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
81
+
82
+ pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
83
+ pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
84
+ pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
85
+ config.input.pp = f'decode_jpeg_and_inception_crop(224)|flip_lr|{RANDAUG_DEF[aug_setting]}' + pp_common_i21k
86
+ pp_eval = 'decode|resize_small(256)|central_crop(224)'
87
+
88
+ # To continue using the near-defunct randaug op.
89
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
90
+
91
+ # Aggressive pre-fetching because our models here are small, so we not only
92
+ # can afford it, but we also need it for the smallest models to not be
93
+ # bottle-necked by the input pipeline. Play around with it for -L models tho.
94
+ config.input.prefetch = 8
95
+ config.prefetch_to_device = 4
96
+
97
+ config.log_training_steps = 50
98
+ config.ckpt_steps = 1000
99
+
100
+ # Model section
101
+ config.model_name = 'vit'
102
+ config.model = dict(variant=arg.variant, pool_type='gap', posemb='learn')
103
+
104
+ # Optimizer section
105
+ config.optax_name = 'scale_by_adam'
106
+ config.optax = dict(mu_dtype='bfloat16')
107
+ config.grad_clip_norm = 1.0
108
+
109
+ config.lr = 0.001
110
+ config.wd = 0.0001
111
+ config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
112
+
113
+ config.mixup = MIXUP_DEF[aug_setting]
114
+
115
+ # Evaluations on i21k itself.
116
+ def eval_i21k(split):
117
+ return dict(
118
+ type='classification',
119
+ data={**config.input.data, 'split': split},
120
+ pp_fn=pp_eval + pp_common_i21k,
121
+ loss_name=config.loss,
122
+ log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
123
+ )
124
+ config.evals = {}
125
+ config.evals.test = eval_i21k('full[:25_600]')
126
+ config.evals.val = eval_i21k('full[25_600:51_200]')
127
+ config.evals.train = eval_i21k('full[51_200:76_800]')
128
+
129
+ # Few-shot evaluators
130
+ config.evals.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
131
+ config.evals.fewshot.log_steps = 25_000
132
+
133
+ # Make a few things much smaller for quick local debugging testruns.
134
+ if arg.runlocal:
135
+ config.input.shuffle_buffer_size = 10
136
+ config.input.batch_size = 8
137
+ config.evals.test.data.split = 'full[:16]'
138
+ config.evals.train.data.split = 'full[:16]'
139
+ config.evals.val.data.split = 'full[:16]'
140
+ config.evals.i1k_val.data.split = 'validation[:16]'
141
+ config.evals.i1k_v2.data.split = 'test[:16]'
142
+ config.evals.i1k_a.data.split = 'test[:16]'
143
+ config.evals.i1k_r.data.split = 'test[:16]'
144
+
145
+ return config
Tipsomaly/model/big_vision/configs/vit_s16_i1k.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint: disable=line-too-long
16
+ r"""Pre-training ViT-S/16 on ILSVRC-2012 following https://arxiv.org/abs/2205.01580.
17
+
18
+ This should take 6-7h to finish 90ep on a TPU-v3-8 and reach 76.5%,
19
+ see the tech report for more details.
20
+
21
+ Command to run:
22
+
23
+ big_vision.train \
24
+ --config big_vision/configs/vit_s16_i1k.py \
25
+ --workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
26
+
27
+ To run for 300ep, add `--config.total_epochs 300` to the command.
28
+ """
29
+
30
+ import ml_collections as mlc
31
+
32
+
33
+ def get_config():
34
+ """Config for training."""
35
+ config = mlc.ConfigDict()
36
+
37
+ config.seed = 0
38
+ config.total_epochs = 90
39
+ config.num_classes = 1000
40
+ config.loss = 'softmax_xent'
41
+
42
+ config.input = {}
43
+ config.input.data = dict(
44
+ name='imagenet2012',
45
+ split='train[:99%]',
46
+ )
47
+ config.input.batch_size = 1024
48
+ config.input.cache_raw = True # Needs up to 120GB of RAM!
49
+ config.input.shuffle_buffer_size = 250_000
50
+
51
+ pp_common = (
52
+ '|value_range(-1, 1)'
53
+ '|onehot(1000, key="{lbl}", key_result="labels")'
54
+ '|keep("image", "labels")'
55
+ )
56
+ config.input.pp = (
57
+ 'decode_jpeg_and_inception_crop(224)|flip_lr|randaug(2,10)' +
58
+ pp_common.format(lbl='label')
59
+ )
60
+ pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
61
+
62
+ # To continue using the near-defunct randaug op.
63
+ config.pp_modules = ['ops_general', 'ops_image', 'ops_text', 'archive.randaug']
64
+
65
+ config.log_training_steps = 50
66
+ config.ckpt_steps = 1000
67
+
68
+ # Model section
69
+ config.model_name = 'vit'
70
+ config.model = dict(
71
+ variant='S/16',
72
+ rep_size=True,
73
+ pool_type='gap',
74
+ posemb='sincos2d',
75
+ )
76
+
77
+ # Optimizer section
78
+ config.grad_clip_norm = 1.0
79
+ config.optax_name = 'scale_by_adam'
80
+ config.optax = dict(mu_dtype='bfloat16')
81
+
82
+ config.lr = 0.001
83
+ config.wd = 0.0001
84
+ config.schedule = dict(warmup_steps=10_000, decay_type='cosine')
85
+
86
+ config.mixup = dict(p=0.2, fold_in=None)
87
+
88
+ # Eval section
89
+ def get_eval(split, dataset='imagenet2012'):
90
+ return dict(
91
+ type='classification',
92
+ data=dict(name=dataset, split=split),
93
+ pp_fn=pp_eval.format(lbl='label'),
94
+ loss_name=config.loss,
95
+ log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
96
+ )
97
+ config.evals = {}
98
+ config.evals.train = get_eval('train[:2%]')
99
+ config.evals.minival = get_eval('train[99%:]')
100
+ config.evals.val = get_eval('validation')
101
+ config.evals.v2 = get_eval('test', dataset='imagenet_v2')
102
+ config.evals.real = get_eval('validation', dataset='imagenet2012_real')
103
+ config.evals.real.pp_fn = pp_eval.format(lbl='real_label')
104
+
105
+ return config
Tipsomaly/model/big_vision/datasets/core.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Core data functions, dispatch calls to the requested dataset."""
16
+ import importlib
17
+
18
+
19
+ # Note: intentionally not using ABC to avoid forcing implementation of every
20
+ # method, since one can imagine train-only datasets for example.
21
+ class DataSource:
22
+ """The API that any data source should implement."""
23
+
24
+ def get_tfdata(self, ordered, *, process_split=True, allow_cache=True):
25
+ """Creates this data object as a tf.data.Dataset.
26
+
27
+ This will be called separately in each process, and it is up to the dataset
28
+ implementation to shard it accordingly if desired!
29
+
30
+ Args:
31
+ ordered: if True, the dataset should use deterministic ordering, if False
32
+ it may have undefined ordering. Think of True == val, False == train.
33
+ process_split: if False then every process receives the entire dataset
34
+ (e.g. for evaluators running in a single process).
35
+ allow_cache: whether to allow caching the opened data or not.
36
+
37
+ Returns:
38
+ A tf.data.Dataset object.
39
+
40
+ Raises:
41
+ RuntimeError: if not implemented by the dataset, but called.
42
+ """
43
+ raise RuntimeError("not implemented for {self.__class__.__name__}")
44
+
45
+ @property
46
+ def total_examples(self):
47
+ """Returns number of examples in the dataset, regardless of sharding."""
48
+ raise RuntimeError("not implemented for {self.__class__.__name__}")
49
+
50
+ def num_examples_per_process(self):
51
+ """Returns a list of the numer of examples for each process.
52
+
53
+ This is only needed for datasets that should go through make_for_inference.
54
+
55
+ Returns:
56
+ Returns a list of the numer of examples for each process.
57
+
58
+ Ideally, this would always be `[total() / nprocess] * nprocess`, but in
59
+ reality we can almost never perfectly shard a dataset across arbitrary
60
+ number of processes.
61
+
62
+ One alternative option that can work in some cases is to not even shard
63
+ the dataset and thus return `[num_examples()] * nprocess.
64
+
65
+ Raises:
66
+ RuntimeError: if not implemented by the dataset, but called.
67
+ """
68
+ raise RuntimeError("not implemented for {self.__class__.__name__}")
69
+
70
+
71
+ def get(name, **kw):
72
+ if name.startswith("bv:"):
73
+ mod = importlib.import_module(f"big_vision.datasets.{name[3:]}")
74
+ return mod.DataSource(**kw)
75
+ else:
76
+ mod = importlib.import_module("big_vision.datasets.tfds")
77
+ return mod.DataSource(name, **kw)
Tipsomaly/model/big_vision/datasets/jsonl.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Simple data input from .jsonl files."""
16
+
17
+ import hashlib
18
+ import json
19
+ from multiprocessing.pool import ThreadPool
20
+ import os
21
+ import tempfile
22
+ import urllib.request
23
+
24
+ from absl import logging
25
+ import big_vision.datasets.core as ds_core
26
+ import jax
27
+ import numpy as np
28
+ import overrides
29
+ import tensorflow as tf
30
+
31
+
32
+ def cached_download(url, dest=None, verbose=True):
33
+ """Download `url` to local file and return path to that, but with caching."""
34
+ # NOTE: there is a small chance of saving corrupted data if the process is
35
+ # interrupted in the middle of writing the file. Then, reading in the input
36
+ # pipeline will fail, and the fix is to nuke the temp folder.
37
+
38
+ # Compute a temp name based on the URL, so we can check if we already
39
+ # downloaded it before.
40
+ dest = dest or os.path.join(tempfile.gettempdir(), "bv")
41
+ os.makedirs(dest, exist_ok=True)
42
+ dest = os.path.join(dest, hashlib.md5(url.encode()).hexdigest())
43
+
44
+ # NOTE: we should use last-modified header to know whether to re-download.
45
+ if os.path.isfile(dest):
46
+ return dest
47
+
48
+ if verbose:
49
+ print(f"\rRetrieving {url} into {dest}", end="", flush=True)
50
+
51
+ with urllib.request.urlopen(url) as f:
52
+ data = f.read()
53
+ with open(dest, "wb+") as f:
54
+ f.write(data)
55
+ return dest
56
+
57
+
58
+ class DataSource(ds_core.DataSource):
59
+ """.jsonl DataSource."""
60
+
61
+ def __init__(self, fname, *, fopen_keys=(), download_keys=(),
62
+ start=0, stop=float("inf")):
63
+ """Create data-source that's jsonl + data files (eg images).
64
+
65
+ This correctly supports multi-host in that each host only reads a subset of
66
+ the dataset automatically. However, currently, all hosts download all items
67
+ if `download_keys` is specified. TODO: b/lbeyer - This can be improved.
68
+
69
+ Args:
70
+ fname: str, the path to the jsonl file that holds the dataset.
71
+ fopen_keys: collection of str or dict, the keys in the dataset whose
72
+ string value actually is a file-path that should be opened and read,
73
+ and its content is what goes into the batch (eg image filenames
74
+ commonly ["image"]).
75
+ If a dict, the values are folders prefixed to the filenames.
76
+ Supports gs:// for reading from buckets.
77
+ download_keys: collection of str, the keys in the dataset whose string
78
+ value actually is a URL from which the file should be downloaded first.
79
+ files are downloaded to a persistent tmp folder using the URL hash as
80
+ filename. If the file already exists, the download is skipped.
81
+ Must be a subset of `fopen_keys`.
82
+ start: int, index of the first row to use; use for slicing the data.
83
+ stop: int or inf, index of the row after the last one to use.
84
+
85
+ Note:
86
+ This simple data input does not allow for nested/hierarchical values,
87
+ or in any way more complicated values like vectors. Use TFDS for that.
88
+
89
+ The way start/stop arguments are used is as in list slicing[start:stop].
90
+ """
91
+ self.examples = []
92
+
93
+ with tf.io.gfile.GFile(fname) as f:
94
+ for i, line in enumerate(f):
95
+ if (start or 0) <= i < (stop or float("inf")):
96
+ try:
97
+ self.examples.append(json.loads(line))
98
+ except json.decoder.JSONDecodeError as e:
99
+ raise ValueError(f"Invalid JSON in line {i}:\n{line}") from e
100
+
101
+ if download_keys:
102
+ for k in download_keys:
103
+ assert k in fopen_keys, (
104
+ f"{k} in download_keys but missing from fopen_keys {fopen_keys}")
105
+
106
+ # TODO: b/lbeyer - use info from trainer instead, move that to utils.
107
+ logging.info( # pylint: disable=logging-fstring-interpolation
108
+ f"\u001b[33mNOTE\u001b[0m: Downloading {download_keys} "
109
+ f"for dataset {fname} ({len(self.examples)} examples) ...")
110
+
111
+ def _dl_one(ex):
112
+ for k in download_keys:
113
+ ex[k] = cached_download(ex[k])
114
+
115
+ ThreadPool(100).map(_dl_one, self.examples)
116
+ print("Done")
117
+ logging.info("\u001b[33mNOTE\u001b[0m: Done downloading.")
118
+
119
+ # Normalize.
120
+ if isinstance(fopen_keys, (list, tuple)):
121
+ self.fopen_keys = {k: "" for k in fopen_keys}
122
+ else:
123
+ self.fopen_keys = fopen_keys or {}
124
+
125
+ # We need to apply fopen path prefix here already, because doing so while
126
+ # actually reading the files in TF, things are symbolic :(
127
+ for ex in self.examples:
128
+ for k, dirname in self.fopen_keys.items():
129
+ ex[k] = os.path.join(dirname, ex[k])
130
+
131
+ def _indices(self, *, process_split=True, process_index=None):
132
+ indices = np.arange(len(self.examples))
133
+
134
+ if not process_split:
135
+ return list(indices)
136
+
137
+ pid = jax.process_index() if process_index is None else process_index
138
+ return list(np.array_split(indices, jax.process_count())[pid])
139
+
140
+ @overrides.overrides
141
+ def get_tfdata(self, ordered=False, *, process_split=True, allow_cache=True):
142
+ del allow_cache # We don't cache anything anyways.
143
+ assert not process_split or len(self.examples) >= jax.process_count(), (
144
+ "Process splitting the data with fewer examples than processes!?")
145
+
146
+ my_idxs = self._indices(process_split=process_split)
147
+ if not ordered:
148
+ np.random.shuffle(my_idxs)
149
+
150
+ dataset = tf.data.Dataset.from_generator(
151
+ generator=lambda: ({"id": str(i), **self.examples[i]} for i in my_idxs),
152
+ output_signature={
153
+ "id": _guess_signature("0"),
154
+ **{k: _guess_signature(v) for k, v in self.examples[0].items()},
155
+ })
156
+
157
+ def _read_files(example):
158
+ for k in self.fopen_keys:
159
+ example[k] = tf.io.read_file(example[k])
160
+ return example
161
+ dataset = dataset.map(_read_files)
162
+
163
+ return dataset
164
+
165
+ @property
166
+ @overrides.overrides
167
+ def total_examples(self):
168
+ return len(self.examples)
169
+
170
+ @overrides.overrides
171
+ def num_examples_per_process(self):
172
+ return [len(self._indices(process_index=pid))
173
+ for pid in range(jax.process_count())]
174
+
175
+
176
+ def _guess_signature(value):
177
+ return tf.TensorSpec.from_tensor(tf.constant(value))
Tipsomaly/model/big_vision/datasets/sequence_packing.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Packed Sequence Op."""
16
+
17
+ # Forked from
18
+ # https://github.com/google/maxtext/blob/main/MaxText/sequence_packing.py.
19
+
20
+
21
+ from typing import Dict, Optional, List, Union
22
+
23
+ from flax import traverse_util
24
+ import tensorflow as tf
25
+
26
+ AUTOTUNE = tf.data.experimental.AUTOTUNE
27
+ FLATTEN_SEPARATOR = "<|sep|>"
28
+
29
+
30
+ def pack_dataset(
31
+ dataset: tf.data.Dataset,
32
+ batch_size: int | None,
33
+ key2length: Union[int, Dict[str, int]],
34
+ keys: Optional[List[str | tuple[str, ...]]] = None) -> tf.data.Dataset:
35
+ """Creates a 'packed' version of a dataset on-the-fly.
36
+
37
+ Wrap `tensorflow.grain` ops.
38
+
39
+ This is meant to replace the irritation of having to create a separate
40
+ "packed" version of a dataset to train efficiently on TPU.
41
+ Each example in the output dataset represents several examples in the
42
+ input dataset.
43
+
44
+ For each key in the input dataset, two additional keys are created:
45
+ <key>_segment_ids: an int32 tensor identifying the parts
46
+ representing the original example.
47
+ <key>_positions: an int32 tensor identifying the position within the original
48
+ example.
49
+
50
+ Example:
51
+ Two input examples get combined to form an output example.
52
+ The input examples are:
53
+ {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
54
+ {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
55
+ The output example is:
56
+ {
57
+ "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
58
+ "inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
59
+ "inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
60
+ "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
61
+ "targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
62
+ "targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
63
+ }
64
+ 0 represents padding in both the inputs and the outputs.
65
+ Sequences in the incoming examples are truncated to length "length", and the
66
+ sequences in the output examples all have fixed (padded) length "length".
67
+
68
+ Args:
69
+ dataset: A `tf.data.Dataset`.
70
+ batch_size: Batch size of the packed dataset.
71
+ key2length: An integer, or a dict from feature-key to integer.
72
+ keys: A list of strings (e.g. ["inputs", "targets"]).
73
+
74
+ Returns:
75
+ A `tf.data.Dataset`.
76
+ """
77
+ raise ValueError("Not implemented in OSS yet.")
Tipsomaly/model/big_vision/datasets/tfds.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TensorFlow Datasets as data source for big_vision."""
16
+ import functools
17
+
18
+ import big_vision.datasets.core as ds_core
19
+ import jax
20
+ import numpy as np
21
+ import overrides
22
+ import tensorflow as tf
23
+ import tensorflow_datasets as tfds
24
+
25
+
26
+ class DataSource(ds_core.DataSource):
27
+ """Use TFDS as a data source."""
28
+
29
+ def __init__(self, name, split, data_dir=None, skip_decode=("image",)):
30
+ self.builder = _get_builder(name, data_dir)
31
+ self.split = split
32
+ # Each host is responsible for a fixed subset of data
33
+ process_splits = tfds.even_splits(split, jax.process_count())
34
+ self.process_split = process_splits[jax.process_index()]
35
+ self.skip_decode = skip_decode
36
+
37
+ @overrides.overrides
38
+ def get_tfdata(
39
+ self, ordered=False, *, process_split=True, allow_cache=True, **kw):
40
+ # The tf.data may use a lot of RAM, so we need to expose the option of not
41
+ # keeping this in memory when we use lots of input pipelines, such as when
42
+ # having many ephemeral evaluators.
43
+ return (_cached_get_dataset if allow_cache else _get_dataset)(
44
+ self.builder, self.skip_decode,
45
+ split=self.process_split if process_split else self.split,
46
+ shuffle_files=not ordered,
47
+ **kw)
48
+
49
+ @property
50
+ @overrides.overrides
51
+ def total_examples(self):
52
+ return self.builder.info.splits[self.split].num_examples
53
+
54
+ @overrides.overrides
55
+ def num_examples_per_process(self):
56
+ splits = tfds.even_splits(self.split, jax.process_count())
57
+ return [self.builder.info.splits[s].num_examples for s in splits]
58
+
59
+
60
+ @functools.cache
61
+ def _get_builder(dataset, data_dir):
62
+ if dataset == "from_data_dir":
63
+ return tfds.builder_from_directory(data_dir)
64
+ else:
65
+ return tfds.builder(dataset, data_dir=data_dir, try_gcs=True)
66
+
67
+
68
+ # Cache as it may well take 1-2min on large datasets, and we may use the same
69
+ # multiple times (eg various evaluators).
70
+ def _get_dataset(builder, skip_decode, shuffle_files, split=None, **rckw):
71
+ """Returns a tf.data to be used."""
72
+ ds = builder.as_dataset(
73
+ split=split, shuffle_files=shuffle_files,
74
+ read_config=tfds.ReadConfig(
75
+ skip_prefetch=True, # We prefetch after pipeline.
76
+ try_autocache=False, # We control this, esp. for few-shot.
77
+ add_tfds_id=True,
78
+ **rckw,
79
+ ),
80
+ decoders={
81
+ f: tfds.decode.SkipDecoding()
82
+ for f in skip_decode if f in builder.info.features
83
+ })
84
+
85
+ def _hash_tfds_id(example):
86
+ id_ = tf.strings.to_hash_bucket_strong(
87
+ example["tfds_id"],
88
+ np.iinfo(np.uint32).max, # Max value
89
+ [3714561454027272724, 8800639020734831960]) # Magic.
90
+ example["_id"] = tf.bitcast(id_, tf.int32)[0] # good device dtype.
91
+ return example
92
+
93
+ return ds.map(_hash_tfds_id)
94
+ _cached_get_dataset = functools.cache(_get_dataset)
Tipsomaly/model/big_vision/evaluators/__init__.py ADDED
File without changes
Tipsomaly/model/big_vision/evaluators/classification.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for the classfication task."""
16
+ # pylint: disable=consider-using-from-import
17
+
18
+ import functools
19
+
20
+ from big_vision.evaluators import common
21
+ import big_vision.utils as u
22
+ import jax
23
+ import jax.numpy as jnp
24
+
25
+
26
+ # Temporary global flag to facilitate backwards compatability. Will be removed
27
+ # by the end of year 2023.
28
+ API = 'jit'
29
+
30
+
31
+ # To avoid re-compiling the function for every new instance of the same
32
+ # evaluator on a different dataset!
33
+ @functools.cache
34
+ def get_eval_fn(predict_fn, loss_name):
35
+ """Produces eval function, also applies pmap."""
36
+ @jax.jit
37
+ def _eval_fn(train_state, batch, labels, mask):
38
+ logits, *_ = predict_fn(train_state, batch)
39
+
40
+ # Ignore the entries with all zero labels for evaluation.
41
+ mask *= labels.max(axis=1)
42
+
43
+ loss = getattr(u, loss_name)(
44
+ logits=logits, labels=labels, reduction=False)
45
+ loss = jnp.sum(loss * mask)
46
+
47
+ top1_idx = jnp.argmax(logits, axis=1)
48
+ # Extracts the label at the highest logit index for each image.
49
+ top1_correct = jnp.take_along_axis(
50
+ labels, top1_idx[:, None], axis=1)[:, 0]
51
+ ncorrect = jnp.sum(top1_correct * mask)
52
+ nseen = jnp.sum(mask)
53
+ return ncorrect, loss, nseen
54
+ return _eval_fn
55
+
56
+
57
+ class Evaluator:
58
+ """Classification evaluator."""
59
+
60
+ def __init__(self, predict_fn, loss_name, label_key='labels', **kw):
61
+ self.get_data_iter, self.steps = common.eval_input_pipeline(**kw)
62
+ self.eval_fn = get_eval_fn(predict_fn, loss_name)
63
+ self.label_key = label_key
64
+
65
+ def run(self, train_state):
66
+ """Computes all metrics."""
67
+ ncorrect, loss, nseen = 0, 0, 0
68
+ for _, batch in zip(range(self.steps), self.get_data_iter()):
69
+ labels, mask = batch.pop(self.label_key), batch.pop('_mask')
70
+ batch_ncorrect, batch_losses, batch_nseen = jax.device_get(
71
+ self.eval_fn(train_state, batch, labels, mask))
72
+ ncorrect += batch_ncorrect
73
+ loss += batch_losses
74
+ nseen += batch_nseen
75
+ yield ('prec@1', ncorrect / nseen)
76
+ yield ('loss', loss / nseen)
Tipsomaly/model/big_vision/evaluators/common.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utils for evaluators in general."""
16
+
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import json
21
+ import os
22
+ from typing import Any, Callable
23
+
24
+ from absl import flags
25
+ from big_vision import input_pipeline
26
+ from big_vision.datasets import core as ds_core
27
+ from big_vision.pp import builder as pp_builder
28
+ import big_vision.utils as u
29
+ import flax
30
+ import jax
31
+ import numpy as np
32
+
33
+ from tensorflow.io import gfile
34
+
35
+
36
+ def from_config(config, predict_fns,
37
+ write_note=lambda s: s,
38
+ get_steps=lambda key, cfg: cfg[f"{key}_steps"],
39
+ devices=None):
40
+ """Creates a list of evaluators based on `config`."""
41
+ evaluators = []
42
+ specs = config.get("evals", {})
43
+
44
+ for name, cfg in specs.items():
45
+ write_note(name)
46
+
47
+ # Pop all generic settings off so we're left with eval's kwargs in the end.
48
+ cfg = cfg.to_dict()
49
+ module = cfg.pop("type", name)
50
+ pred_key = cfg.pop("pred", "predict")
51
+ pred_kw = cfg.pop("pred_kw", None)
52
+ prefix = cfg.pop("prefix", f"{name}/")
53
+ cfg.pop("skip_first", None)
54
+ logsteps = get_steps("log", cfg)
55
+ for typ in ("steps", "epochs", "examples", "percent"):
56
+ cfg.pop(f"log_{typ}", None)
57
+
58
+ # Use same batch_size as eval by default, to reduce fragmentation.
59
+ # TODO: eventually remove all the deprecated names...
60
+ cfg["batch_size"] = cfg.get("batch_size") or config.get("batch_size_eval") or config.get("input.batch_size") or config.get("batch_size") # pylint: disable=line-too-long
61
+
62
+ module = importlib.import_module(f"big_vision.evaluators.{module}")
63
+
64
+ if devices is not None:
65
+ cfg["devices"] = devices
66
+
67
+ api_type = getattr(module, "API", "pmap")
68
+ if api_type == "pmap" and "devices" in cfg:
69
+ raise RuntimeError(
70
+ "You are seemingly using the old pmap-based evaluator, but with "
71
+ "jit-based train loop, see (internal link) for more details.")
72
+ if api_type == "jit" and "devices" not in cfg:
73
+ raise RuntimeError(
74
+ "You are seemingly using new jit-based evaluator, but with "
75
+ "old pmap-based train loop, see (internal link) for more details.")
76
+
77
+ try:
78
+ predict_fn = predict_fns[pred_key]
79
+ except KeyError as e:
80
+ raise ValueError(
81
+ f"Unknown predict_fn '{pred_key}'. Available predict_fns are:\n"
82
+ + "\n".join(predict_fns)) from e
83
+ if pred_kw is not None:
84
+ predict_fn = _CacheablePartial(predict_fn, flax.core.freeze(pred_kw))
85
+ evaluator = module.Evaluator(predict_fn, **cfg)
86
+ evaluators.append((name, evaluator, logsteps, prefix))
87
+
88
+ return evaluators
89
+
90
+
91
+ @dataclasses.dataclass(frozen=True, eq=True)
92
+ class _CacheablePartial:
93
+ """partial(fn, **kwargs) that defines hash and eq - to help with jit caches.
94
+
95
+ This is particularly common in evaluators when one has many evaluator
96
+ instances that run on difference slices of data.
97
+
98
+ Example:
99
+
100
+ ```
101
+ f1 = _CacheablePartial(fn, a=1)
102
+ jax.jit(f1)(...)
103
+ jax.jit(_CacheablePartial(fn, a=1))(...) # fn won't be retraced.
104
+ del f1
105
+ jax.jit(_CacheablePartial(fn, a=1))(...) # fn will be retraced.
106
+ ```
107
+ """
108
+ fn: Callable[..., Any]
109
+ kwargs: flax.core.FrozenDict
110
+
111
+ def __call__(self, *args, **kwargs):
112
+ return functools.partial(self.fn, **self.kwargs)(*args, **kwargs)
113
+
114
+
115
+ def eval_input_pipeline(
116
+ data, pp_fn, batch_size, devices, keep_on_cpu=(),
117
+ cache="pipeline", prefetch=1, warmup=False,
118
+ ):
119
+ """Create an input pipeline in the way used by most evaluators.
120
+
121
+ Args:
122
+ data: The configuration to create the data source (like for training).
123
+ pp_fn: A string representing the preprocessing to be performed.
124
+ batch_size: The batch size to use.
125
+ devices: The devices that the batches are sharded and pre-fetched onto.
126
+ keep_on_cpu: See input_pipeline.start_global. Entries in the batch that
127
+ should be kept on the CPU, hence could be ragged or of string type.
128
+ cache: One of "none", "pipeline", "raw_data", "final_data". Determines what
129
+ part of the input stream should be cached across evaluator runs. They use
130
+ more and more RAM, but make evals faster, in that order.
131
+ - "none": Entirely re-create and destroy the input pipeline each run.
132
+ - "pipeline": Keep the (tf.data) pipeline object alive across runs.
133
+ - "raw_data": Cache the full raw data before pre-processing.
134
+ - "final_data": Cache the full raw data after pre-processing.
135
+ prefetch: How many batches to fetch ahead.
136
+ warmup: Start fetching the first batch at creation time (right now),
137
+ instead of once the iteration starts.
138
+
139
+ Returns:
140
+ A tuple (get_iter, steps), the first element is a function that returns
141
+ the iterator to be used for an evaluation, the second one is how many steps
142
+ should be iterated for doing one evaluation.
143
+ """
144
+ assert (
145
+ cache is None
146
+ or cache.lower() in ("none", "pipeline", "raw_data", "final_data")
147
+ ), f"Unknown value for cache: {cache}"
148
+ data_source = ds_core.get(**data)
149
+ tfdata, steps = input_pipeline.make_for_inference(
150
+ data_source.get_tfdata(ordered=True, allow_cache=cache.lower() != "none"),
151
+ batch_size=batch_size,
152
+ num_ex_per_process=data_source.num_examples_per_process(),
153
+ preprocess_fn=pp_builder.get_preprocess_fn(pp_fn, str(data)),
154
+ cache_final=cache == "raw_data",
155
+ cache_raw=cache == "final_data")
156
+ get_data_iter = lambda: input_pipeline.start_global(
157
+ tfdata, devices, prefetch, keep_on_cpu, warmup)
158
+
159
+ # Possibly create one persistent iterator:
160
+ if cache in ("pipeline", "raw_data", "final_data"):
161
+ data_iter = get_data_iter()
162
+ get_data_iter = lambda: data_iter
163
+
164
+ return get_data_iter, steps
165
+
166
+
167
+ def process_sum(tree):
168
+ """Sums the pytree across all processes."""
169
+ if jax.process_count() == 1: # Avoids corner-cases on donuts.
170
+ return tree
171
+
172
+ with jax.transfer_guard_device_to_host("allow"):
173
+ gathered = jax.experimental.multihost_utils.process_allgather(tree)
174
+ return jax.tree.map(functools.partial(np.sum, axis=0), gathered)
175
+
176
+
177
+ def resolve_outfile(outfile, split="", **kw):
178
+ if not outfile:
179
+ return None
180
+
181
+ # A caveat: when workdir doesn't exist but is in the `outfile`, we should
182
+ # skip. This is common in small runs or runlocal debuggings.
183
+ if "{workdir}" in outfile and not flags.FLAGS.workdir:
184
+ return None
185
+
186
+ return outfile.format(
187
+ workdir=flags.FLAGS.workdir,
188
+ split="".join(c if c not in "[]%:" else "_" for c in (split or "")),
189
+ step=getattr(u.chrono, "prev_step", None),
190
+ **kw,
191
+ )
192
+
193
+
194
+ def multiprocess_write_json(outfile, jobj): # jobj = "json object"
195
+ """Write a single json file combining all processes' `jobj`s."""
196
+ if not outfile:
197
+ return
198
+
199
+ outfile = resolve_outfile(outfile)
200
+ gfile.makedirs(os.path.dirname(outfile))
201
+
202
+ if isinstance(jobj, list):
203
+ combine_fn = list.extend
204
+ elif isinstance(jobj, dict):
205
+ combine_fn = dict.update
206
+ else:
207
+ raise TypeError(f"Can only write list or dict jsons, but got {type(jobj)}")
208
+
209
+ # First, each process writes its own file.
210
+ with gfile.GFile(outfile + f".p{jax.process_index()}", "w+") as f:
211
+ f.write(json.dumps(jobj))
212
+
213
+ u.sync() # Wait for all files to be written; `with` above does close/flush.
214
+
215
+ # Have process 0 collect, concat, and write final output.
216
+ all_json = type(jobj)()
217
+ if jax.process_index() == 0:
218
+ for pid in range(jax.process_count()):
219
+ with gfile.GFile(outfile + f".p{pid}", "r") as f:
220
+ combine_fn(all_json, json.loads(f.read()))
221
+ with gfile.GFile(outfile, "w+") as f:
222
+ f.write(json.dumps(all_json))
223
+
224
+ # Cleanup time
225
+ u.sync()
226
+ gfile.remove(outfile + f".p{jax.process_index()}")
227
+
228
+ return all_json
Tipsomaly/model/big_vision/evaluators/fewshot_lsr.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utils for few-shot evaluation."""
16
+ # pylint: disable=consider-using-from-import,g-importing-member
17
+
18
+ import functools
19
+
20
+ import big_vision.datasets.core as ds_core
21
+ import big_vision.input_pipeline as input_pipeline
22
+ import big_vision.pp.builder as pp_builder
23
+ import big_vision.utils as u
24
+ import jax
25
+ import jax.numpy as jnp
26
+ from jax.sharding import NamedSharding as Sharding
27
+ from jax.sharding import PartitionSpec as P
28
+ import numpy as np
29
+
30
+ BIAS_CONSTANT = 100.0
31
+
32
+ # Temporary global flag to facilitate backwards compatability. Will be removed
33
+ # by the end of year 2023.
34
+ API = "jit"
35
+
36
+
37
+ # Setup function for few-shot regression on CPU to avoid "polluting" the TPU.
38
+ @u.jit_cpu(static_argnums=(2,))
39
+ def _precompute_cache(x, y, num_classes):
40
+ """Cache quantities to speed-up the computation of L2-regularized least-sq."""
41
+ # Whiten
42
+ mean = jnp.mean(x, axis=0, keepdims=True)
43
+ std = jnp.std(x, axis=0, keepdims=True) + 1e-5
44
+ x = (x - mean) / std
45
+
46
+ # Add a constant feature for the bias, large so it's almost unregularized:
47
+ x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
48
+
49
+ # To one-hot representation rescaled into {-1, 1}
50
+ y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0
51
+
52
+ num_points, dim = x.shape
53
+ # Let N be the number of points, D the dimension and C the number of classes.
54
+ # We have x of shape (N, D) and y of shape (N, C).
55
+ # For least-squares, we can compute
56
+ #
57
+ # (A) when N >= D, (x^T x + l2 Id)^{-1} x^T y
58
+ # (B) when D > N, x^T (x x^T + l2 Id)^{-1} y
59
+ #
60
+ # We pre-compute the eigen-decomposition of either x^T x or x x^T which
61
+ # becomes q diag(eigs) q^T with q unitary matrix either (D, D) or (N, N)
62
+ # and eigs a vector (D,) or (N,).
63
+ #
64
+ # For any l2 > 0, we can compute (x^T x + l2 Id)^{-1} or (x x^T + l2 Id)^{-1}
65
+ # by simply computing q (diag(eigs) + l2 Id)^{-1} q^T.
66
+ # (SVD would be more natural here, but it proved slower, so we use eigh)
67
+ #
68
+ # Both cases (A) and (B) can be viewed as lhs (diag(eigs) + l2 Id)^{-1} rhs,
69
+ # where lhs/rhs are pre-computed left/right-hand sides to specify.
70
+ #
71
+ # Detailed evaluation in terms of time and fewshot metrics can be found in
72
+ # (internal link)
73
+ #
74
+ # Implemented by Rodolphe Jenatton.
75
+ if num_points >= dim:
76
+ eigs, q = jnp.linalg.eigh(x.T @ x)
77
+ rhs = q.T @ (x.T @ y)
78
+ lhs = q
79
+ else:
80
+ eigs, q = jnp.linalg.eigh(x @ x.T)
81
+ rhs = q.T @ y
82
+ lhs = x.T @ q
83
+
84
+ cache = {
85
+ "eigs": eigs,
86
+ "rhs": rhs,
87
+ "lhs": lhs,
88
+ "mean": mean,
89
+ "std": std
90
+ }
91
+ return cache
92
+
93
+
94
+ @u.jit_cpu()
95
+ def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg):
96
+ """Computes (x,y) linear regression accuracy on (x_test, y_test)."""
97
+
98
+ x_test = (x_test - cache["mean"]) / cache["std"]
99
+ x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
100
+
101
+ rhs = cache["rhs"]
102
+ lhs = cache["lhs"]
103
+ eigs = cache["eigs"]
104
+
105
+ # See comments in _precompute_cache for context about the formula.
106
+ scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs))
107
+ scaling = scaling.reshape((1, -1))
108
+ w = (lhs * scaling) @ rhs
109
+ # Predict test-set values and measure their accuracy
110
+ preds = jnp.argmax(x_test @ w, axis=1)
111
+ return jnp.mean(preds == y_test)
112
+
113
+
114
+ class Evaluator:
115
+ """Class for few-shot evaluation."""
116
+
117
+ def __init__(self, predict_fn, batch_size,
118
+ datasets, shots, l2_reg,
119
+ pp_train, pp_eval, display_first,
120
+ representation_layer=None, num_seeds=3,
121
+ label_key="label", mask_key="_mask", data_dir=None, *,
122
+ devices):
123
+ self.datasets = datasets
124
+ self.shots = shots
125
+ self.l2_reg = l2_reg
126
+ self.batch_size = batch_size
127
+ self.pp_tr = pp_train
128
+ self.pp_te = pp_eval
129
+ self.display_first = display_first
130
+ self._datasets = {} # Cache for tfds data. Persists while object is alive.
131
+ self._repr = {} # Cache for precomputed repr. Persists within the run call.
132
+ self.num_seeds = num_seeds
133
+ self.label_key = label_key
134
+ self.mask_key = mask_key
135
+ self.data_dir = data_dir
136
+ self.devices = devices
137
+ self.mesh = jax.sharding.Mesh(devices, ("devices",))
138
+ self.repr_fn = self.get_representation_fn(
139
+ predict_fn, representation_layer)
140
+
141
+ def get_representation_fn(self, predict_fn, representation_layer):
142
+ # `out_shardings=Sharding(self.mesh, P())` will "all_gather" the outputs.
143
+ @functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P()))
144
+ def _repr_fn(train_state, batch, labels, mask):
145
+ zimg, *_, out = predict_fn(train_state, batch)
146
+ if representation_layer is not None:
147
+ rep = u.tree_get(out, representation_layer)
148
+ else:
149
+ rep = zimg
150
+ return rep, labels, mask
151
+ return _repr_fn
152
+
153
+ # Setup input pipeline.
154
+ def _get_dataset(self, dataset, train_split, test_split):
155
+ """Lazy-loads given dataset."""
156
+ key = (dataset, train_split, test_split)
157
+ try:
158
+ return self._datasets[key]
159
+ except KeyError:
160
+ # NOTE: only supporting TFDS data for now for bwd compat/lazyness.
161
+ train_data = ds_core.get(
162
+ name=dataset, split=train_split, data_dir=self.data_dir
163
+ )
164
+ test_data = ds_core.get(
165
+ name=dataset, split=test_split, data_dir=self.data_dir
166
+ )
167
+ train_ds, batches_tr = input_pipeline.make_for_inference(
168
+ train_data.get_tfdata(ordered=True),
169
+ num_ex_per_process=train_data.num_examples_per_process(),
170
+ batch_size=self.batch_size,
171
+ preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr))
172
+ test_ds, batches_te = input_pipeline.make_for_inference(
173
+ test_data.get_tfdata(ordered=True),
174
+ num_ex_per_process=test_data.num_examples_per_process(),
175
+ batch_size=self.batch_size,
176
+ preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te))
177
+
178
+ num_classes = train_data.builder.info.features[self.label_key].num_classes
179
+ return self._datasets.setdefault(
180
+ key, (train_ds, batches_tr, test_ds, batches_te, num_classes))
181
+
182
+ def _get_repr(self, params, data, steps):
183
+ """Compute representation for the whole dataset."""
184
+ pre_logits_list = []
185
+ labels_list = []
186
+ for batch, _ in zip(
187
+ input_pipeline.start_global(data, self.devices, 0), range(steps)):
188
+ labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key)
189
+ pre_logits, labels, mask = jax.device_get(self.repr_fn(
190
+ params, batch, labels, mask))
191
+ mask = mask.astype(bool)
192
+ pre_logits_list.append(pre_logits[mask])
193
+ labels_list.append(labels[mask])
194
+ pre_logits = np.concatenate(pre_logits_list, axis=0)
195
+ labels = np.concatenate(labels_list, axis=0)
196
+
197
+ return pre_logits, labels
198
+
199
+ def compute_fewshot_metrics(self, train_state, seed,
200
+ dataset, train_split, test_split):
201
+ """Compute few-shot metrics on one dataset."""
202
+ if dataset in self._repr:
203
+ repr_train, labels_train, repr_test, labels_test, num_classes = (
204
+ self._repr[dataset])
205
+ else:
206
+ train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset(
207
+ dataset, train_split, test_split)
208
+ repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr)
209
+ repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te)
210
+ self._repr[dataset] = (repr_train, labels_train,
211
+ repr_test, labels_test,
212
+ num_classes)
213
+
214
+ # Collect where we have samples of which classes.
215
+ rng = np.random.default_rng(seed)
216
+ class_indices = [rng.permutation(np.where(labels_train == cls_i)[0])
217
+ for cls_i in range(num_classes)]
218
+
219
+ results = {}
220
+ for shots in self.shots:
221
+ all_idx = [indices[:shots] for indices in class_indices]
222
+ all_idx = np.concatenate(all_idx, axis=0)
223
+ x = u.put_cpu(repr_train[all_idx])
224
+ y = u.put_cpu(labels_train[all_idx])
225
+ repr_test, labels_test = u.put_cpu((repr_test, labels_test))
226
+
227
+ # Note the code is optimized to solve multiple LSR tasks for changing l2
228
+ # strength, even though we currently used the fixed l2_reg constant.
229
+ cache = _precompute_cache(x, y, num_classes)
230
+ acc = _eig_fewshot_acc_fn(
231
+ cache, repr_test, labels_test, u.put_cpu(self.l2_reg))
232
+ results[shots] = jax.device_get(acc)
233
+
234
+ return results
235
+
236
+ def run(self, train_state):
237
+ """New API executed in terms of old API."""
238
+ self._repr = {}
239
+ for seed in range(self.num_seeds):
240
+ for name, dataset_args in self.datasets.items():
241
+ result = self.compute_fewshot_metrics(train_state, seed, *dataset_args)
242
+ for shots, v in result.items():
243
+ prefix = "a/" if (name, shots) in self.display_first else "z/"
244
+ suffix = f"-seed-{seed}"
245
+ yield f"{prefix}{name}_{shots}shot{suffix}", v
Tipsomaly/model/big_vision/evaluators/mean.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator for computing mean of per-example metrics.
16
+
17
+ This evaluator can be used in two ways:
18
+ 1. Create a new evaluator with reduced boilerplate by inheriting from it.
19
+ 2. For quick prototyping, use this with predict_fns which return the metrics.
20
+ """
21
+ from functools import partial
22
+ from typing import Mapping
23
+
24
+ from big_vision.evaluators import common
25
+
26
+ import jax
27
+ import jax.numpy as jnp
28
+ import numpy as np
29
+
30
+
31
+ # Temporary global flag to facilitate backwards compatability. Will be removed
32
+ # by the end of year 2023.
33
+ API = 'jit'
34
+
35
+
36
+ # Note: global to avoid jax re-compiling across different evaluator instances.
37
+ @partial(jax.jit, static_argnums=0)
38
+ def _run_predict_fn(predict_fn, train_state, batch):
39
+ """Sum per-example metrics weighted by `_mask`."""
40
+ metrics = predict_fn(train_state, batch)
41
+ mask = batch['_mask']
42
+ # Sanity check output format of predict_fn.
43
+ assert isinstance(metrics, Mapping), 'predict_fn must return a dict'
44
+ for y in jax.tree.leaves(metrics):
45
+ if y.shape != mask.shape:
46
+ raise ValueError(
47
+ f'Expected per-example metrics of shape {mask.shape} found '
48
+ f'{jax.tree.map(lambda x: x.shape, metrics)}.')
49
+ metrics = {**metrics, '_mask': mask}
50
+ return jax.tree.map(lambda x: jnp.sum(jnp.where(mask, x, 0)), metrics)
51
+
52
+
53
+ class Evaluator:
54
+ """Report the mean of per-example metrics computed by predict_fn.
55
+
56
+ `predict_fn(params, batch)` must return a dict from metric name to
57
+ per-example metrics of shape [batch_size].
58
+ """
59
+
60
+ def __init__(self, predict_fn, **kw):
61
+ self.get_data_iter, self.steps = common.eval_input_pipeline(**kw)
62
+ self.predict_fn = partial(_run_predict_fn, predict_fn)
63
+
64
+ def run(self, train_state):
65
+ """Computes all metrics."""
66
+ metrics = []
67
+
68
+ # Compute batch metrics without blocking.
69
+ for _, batch in zip(range(self.steps), self.get_data_iter()):
70
+ batch_metrics = self.predict_fn(train_state, batch)
71
+ metrics.append(batch_metrics)
72
+
73
+ # Transfer metrics (blocking).
74
+ metrics = jax.device_get(metrics)
75
+
76
+ # Accumulate metrics across batches.
77
+ metrics_sum = jax.tree.map(lambda *x: np.sum(x), *metrics)
78
+ mask_sum = metrics_sum.pop('_mask')
79
+ for key, value_sum in metrics_sum.items():
80
+ yield (key, value_sum / mask_sum)
Tipsomaly/model/big_vision/evaluators/save.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Evaluator that save inputs and outputs of prediction functions."""
16
+ import functools
17
+
18
+ from absl import flags
19
+ from absl import logging
20
+
21
+ from big_vision import input_pipeline
22
+ from big_vision import optax as bv_optax
23
+ from big_vision import utils
24
+ from big_vision.datasets import core as ds_core
25
+ from big_vision.pp import builder as pp_builder
26
+
27
+ import jax
28
+ import numpy as np
29
+
30
+ # Temporary global flag to facilitate backwards compatability. Will be removed
31
+ # by the end of year 2023.
32
+ API = 'jit'
33
+
34
+
35
+ # Note: global to avoid jax re-compiling across different evaluator instances.
36
+ def _run_predict_fn(predict_fn, train_state, batch):
37
+ """Run predict_fn and gather all outputs on all devices."""
38
+ y = predict_fn(train_state, batch)
39
+ return {'inputs': batch, 'outputs': y}
40
+
41
+
42
+ class Evaluator:
43
+ """Evaluator that saves the inputs and outputs of a prediction function.
44
+
45
+ Example configuration:
46
+
47
+ ```
48
+ config.evals.save_pred = {
49
+ 'type': 'save',
50
+ 'pred': 'inference',
51
+ 'outfile': '{workdir}/inference-{step:09d}.npz',
52
+ 'data': ..., 'pp_fn': ..., 'log_steps': ...,
53
+ }
54
+ ```
55
+
56
+ Results can then be easily inspected in a notebook such as:
57
+
58
+ ```
59
+ results = utils.load_checkpoint("<full_path_to_outfile>")
60
+ inputs, outputs = (results["inputs"], results["outputs"])
61
+ ```
62
+ """
63
+
64
+ def __init__(self, predict_fn, data, pp_fn, batch_size, outfile,
65
+ cache_final=True, cache_raw=False, prefetch=1, *, devices):
66
+ replicate = jax.sharding.NamedSharding(
67
+ jax.sharding.Mesh(devices, ('devices',)),
68
+ jax.sharding.PartitionSpec()
69
+ )
70
+ self.predict_fn = functools.partial(
71
+ jax.jit(_run_predict_fn, static_argnums=0, out_shardings=replicate),
72
+ predict_fn,
73
+ )
74
+
75
+ data = ds_core.get(**data)
76
+ self.dataset, self.steps = input_pipeline.make_for_inference(
77
+ data.get_tfdata(ordered=True),
78
+ batch_size=batch_size,
79
+ num_ex_per_process=data.num_examples_per_process(),
80
+ preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
81
+ cache_final=cache_final,
82
+ cache_raw=cache_raw,
83
+ )
84
+ self.data_iter = input_pipeline.start_global(
85
+ self.dataset, devices, prefetch
86
+ )
87
+
88
+ self.outfile = outfile
89
+
90
+ def run(self, train_state):
91
+ """Compute all predictions, gather in main host and save in outfile."""
92
+ step = jax.device_get(bv_optax.get_count(train_state['opt'], jittable=True))
93
+ outfile = self.outfile.format(workdir=flags.FLAGS.workdir, step=step)
94
+
95
+ count = 0
96
+ outputs = []
97
+ for _, batch in zip(range(self.steps), self.data_iter):
98
+ out = self.predict_fn(train_state, batch)
99
+ if jax.process_index():
100
+ continue
101
+
102
+ out = jax.device_get(out)
103
+ mask = out['inputs']['_mask']
104
+ out = jax.tree.map(lambda x: x[mask == 1], out) # pylint: disable=cell-var-from-loop
105
+ count += mask.shape[0]
106
+ out['inputs'].pop('_mask')
107
+ outputs.append(out)
108
+
109
+ logging.log_every_n_seconds(
110
+ logging.INFO, 'Processed %i examples so far.', 60,
111
+ count)
112
+
113
+ if jax.process_index():
114
+ return
115
+
116
+ logging.info('Saving %d examples in %s', count, outfile)
117
+ outputs = jax.tree.map(lambda *x: np.concatenate(x, axis=0), *outputs)
118
+ utils.save_checkpoint(outputs, outfile, compressed=True)
119
+ return
120
+
121
+ yield None # pylint: disable=unreachable
Tipsomaly/model/big_vision/models/__init__.py ADDED
File without changes
Tipsomaly/model/big_vision/models/bit.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ResNet V1 with GroupNorm."""
16
+
17
+ from typing import Optional, Sequence, Union
18
+
19
+ from big_vision import utils
20
+ from big_vision.models import common
21
+ import flax
22
+ import flax.linen as nn
23
+ import flax.training.checkpoints
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+
28
+ def weight_standardize(w, axis, eps):
29
+ w = w - jnp.mean(w, axis=axis)
30
+ w = w / (jnp.std(w, axis=axis) + eps)
31
+ return w
32
+
33
+
34
+ class StdConv(nn.Conv):
35
+
36
+ def param(self, name, *a, **kw):
37
+ param = super().param(name, *a, **kw)
38
+ if name == "kernel":
39
+ param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5)
40
+ return param
41
+
42
+
43
+ class ResidualUnit(nn.Module):
44
+ """Bottleneck ResNet block."""
45
+ nmid: Optional[int] = None
46
+ strides: Sequence[int] = (1, 1)
47
+
48
+ @nn.compact
49
+ def __call__(self, x):
50
+ nmid = self.nmid or x.shape[-1] // 4
51
+ nout = nmid * 4
52
+
53
+ residual = x
54
+ if x.shape[-1] != nout or self.strides != (1, 1):
55
+ residual = StdConv(nout, (1, 1), self.strides, use_bias=False,
56
+ name="conv_proj")(residual)
57
+ residual = nn.GroupNorm(name="gn_proj")(residual)
58
+
59
+ y = StdConv(nmid, (1, 1), use_bias=False, name="conv1")(x)
60
+ y = nn.GroupNorm(name="gn1")(y)
61
+ y = nn.relu(y)
62
+ y = StdConv(nmid, (3, 3), self.strides, use_bias=False, name="conv2")(y)
63
+ y = nn.GroupNorm(name="gn2")(y)
64
+ y = nn.relu(y)
65
+ y = StdConv(nout, (1, 1), use_bias=False, name="conv3")(y)
66
+
67
+ y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y)
68
+ y = nn.relu(residual + y)
69
+ return y
70
+
71
+
72
+ class ResNetStage(nn.Module):
73
+ """One stage of ResNet."""
74
+ block_size: int
75
+ first_stride: Sequence[int] = (1, 1)
76
+ nmid: Optional[int] = None
77
+
78
+ @nn.compact
79
+ def __call__(self, x):
80
+ x = ResidualUnit(self.nmid, strides=self.first_stride, name="unit1")(x)
81
+ for i in range(1, self.block_size):
82
+ x = ResidualUnit(self.nmid, name=f"unit{i + 1}")(x)
83
+ return x
84
+
85
+
86
+ class Model(nn.Module):
87
+ """ResNetV1."""
88
+ num_classes: Optional[int] = None
89
+ width: float = 1
90
+ depth: Union[int, Sequence[int]] = 50
91
+
92
+ @nn.compact
93
+ def __call__(self, image, *, train=False):
94
+ del train # Unused
95
+ blocks = get_block_desc(self.depth)
96
+ width = int(64 * self.width)
97
+
98
+ out = {}
99
+
100
+ # Root block
101
+ x = StdConv(width, (7, 7), (2, 2), use_bias=False, name="conv_root")(image)
102
+ x = nn.GroupNorm(name="gn_root")(x)
103
+ x = nn.relu(x)
104
+ x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
105
+ out["stem"] = x
106
+
107
+ # Stages
108
+ x = ResNetStage(blocks[0], nmid=width, name="block1")(x)
109
+ out["stage1"] = x
110
+ for i, block_size in enumerate(blocks[1:], 1):
111
+ x = ResNetStage(block_size, nmid=width * 2 ** i,
112
+ first_stride=(2, 2), name=f"block{i + 1}")(x)
113
+ out[f"stage{i + 1}"] = x
114
+ out["pre_logits_2d"] = x
115
+
116
+ # Head
117
+ x = out["pre_logits"] = jnp.mean(x, axis=(1, 2))
118
+
119
+ if self.num_classes:
120
+ head = nn.Dense(self.num_classes, name="head",
121
+ kernel_init=nn.initializers.zeros)
122
+ out["logits_2d"] = head(out["pre_logits_2d"])
123
+ x = out["logits"] = head(out["pre_logits"])
124
+
125
+ return x, out
126
+
127
+
128
+ # A dictionary mapping the number of layers in a resnet to the number of
129
+ # blocks in each stage of the model.
130
+ # NOTE: Does not include 18/34 as they also need non-bottleneck block!
131
+ def get_block_desc(depth):
132
+ if isinstance(depth, list): # Be robust to silly mistakes.
133
+ depth = tuple(depth)
134
+ return {
135
+ 26: [2, 2, 2, 2], # From timm, gets ~75% on ImageNet.
136
+ 50: [3, 4, 6, 3],
137
+ 101: [3, 4, 23, 3],
138
+ 152: [3, 8, 36, 3],
139
+ 200: [3, 24, 36, 3]
140
+ }.get(depth, depth)
141
+
142
+
143
+ def fix_old_checkpoints(params):
144
+ """Modifies params from old checkpoints to run with current implementation."""
145
+ params = flax.core.unfreeze(
146
+ flax.training.checkpoints.convert_pre_linen(params))
147
+ # Old linen used to store non-squeezed GN params.
148
+ params = flax.traverse_util.unflatten_dict({
149
+ k: np.squeeze(v) if (set(k)
150
+ & {"gn_root", "gn_proj", "gn1", "gn2", "gn3"}) else v
151
+ for k, v in flax.traverse_util.flatten_dict(params).items()
152
+ })
153
+ return params
154
+
155
+
156
+ def load(init_params, init_file, model_cfg, dont_load=()):
157
+ """Load init from checkpoint."""
158
+ del model_cfg # Unused
159
+ params = utils.load_params(init_file)
160
+ params = common.merge_params(params, init_params, dont_load)
161
+ params = fix_old_checkpoints(params)
162
+ return params
Tipsomaly/model/big_vision/models/bit_paper.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BiT models as in the paper (ResNet V2) w/ loading of public weights.
16
+
17
+ See reproduction proof: http://(internal link)/qY70qs6j944
18
+ """
19
+
20
+ import functools
21
+ import re
22
+ from typing import Optional, Sequence, Union
23
+
24
+ from big_vision import utils as u
25
+ from big_vision.models import bit
26
+ from big_vision.models import common
27
+ import flax.linen as nn
28
+ import jax.numpy as jnp
29
+
30
+
31
+ def standardize(x, axis, eps):
32
+ x = x - jnp.mean(x, axis=axis, keepdims=True)
33
+ x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
34
+ return x
35
+
36
+
37
+ # Defined our own, because we compute normalizing variance slightly differently,
38
+ # which does affect performance when loading pre-trained weights!
39
+ class GroupNorm(nn.Module):
40
+ """Group normalization (arxiv.org/abs/1803.08494)."""
41
+ ngroups: int = 32
42
+
43
+ @nn.compact
44
+ def __call__(self, x):
45
+
46
+ input_shape = x.shape
47
+ group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups)
48
+
49
+ x = x.reshape(group_shape)
50
+
51
+ # Standardize along spatial and group dimensions
52
+ x = standardize(x, axis=[1, 2, 4], eps=1e-5)
53
+ x = x.reshape(input_shape)
54
+
55
+ bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
56
+ x = x * self.param('scale', nn.initializers.ones, bias_scale_shape)
57
+ x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape)
58
+ return x
59
+
60
+
61
+ class StdConv(nn.Conv):
62
+
63
+ def param(self, name, *a, **kw):
64
+ param = super().param(name, *a, **kw)
65
+ if name == 'kernel':
66
+ param = standardize(param, axis=[0, 1, 2], eps=1e-10)
67
+ return param
68
+
69
+
70
+ class RootBlock(nn.Module):
71
+ """Root block of ResNet."""
72
+ width: int
73
+
74
+ @nn.compact
75
+ def __call__(self, x):
76
+ x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
77
+ use_bias=False, name='conv_root')(x)
78
+ x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)])
79
+ return x
80
+
81
+
82
+ class ResidualUnit(nn.Module):
83
+ """Bottleneck ResNet block."""
84
+ nmid: Optional[int] = None
85
+ strides: Sequence[int] = (1, 1)
86
+
87
+ @nn.compact
88
+ def __call__(self, x):
89
+ nmid = self.nmid or x.shape[-1] // 4
90
+ nout = nmid * 4
91
+ conv = functools.partial(StdConv, use_bias=False)
92
+
93
+ residual = x
94
+ x = GroupNorm(name='gn1')(x)
95
+ x = nn.relu(x)
96
+
97
+ if x.shape[-1] != nout or self.strides != (1, 1):
98
+ residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x)
99
+
100
+ x = conv(nmid, (1, 1), name='conv1')(x)
101
+ x = GroupNorm(name='gn2')(x)
102
+ x = nn.relu(x)
103
+ x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)],
104
+ name='conv2')(x)
105
+ x = GroupNorm(name='gn3')(x)
106
+ x = nn.relu(x)
107
+ x = conv(nout, (1, 1), name='conv3')(x)
108
+
109
+ return x + residual
110
+
111
+
112
+ class ResNetStage(nn.Module):
113
+ """A stage (sequence of same-resolution blocks)."""
114
+ block_size: int
115
+ nmid: Optional[int] = None
116
+ first_stride: Sequence[int] = (1, 1)
117
+
118
+ @nn.compact
119
+ def __call__(self, x):
120
+ out = {}
121
+ x = out['unit01'] = ResidualUnit(
122
+ self.nmid, strides=self.first_stride, name='unit01')(x)
123
+ for i in range(1, self.block_size):
124
+ x = out[f'unit{i+1:02d}'] = ResidualUnit(
125
+ self.nmid, name=f'unit{i+1:02d}')(x)
126
+ return x, out
127
+
128
+
129
+ class Model(nn.Module):
130
+ """ResNetV2."""
131
+ num_classes: Optional[int] = None
132
+ width: int = 1
133
+ depth: Union[int, Sequence[int]] = 50 # 50/101/152, or list of block depths.
134
+ head_zeroinit: bool = True
135
+
136
+ @nn.compact
137
+ def __call__(self, image, *, train=False):
138
+ blocks = bit.get_block_desc(self.depth)
139
+ width = int(64 * self.width)
140
+ out = {}
141
+
142
+ x = out['stem'] = RootBlock(width=width, name='root_block')(image)
143
+
144
+ # Blocks
145
+ x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x)
146
+ for i, block_size in enumerate(blocks[1:], 1):
147
+ x, out[f'stage{i + 1}'] = ResNetStage(
148
+ block_size, width * 2 ** i,
149
+ first_stride=(2, 2), name=f'block{i + 1}')(x)
150
+
151
+ # Pre-head
152
+ x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x)
153
+ x = out['pre_logits_2d'] = nn.relu(x)
154
+ x = out['pre_logits'] = jnp.mean(x, axis=(1, 2))
155
+
156
+ # Head
157
+ if self.num_classes:
158
+ kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {}
159
+ head = nn.Dense(self.num_classes, name='head', **kw)
160
+ out['logits_2d'] = head(out['pre_logits_2d'])
161
+ x = out['logits'] = head(out['pre_logits'])
162
+
163
+ return x, out
164
+
165
+
166
+ def load(init_params, init_file, model_cfg, dont_load=()):
167
+ """Loads the TF-dumped NumPy or big_vision checkpoint.
168
+
169
+ Args:
170
+ init_params: random init params from which the new head is taken.
171
+ init_file: comes from `config.model_init`, can either be an absolute
172
+ path (ie starts with /) to the checkpoint, or a string like
173
+ "L-imagenet2012" describing one of the variants from the paper.
174
+ model_cfg: the model configuration.
175
+ dont_load: list of param names to be reset to init.
176
+
177
+ Returns:
178
+ The loaded parameters.
179
+ """
180
+
181
+ # Support for vanity model names from the paper.
182
+ vanity = {
183
+ 'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz',
184
+ 'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz',
185
+ }
186
+ if init_file[0] in ('L', 'M', 'S'): # The models from the original paper.
187
+ # Supported names are of the following type:
188
+ # - 'M' or 'S': the original "upstream" model without fine-tuning.
189
+ # - 'M-ILSVRC2012': i21k model fine-tuned on i1k.
190
+ # - 'M-run0-caltech101': i21k model fine-tuned on VTAB's caltech101.
191
+ # each VTAB fine-tuning was run 3x, so there's run0, run1, run2.
192
+ if '-' in init_file:
193
+ up, down = init_file[0], init_file[1:]
194
+ else:
195
+ up, down = init_file, ''
196
+ down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down) # normalize
197
+ fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz'
198
+ fname = f'gs://bit_models/{fname}'
199
+ else:
200
+ fname = vanity.get(init_file, init_file)
201
+
202
+ params = u.load_params(fname)
203
+ params = maybe_convert_big_transfer_format(params)
204
+ return common.merge_params(params, init_params, dont_load)
205
+
206
+
207
+ def maybe_convert_big_transfer_format(params_tf):
208
+ """If the checkpoint comes from legacy codebase, convert it."""
209
+
210
+ # Only do anything at all if we recognize the format.
211
+ if 'resnet' not in params_tf:
212
+ return params_tf
213
+
214
+ # For ease of processing and backwards compatibility, flatten again:
215
+ params_tf = dict(u.tree_flatten_with_names(params_tf)[0])
216
+
217
+ # Works around some files containing weird naming of variables:
218
+ for k in list(params_tf):
219
+ k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k)
220
+ if k2 != k:
221
+ params_tf[k2] = params_tf[k]
222
+ del params_tf[k]
223
+
224
+ params = {
225
+ 'root_block': {'conv_root': {'kernel': params_tf[
226
+ 'resnet/root_block/standardized_conv2d/kernel']}},
227
+ 'norm-pre-head': {
228
+ 'bias': params_tf['resnet/group_norm/beta'][None, None, None],
229
+ 'scale': params_tf['resnet/group_norm/gamma'][None, None, None],
230
+ },
231
+ 'head': {
232
+ 'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0],
233
+ 'bias': params_tf['resnet/head/conv2d/bias'],
234
+ }
235
+ }
236
+
237
+ for block in ('block1', 'block2', 'block3', 'block4'):
238
+ params[block] = {}
239
+ units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys()
240
+ if p.find(block) >= 0])
241
+ for unit in units:
242
+ params[block][unit] = {}
243
+ for i, group in enumerate('abc', 1):
244
+ params[block][unit][f'conv{i}'] = {
245
+ 'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel'] # pylint: disable=line-too-long
246
+ }
247
+ params[block][unit][f'gn{i}'] = {
248
+ 'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None], # pylint: disable=line-too-long
249
+ 'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None], # pylint: disable=line-too-long
250
+ }
251
+
252
+ projs = [p for p in params_tf.keys()
253
+ if p.find(f'{block}/{unit}/a/proj') >= 0]
254
+ assert len(projs) <= 1
255
+ if projs:
256
+ params[block][unit]['conv_proj'] = {
257
+ 'kernel': params_tf[projs[0]]
258
+ }
259
+
260
+ return params
Tipsomaly/model/big_vision/models/common.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities shared across models."""
16
+
17
+ from absl import logging
18
+ import big_vision.utils as u
19
+ import flax.linen as nn
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+
24
+ def merge_params(loaded, inited, dont_load=(), match_dtype=False):
25
+ """Makes `loaded` pytree match `init`, warning or failing on mismatch.
26
+
27
+ Args:
28
+ loaded: pytree of parameters, typically loaded from a checkpoint.
29
+ inited: pytree of parameter, typically coming from model init.
30
+ dont_load: List of regexes for parameters which shall not be taken
31
+ from `loaded`, either because they should remain at their init value,
32
+ or because they are missing on either side.
33
+ match_dtype: returned pytree as leaves converted to dtype from `inited`.
34
+
35
+ Returns:
36
+ If successful, a new pytree which matches the structure of `init`
37
+ but contains values from `loaded`, except for `dont_load`.
38
+
39
+ If structures don't match and mismatches are not covered by regexes in
40
+ `dont_load` argument, then raises an exception with more information.
41
+ """
42
+ if inited is None: # A useful shortcut for example for colabs.
43
+ return loaded
44
+
45
+ dont_load = u.check_and_compile_patterns(dont_load)
46
+
47
+ def should_merge(name):
48
+ return not any(pattern.fullmatch(name) for pattern in dont_load)
49
+
50
+ loaded_flat, _ = u.tree_flatten_with_names(loaded)
51
+ inited_flat, _ = u.tree_flatten_with_names(inited)
52
+ loaded_flat = {k: v for k, v in loaded_flat}
53
+ inited_flat = {k: v for k, v in inited_flat}
54
+
55
+ # Let's first build the pytree from all common keys.
56
+ merged = {}
57
+ for name, init_val in inited_flat.items():
58
+ # param is present in both. Load or ignore it!
59
+ if name in loaded_flat and should_merge(name):
60
+ merged[name] = loaded_flat[name]
61
+ if match_dtype:
62
+ merged[name] = loaded_flat[name].astype(init_val.dtype)
63
+ else:
64
+ logging.info("Ignoring checkpoint and using init value for %s", name)
65
+ merged[name] = init_val
66
+
67
+ def pp(title, names, indent=" "): # Just pretty-printing
68
+ if names:
69
+ return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names))
70
+ else:
71
+ return ""
72
+
73
+ # Now, if there are keys that only exist in inited or loaded, be helpful:
74
+ not_in_loaded = inited_flat.keys() - loaded_flat.keys()
75
+ not_in_inited = loaded_flat.keys() - inited_flat.keys()
76
+ logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded))
77
+ logging.info(pp("Parameters in checkpoint but not in model", not_in_inited))
78
+
79
+ # And now see if any of them are not explicitly ignored => an error
80
+ not_in_loaded = {k for k in not_in_loaded if should_merge(k)}
81
+ not_in_inited = {k for k in not_in_inited if should_merge(k)}
82
+
83
+ if not_in_loaded or not_in_inited:
84
+ raise ValueError(
85
+ pp("Params in checkpoint", loaded_flat.keys()) + "\n" +
86
+ pp("Params in model (code)", inited_flat.keys()) + "\n" +
87
+ pp("Params in model (code) but not in checkpoint and not `dont_load`ed",
88
+ not_in_loaded, indent=" - ") + "\n" + # Special indent for tests.
89
+ pp("Params in checkpoint but not in model (code) and not `dont_load`ed",
90
+ not_in_inited, indent=" + ")) # Special indent for tests.
91
+
92
+ return u.recover_tree(merged.keys(), merged.values())
93
+
94
+
95
+ class AddPositionEmbs(nn.Module):
96
+ """Adds positional embeddings to the inputs, supports caching for decode.
97
+
98
+ Attributes:
99
+ decode: whether to run in single-position autoregressive mode.
100
+ """
101
+ decode: bool = False
102
+
103
+ @nn.compact
104
+ def __call__(self, inputs, posemb):
105
+ """Applies AddPositionEmbs module.
106
+
107
+ Adds posemb to the inputs, supports single-position autoregressive mode.
108
+
109
+ Args:
110
+ inputs: input data [batch_size, seq_len, emb_dim].
111
+ posemb: positional embeddings.
112
+
113
+ Returns:
114
+ output: inputs modulated by pos-embeddings [batch_size, seq_len, emb_dim].
115
+ """
116
+ assert inputs.ndim == 3, f"Unexpected inputs shape: {inputs.shape}"
117
+ _, seq_len, emb_dim = inputs.shape
118
+ pe = posemb[:, :seq_len, :]
119
+
120
+ if self.decode:
121
+ is_initialized = self.has_variable("cache", "cache_index")
122
+ # We use a cache position index for tracking decoding position.
123
+ cache_index = self.variable("cache", "cache_index",
124
+ lambda: jnp.array(0, dtype=jnp.uint32))
125
+ if is_initialized:
126
+ i = cache_index.value
127
+ cache_index.value = i + 1
128
+ # Returns posemb[0, i, :], the positional embedding for the
129
+ # current decoding position.
130
+ pe = jax.lax.dynamic_slice(posemb,
131
+ start_indices=jnp.array((0, i, 0)),
132
+ slice_sizes=(1, 1, emb_dim))
133
+ return inputs + pe
Tipsomaly/model/big_vision/models/mlp_mixer.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MLP-Mixer model."""
16
+
17
+ from typing import Optional, Tuple
18
+ from absl import logging
19
+
20
+ from big_vision import utils
21
+ from big_vision.models import common
22
+
23
+ import einops
24
+ import flax.linen as nn
25
+ import flax.training.checkpoints
26
+ import jax
27
+ import jax.numpy as jnp
28
+
29
+
30
+ class MlpBlock(nn.Module):
31
+ mlp_dim: int
32
+
33
+ @nn.compact
34
+ def __call__(self, x):
35
+ y = nn.Dense(self.mlp_dim)(x)
36
+ y = nn.gelu(y)
37
+ return nn.Dense(x.shape[-1])(y)
38
+
39
+
40
+ class MixerBlock(nn.Module):
41
+ """Mixer block layer."""
42
+ tokens_mlp_dim: int
43
+ channels_mlp_dim: int
44
+ drop_p: float
45
+
46
+ @nn.compact
47
+ def __call__(self, x, *, train=False):
48
+ y = nn.LayerNorm()(x)
49
+ y = jnp.swapaxes(y, 1, 2)
50
+ y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y)
51
+ y = jnp.swapaxes(y, 1, 2)
52
+ x = x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
53
+ y = nn.LayerNorm()(x)
54
+ y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y)
55
+ return x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
56
+
57
+
58
+ class MlpMixer(nn.Module):
59
+ """Mixer architecture."""
60
+ patch_size: Tuple[int, int]
61
+ num_classes: Optional[int]
62
+ num_blocks: int
63
+ hidden_dim: int
64
+ tokens_mlp_dim: int
65
+ channels_mlp_dim: int
66
+ model_name: Optional[str] = None
67
+ stoch_depth: float = 0.0
68
+
69
+ @nn.compact
70
+ def __call__(self, image, *, train=False):
71
+ out = {}
72
+ x = out["stem"] = nn.Conv(self.hidden_dim, self.patch_size,
73
+ strides=self.patch_size, name="stem")(image)
74
+ x = out["input_tokens"] = einops.rearrange(x, "n h w c -> n (h w) c")
75
+ for i in range(self.num_blocks):
76
+ drop_p = (i / max(self.num_blocks - 1, 1)) * self.stoch_depth
77
+ x = out[f"block_{i}"] = MixerBlock(
78
+ self.tokens_mlp_dim, self.channels_mlp_dim, drop_p)(x, train=train)
79
+ x = nn.LayerNorm(name="pre_head_layer_norm")(x)
80
+ x = out["pre_logits"] = jnp.mean(x, axis=1)
81
+ if self.num_classes:
82
+ x = out["logits"] = nn.Dense(
83
+ self.num_classes, kernel_init=nn.initializers.zeros, name="head")(x)
84
+ return x, out
85
+
86
+
87
+ def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name
88
+ """Factory function to easily create a Model variant like "L/16"."""
89
+
90
+ if variant is not None:
91
+ model_size, patch = variant.split("/")
92
+ kw.setdefault("patch_size", (int(patch), int(patch)))
93
+ config = {
94
+ "S": {
95
+ "hidden_dim": 512,
96
+ "num_blocks": 8,
97
+ "channels_mlp_dim": 2048,
98
+ "tokens_mlp_dim": 256
99
+ },
100
+ "B": {
101
+ "hidden_dim": 768,
102
+ "num_blocks": 12,
103
+ "channels_mlp_dim": 3072,
104
+ "tokens_mlp_dim": 384
105
+ },
106
+ "L": {
107
+ "hidden_dim": 1024,
108
+ "num_blocks": 24,
109
+ "channels_mlp_dim": 4096,
110
+ "tokens_mlp_dim": 512
111
+ },
112
+ "H": {
113
+ "hidden_dim": 1280,
114
+ "num_blocks": 32,
115
+ "channels_mlp_dim": 5120,
116
+ "tokens_mlp_dim": 640
117
+ },
118
+ }[model_size]
119
+
120
+ for k, v in config.items():
121
+ kw.setdefault(k, v)
122
+
123
+ logging.info("Mixer config: %s", kw)
124
+ return MlpMixer(num_classes=num_classes, **kw)
125
+
126
+
127
+ def load(init_params, init_file, model_cfg, dont_load=()):
128
+ """Load checkpoint."""
129
+
130
+ del model_cfg
131
+ # Shortcut names for some canonical paper checkpoints:
132
+ init_file = {
133
+ # pylint: disable=line-too-long
134
+ # Pretrained models from the MLP-Mixer paper: https://arxiv.org/abs/2105.01601.
135
+ "B-i1k/16": "gs://mixer_models/imagenet1k/Mixer-B_16.npz",
136
+ "L-i1k/16": "gs://mixer_models/imagenet1k/Mixer-L_16.npz",
137
+ "B-i21k/16": "gs://mixer_models/imagenet21k/Mixer-B_16.npz",
138
+ "L-i21k/16": "gs://mixer_models/imagenet21k/Mixer-L_16.npz",
139
+ # pylint: enable=line-too-long
140
+ }.get(init_file, init_file)
141
+ restored_params = utils.load_params(init_file)
142
+ restored_params = flax.training.checkpoints.convert_pre_linen(restored_params)
143
+
144
+ if "Mixer" in restored_params:
145
+ restored_params["pre_head_layer_norm"] = restored_params["Mixer"].pop(
146
+ "encoder_norm"
147
+ )
148
+ restored_params["stem"] = restored_params.pop("embedding")
149
+ def unflatten_dense(d):
150
+ return {
151
+ "Dense_0": {
152
+ "bias": d["bias1"].squeeze(),
153
+ "kernel": d["kernel1"].squeeze(),
154
+ },
155
+ "Dense_1": {
156
+ "bias": d["bias2"].squeeze(),
157
+ "kernel": d["kernel2"].squeeze(),
158
+ },
159
+ }
160
+ for k, v in restored_params["Mixer"].items():
161
+ assert k.startswith("encoderblock_"), k
162
+ v["token_mixing"] = unflatten_dense(v.pop("token_mixing_phase_0"))
163
+ v["channel_mixing"] = unflatten_dense(v.pop("channel_mixing_phase_0"))
164
+ restored_params["MixerBlock_" + k[len("encoderblock_"):]] = v
165
+ del restored_params["Mixer"]
166
+
167
+ # possibly use the random init for some of the params (such as, the head).
168
+ restored_params = common.merge_params(restored_params, init_params, dont_load)
169
+
170
+ return restored_params
171
+
172
+
173
+ def _stoch_depth_mask(x, drop_p, deterministic, make_rng):
174
+ if not deterministic and drop_p:
175
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
176
+ return 1.0 - jax.random.bernoulli(make_rng("dropout"), drop_p, shape)
177
+ return 1.0
Tipsomaly/model/big_vision/models/vit.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A refactored and simplified ViT.
16
+
17
+ However, the names of modules are made to match the old ones for easy loading.
18
+ """
19
+
20
+ from typing import Optional, Sequence, Union
21
+
22
+ from absl import logging
23
+ from big_vision import utils
24
+ from big_vision.models import common
25
+ import flax
26
+ import flax.linen as nn
27
+ import flax.training.checkpoints
28
+ import jax
29
+ import jax.numpy as jnp
30
+ import numpy as np
31
+ import scipy.ndimage
32
+
33
+
34
+ def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32):
35
+ """Follows the MoCo v3 logic."""
36
+ y, x = jnp.mgrid[:h, :w]
37
+
38
+ assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
39
+ omega = jnp.arange(width // 4) / (width // 4 - 1)
40
+ omega = 1. / (temperature**omega)
41
+ y = jnp.einsum("m,d->md", y.flatten(), omega)
42
+ x = jnp.einsum("m,d->md", x.flatten(), omega)
43
+ pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
44
+ return jnp.asarray(pe, dtype)[None, :, :]
45
+
46
+
47
+ def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
48
+ if typ == "learn":
49
+ return self.param(name, nn.initializers.normal(stddev=1/np.sqrt(width)),
50
+ (1, np.prod(seqshape), width), dtype)
51
+ elif typ == "sincos2d":
52
+ return posemb_sincos_2d(*seqshape, width, dtype=dtype)
53
+ else:
54
+ raise ValueError(f"Unknown posemb type: {typ}")
55
+
56
+
57
+ class MlpBlock(nn.Module):
58
+ """Transformer MLP / feed-forward block."""
59
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim
60
+ dropout: float = 0.0
61
+ dtype_mm: str = "float32"
62
+
63
+ @nn.compact
64
+ def __call__(self, x, deterministic=True):
65
+ """Applies Transformer MlpBlock module."""
66
+ inits = dict(
67
+ kernel_init=nn.initializers.xavier_uniform(),
68
+ bias_init=nn.initializers.normal(stddev=1e-6),
69
+ )
70
+
71
+ d = x.shape[-1]
72
+ x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
73
+ # In some extreme batch-size cases, this is needed as of Sept 2024:
74
+ x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
75
+ x = nn.gelu(x)
76
+ x = nn.Dropout(rate=self.dropout)(x, deterministic)
77
+ x = nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
78
+ return x
79
+
80
+
81
+ class Encoder1DBlock(nn.Module):
82
+ """Single transformer encoder block (MHSA + MLP)."""
83
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim
84
+ num_heads: int = 12
85
+ dropout: float = 0.0
86
+ dtype_mm: str = "float32"
87
+
88
+ @nn.compact
89
+ def __call__(self, x, deterministic=True):
90
+ out = {}
91
+ x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
92
+ y = nn.LayerNorm()(x)
93
+ y = out["sa"] = nn.MultiHeadDotProductAttention(
94
+ num_heads=self.num_heads,
95
+ kernel_init=nn.initializers.xavier_uniform(),
96
+ deterministic=deterministic,
97
+ dtype=self.dtype_mm,
98
+ )(y, y)
99
+ y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
100
+ y = nn.Dropout(rate=self.dropout)(y, deterministic)
101
+ x = out["+sa"] = x + y
102
+
103
+ y = nn.LayerNorm()(x)
104
+ y = out["mlp"] = MlpBlock(
105
+ mlp_dim=self.mlp_dim, dropout=self.dropout,
106
+ dtype_mm=self.dtype_mm,
107
+ )(y, deterministic)
108
+ y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
109
+ y = nn.Dropout(rate=self.dropout)(y, deterministic)
110
+ x = out["+mlp"] = x + y
111
+ x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
112
+ return x, out
113
+
114
+
115
+ class Encoder(nn.Module):
116
+ """Transformer Model Encoder for sequence to sequence translation."""
117
+ depth: int
118
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim
119
+ num_heads: int = 12
120
+ dropout: float = 0.0
121
+ scan: bool = False
122
+ remat_policy: str = "nothing_saveable"
123
+ dtype_mm: str = "float32"
124
+
125
+ @nn.compact
126
+ def __call__(self, x, deterministic=True):
127
+ out = {}
128
+
129
+ if self.scan:
130
+ block = nn.remat(
131
+ Encoder1DBlock,
132
+ prevent_cse=False,
133
+ static_argnums=(2,), # 0=self, 2=deterministic
134
+ policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
135
+ )
136
+ x, scan_out = nn.scan(
137
+ block,
138
+ variable_axes={"params": 0},
139
+ split_rngs={"params": True, "dropout": True},
140
+ in_axes=nn.broadcast,
141
+ length=self.depth)(
142
+ name="encoderblock",
143
+ dtype_mm=self.dtype_mm,
144
+ mlp_dim=self.mlp_dim,
145
+ num_heads=self.num_heads,
146
+ dropout=self.dropout)(x, deterministic)
147
+ for lyr in range(self.depth):
148
+ out[f"block{lyr:02d}"] = jax.tree.map(lambda o, l=lyr: o[l], scan_out)
149
+ else:
150
+ # Input Encoder
151
+ for lyr in range(self.depth):
152
+ block_cur = Encoder1DBlock(
153
+ name=f"encoderblock_{lyr}",
154
+ dtype_mm=self.dtype_mm,
155
+ mlp_dim=self.mlp_dim, num_heads=self.num_heads,
156
+ dropout=self.dropout)
157
+ x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
158
+ out["pre_ln"] = x # Alias for last block, but without the number in it.
159
+
160
+ return nn.LayerNorm(name="encoder_norm")(x), out
161
+
162
+
163
+ class MAPHead(nn.Module):
164
+ """Multihead Attention Pooling."""
165
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim
166
+ num_heads: int = 12
167
+
168
+ @nn.compact
169
+ def __call__(self, x):
170
+ # TODO
171
+ n, l, d = x.shape # pylint: disable=unused-variable
172
+ probe = self.param("probe", nn.initializers.xavier_uniform(),
173
+ (1, 1, d), x.dtype)
174
+ probe = jnp.tile(probe, [n, 1, 1])
175
+
176
+ x = nn.MultiHeadDotProductAttention(
177
+ num_heads=self.num_heads,
178
+ kernel_init=nn.initializers.xavier_uniform())(probe, x)
179
+
180
+ # TODO: dropout on head?
181
+ y = nn.LayerNorm()(x)
182
+ x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
183
+ return x[:, 0]
184
+
185
+
186
+ class _Model(nn.Module):
187
+ """ViT model."""
188
+
189
+ num_classes: Optional[int] = None
190
+ patch_size: Sequence[int] = (16, 16)
191
+ width: int = 768
192
+ depth: int = 12
193
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim
194
+ num_heads: int = 12
195
+ posemb: str = "learn" # Can also be "sincos2d"
196
+ rep_size: Union[int, bool] = False
197
+ dropout: float = 0.0
198
+ pool_type: str = "gap" # Can also be "map" or "tok"
199
+ head_zeroinit: bool = True
200
+ scan: bool = False
201
+ # or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
202
+ remat_policy: str = "nothing_saveable"
203
+ dtype_mm: str = "float32"
204
+
205
+ @nn.compact
206
+ def __call__(self, image, *, train=False):
207
+ out = {}
208
+
209
+ image = jnp.asarray(image, self.dtype_mm)
210
+
211
+ # Patch extraction
212
+ x = out["stem"] = nn.Conv(
213
+ self.width, self.patch_size, strides=self.patch_size,
214
+ padding="VALID", name="embedding", dtype=self.dtype_mm)(image)
215
+
216
+ n, h, w, c = x.shape
217
+ x = jnp.reshape(x, [n, h * w, c])
218
+
219
+ # Add posemb before adding extra token.
220
+ x = out["with_posemb"] = x + get_posemb(
221
+ self, self.posemb, (h, w), c, "pos_embedding", x.dtype)
222
+
223
+ if self.pool_type == "tok":
224
+ cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
225
+ x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
226
+
227
+ n, l, c = x.shape # pylint: disable=unused-variable
228
+ x = nn.Dropout(rate=self.dropout)(x, not train)
229
+
230
+ x, out["encoder"] = Encoder(
231
+ depth=self.depth,
232
+ mlp_dim=self.mlp_dim,
233
+ num_heads=self.num_heads,
234
+ dropout=self.dropout,
235
+ scan=self.scan,
236
+ remat_policy=self.remat_policy,
237
+ dtype_mm=self.dtype_mm,
238
+ name="Transformer")(
239
+ x, deterministic=not train)
240
+ encoded = out["encoded"] = x
241
+
242
+ if self.pool_type == "map":
243
+ x = out["head_input"] = MAPHead(
244
+ num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
245
+ elif self.pool_type == "gap":
246
+ x = out["head_input"] = jnp.mean(x, axis=1)
247
+ elif self.pool_type == "0":
248
+ x = out["head_input"] = x[:, 0]
249
+ elif self.pool_type == "tok":
250
+ x = out["head_input"] = x[:, 0]
251
+ encoded = encoded[:, 1:]
252
+ elif self.pool_type == "none":
253
+ pass
254
+ else:
255
+ raise ValueError(f"Unknown pool type: '{self.pool_type}'")
256
+
257
+ x_2d = jnp.reshape(encoded, [n, h, w, -1])
258
+
259
+ if self.rep_size:
260
+ # raise Exception("It should not come here, patch embds should not be ...")
261
+ rep_size = self.width if self.rep_size is True else self.rep_size
262
+ hid = nn.Dense(rep_size, name="pre_logits")
263
+ # NOTE: In the past we did not include tanh in pre_logits.
264
+ # For few-shot, it should not matter much, as it whitens anyways.
265
+ x_2d = nn.tanh(hid(x_2d))
266
+ x = nn.tanh(hid(x))
267
+ print('here_rep_size')
268
+ # print('after_rep_size')
269
+ # print(f'self.pool_type: {self.pool_type}')
270
+
271
+ out["pre_logits_2d"] = x_2d
272
+ out["pre_logits"] = x
273
+
274
+ if self.num_classes:
275
+ kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
276
+ head = nn.Dense(self.num_classes, name="head", **kw)
277
+ x_2d = out["logits_2d"] = head(x_2d)
278
+ x = out["logits"] = head(x)
279
+
280
+ return x, out
281
+
282
+
283
+ def Model(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name
284
+ """Factory function, because linen really don't like what I'm doing!"""
285
+ return _Model(num_classes, **{**decode_variant(variant), **kw})
286
+
287
+
288
+ def decode_variant(variant):
289
+ """Converts a string like "B" or "B/32" into a params dict."""
290
+ if variant is None:
291
+ return {}
292
+
293
+ v, patch = variant, {}
294
+ if "/" in variant:
295
+ v, patch = variant.split("/")
296
+ patch = {"patch_size": (int(patch), int(patch))}
297
+
298
+ return {
299
+ # pylint:disable=line-too-long
300
+ # Reference: Table 2 of https://arxiv.org/abs/2106.04560.
301
+ "width": {"mu": 32, "Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "So400m": 1152, "H": 1280, "g": 1408, "g-opt": 1536, "G": 1664, "G-opt": 1536, "e": 1792}[v],
302
+ "depth": {"mu": 1, "Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "So400m": 27, "H": 32, "g": 40, "g-opt": 40, "G": 48, "G-opt": 48, "e": 56}[v],
303
+ "mlp_dim": {"mu": 128, "Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "So400m": 4304, "H": 5120, "g": 6144, "g-opt": 6144, "G": 8192, "G-opt": 8192, "e": 15360}[v],
304
+ "num_heads": {"mu": 2, "Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "So400m": 16, "H": 16, "g": 16, "g-opt": 16, "G": 16, "G-opt": 16, "e": 16}[v],
305
+ # pylint:enable=line-too-long
306
+ **patch
307
+ }
308
+
309
+ def resample_posemb(old, new):
310
+ """This function implements "high-res finetuning" for transformer models."""
311
+ # Rescale the grid of position embeddings. Param shape is (1,N,1024)
312
+ if old.shape == new.shape:
313
+ return old
314
+
315
+ logging.info("ViT: resize %s to %s", old.shape, new.shape)
316
+ gs_old = int(np.sqrt(old.shape[1]))
317
+ gs_new = int(np.sqrt(new.shape[1]))
318
+ logging.info("ViT: grid-size from %s to %s", gs_old, gs_new)
319
+ grid = old.reshape(gs_old, gs_old, -1)
320
+
321
+ zoom = (gs_new/gs_old, gs_new/gs_old, 1)
322
+ grid = scipy.ndimage.zoom(grid, zoom, order=1)
323
+ grid = grid.reshape(1, gs_new*gs_new, -1)
324
+ return grid
325
+
326
+
327
+ def fix_old_checkpoints(params):
328
+ """Fix small bwd incompat that can't be resolved with names in model def."""
329
+
330
+ params = flax.core.unfreeze(
331
+ flax.training.checkpoints.convert_pre_linen(params))
332
+
333
+ # Original ViT paper variant had posemb in a module:
334
+ if "posembed_input" in params["Transformer"]:
335
+ logging.info("ViT: Loading and fixing VERY old posemb")
336
+ posemb = params["Transformer"].pop("posembed_input")
337
+ params["pos_embedding"] = posemb["pos_embedding"]
338
+
339
+ # Widely used version before 2022 had posemb in Encoder:
340
+ if "pos_embedding" in params["Transformer"]:
341
+ logging.info("ViT: Loading and fixing old posemb")
342
+ params["pos_embedding"] = params["Transformer"].pop("pos_embedding")
343
+
344
+ # Old vit.py used to first concat [cls] token, then add posemb.
345
+ # This means a B/32@224px would have 7x7+1 posembs. This is useless and clumsy
346
+ # so we changed to add posemb then concat [cls]. We can recover the old
347
+ # checkpoint by manually summing [cls] token and its posemb entry.
348
+ if "pos_embedding" in params:
349
+ pe = params["pos_embedding"]
350
+ if int(np.sqrt(pe.shape[1])) ** 2 + 1 == int(pe.shape[1]):
351
+ logging.info("ViT: Loading and fixing combined cls+posemb")
352
+ pe_cls, params["pos_embedding"] = pe[:, :1], pe[:, 1:]
353
+ if "cls" in params:
354
+ params["cls"] += pe_cls
355
+
356
+ # MAP-head variants during ViT-G development had it inlined:
357
+ if "probe" in params:
358
+ params["MAPHead_0"] = {
359
+ k: params.pop(k) for k in
360
+ ["probe", "MlpBlock_0", "MultiHeadDotProductAttention_0", "LayerNorm_0"]
361
+ }
362
+
363
+ return params
364
+
365
+
366
+ def pyloop_to_scan(params_pyloop):
367
+ """Converts a python for-loop ViT checkpoint to a lax.scan based one."""
368
+ # On a high level, they are the same except that the for loop has separate
369
+ # array pytrees for each encoderblock, while the scan one has just one
370
+ # encoderblock pytree, with all block's params concatenated.
371
+
372
+ params_scan = jax.tree.map(lambda x: x, params_pyloop) # Structural copy
373
+ t = params_scan["Transformer"]
374
+
375
+ # Find highest index of encoderblocks in the checkpoint (they start at 0):
376
+ encoderblocks = {k for k in t if k.startswith("encoderblock_")}
377
+ depth = 1 + max({int(k.split("_")[-1]) for k in encoderblocks})
378
+
379
+ def stack(*values):
380
+ return np.stack(values)
381
+
382
+ # Stack all encoderblocks into a single one:
383
+ t["encoderblock"] = jax.tree.map(
384
+ stack, *[t[f"encoderblock_{lyr}"] for lyr in range(depth)])
385
+
386
+ for lyr in range(depth):
387
+ del t[f"encoderblock_{lyr}"]
388
+
389
+ return params_scan
390
+
391
+
392
+ def scan_to_pyloop(params_scan):
393
+ """Converts a lax.scan ViT checkpoint to a python for-loop based one."""
394
+ # See comment in pyloop_to_scan.
395
+
396
+ params_scan = jax.tree.map(lambda x: x, params_scan) # Structural copy
397
+ t = params_scan["Transformer"]
398
+
399
+ # Find out how many encoderblocks there are
400
+ depth = len(t["encoderblock"]["LayerNorm_0"]["bias"])
401
+
402
+ # Create that many encoderblocks, each with their slice of their sub-pytree.
403
+ for lyr in range(depth):
404
+ block = jax.tree.map(lambda x, lyr=lyr: x[lyr], t["encoderblock"])
405
+ t[f"encoderblock_{lyr}"] = block
406
+
407
+ del t["encoderblock"]
408
+ return params_scan
409
+
410
+
411
+ def load(init_params, init_file, model_cfg, dont_load=()): # pylint: disable=invalid-name because we had to CamelCase above.
412
+ """Load init from checkpoint, both old model and this one. +Hi-res posemb."""
413
+ init_file = VANITY_NAMES.get(init_file, init_file)
414
+ restored_params = utils.load_params(init_file)
415
+
416
+ restored_params = fix_old_checkpoints(restored_params)
417
+
418
+ # Detect attempts to load non-scan checkpoint into scan model.
419
+ if (model_cfg.get("scan") and
420
+ "encoderblock" not in restored_params["Transformer"]):
421
+ restored_params = pyloop_to_scan(restored_params)
422
+ if (not model_cfg.get("scan")
423
+ and "encoderblock" in restored_params["Transformer"]):
424
+ restored_params = scan_to_pyloop(restored_params)
425
+
426
+ # possibly use the random init for some of the params (such as, the head).
427
+ restored_params = common.merge_params(restored_params, init_params, dont_load)
428
+
429
+ # resample posemb if needed.
430
+ # TODO: Take this from model_cfg to avoid need for init_params.
431
+ if init_params and "pos_embedding" in init_params:
432
+ restored_params["pos_embedding"] = resample_posemb(
433
+ old=restored_params["pos_embedding"],
434
+ new=init_params["pos_embedding"])
435
+
436
+ return restored_params
437
+
438
+
439
+ # Shortcut names for some canonical paper checkpoints:
440
+ VANITY_NAMES = {
441
+ # pylint: disable=line-too-long
442
+ # Recommended models from https://arxiv.org/abs/2106.10270
443
+ # Many more models at https://github.com/google-research/vision_transformer
444
+ "howto-i21k-Ti/16": "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz",
445
+ "howto-i21k-S/32": "gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz",
446
+ "howto-i21k-S/16": "gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz",
447
+ "howto-i21k-B/32": "gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz",
448
+ "howto-i21k-B/16": "gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz",
449
+ "howto-i21k-B/8": "gs://vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz",
450
+ "howto-i21k-L/16": "gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz",
451
+
452
+ # Better plain vit-s16 baselines from https://arxiv.org/abs/2205.01580
453
+ "i1k-s16-90ep": "gs://big_vision/vit_s16_i1k_90ep.npz",
454
+ "i1k-s16-150ep": "gs://big_vision/vit_s16_i1k_150ep.npz",
455
+ "i1k-s16-300ep": "gs://big_vision/vit_s16_i1k_300ep.npz",
456
+
457
+ # DeiT-3 checkpoints from https://github.com/facebookresearch/deit/blob/main/README_revenge.md
458
+ # First layer converted to take inputs in [-1,1]
459
+ "deit3_S_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_1k.npz",
460
+ "deit3_S_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_224_21k.npz",
461
+ "deit3_S_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_1k.npz",
462
+ "deit3_S_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_small_384_21k.npz",
463
+ "deit3_B_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_1k.npz",
464
+ "deit3_B_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_224_21k.npz",
465
+ "deit3_B_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_1k.npz",
466
+ "deit3_B_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_base_384_21k.npz",
467
+ "deit3_L_224_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_1k.npz",
468
+ "deit3_L_224_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_224_21k.npz",
469
+ "deit3_L_384_1k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_1k.npz",
470
+ "deit3_L_384_21k": "gs://big_vision/zoo/deit3/bv_deit_3_large_384_21k.npz",
471
+
472
+ # SigLIP image encoder checkpoints from https://arxiv.org/abs/2303.15343
473
+ "SigLIP B/16 224": "gs://big_vision/siglip/webli_en_b16_224_63724782.npz:img",
474
+ "SigLIP B/16 256": "gs://big_vision/siglip/webli_en_b16_256_60500360.npz:img",
475
+ "SigLIP B/16 384": "gs://big_vision/siglip/webli_en_b16_384_68578854.npz:img",
476
+ "SigLIP B/16 512": "gs://big_vision/siglip/webli_en_b16_512_68580893.npz:img",
477
+ "SigLIP L/16 256": "gs://big_vision/siglip/webli_en_l16_256_60552751.npz:img",
478
+ "SigLIP L/16 384": "gs://big_vision/siglip/webli_en_l16_384_63634585.npz:img",
479
+ "SigLIP So400m/14 224": "gs://big_vision/siglip/webli_en_so400m_224_57633886.npz:img",
480
+ "SigLIP So400m/14 384": "gs://big_vision/siglip/webli_en_so400m_384_58765454.npz:img",
481
+ "SigLIP B/16-i18n 256": "gs://big_vision/siglip/webli_i18n_b16_256_66117334.npz:img",
482
+
483
+ # SigLIP 2 image encoder checkpoints from https://arxiv.org/abs/2502.14786
484
+ "SigLIP2 B/16 224": "gs://big_vision/siglip2/siglip2_b16_224.npz:img",
485
+ "SigLIP2 B/16 256": "gs://big_vision/siglip2/siglip2_b16_256.npz:img",
486
+ "SigLIP2 B/16 384": "gs://big_vision/siglip2/siglip2_b16_384.npz:img",
487
+ "SigLIP2 B/16 512": "gs://big_vision/siglip2/siglip2_b16_512.npz:img",
488
+ "SigLIP2 B/32 256": "gs://big_vision/siglip2/siglip2_b32_256.npz:img",
489
+ "SigLIP2 L/16 256": "gs://big_vision/siglip2/siglip2_l16_256.npz:img",
490
+ "SigLIP2 L/16 384": "gs://big_vision/siglip2/siglip2_l16_384.npz:img",
491
+ "SigLIP2 L/16 512": "gs://big_vision/siglip2/siglip2_l16_512.npz:img",
492
+ "SigLIP2 So400m/14 224": "gs://big_vision/siglip2/siglip2_so400m14_224.npz:img",
493
+ "SigLIP2 So400m/14 384": "gs://big_vision/siglip2/siglip2_so400m14_384.npz:img",
494
+ "SigLIP2 So400m/16 256": "gs://big_vision/siglip2/siglip2_so400m16_256.npz:img",
495
+ "SigLIP2 So400m/16 384": "gs://big_vision/siglip2/siglip2_so400m16_384.npz:img",
496
+ "SigLIP2 So400m/16 512": "gs://big_vision/siglip2/siglip2_so400m16_512.npz:img",
497
+ "SigLIP2 g-opt/16 256": "gs://big_vision/siglip2/siglip2_g-opt16_256.npz:img",
498
+ "SigLIP2 g-opt/16 384": "gs://big_vision/siglip2/siglip2_g-opt16_384.npz:img",
499
+ # SigLIP 2 NaFlex image encoder checkpoints.
500
+ # These need `proj.image_text.naflex_vit.py` as the image encoder model
501
+ # and a non-standard preprocessing, see configs/proj/image_text/README_siglip2.md.
502
+ "SigLIP2 B/16 NaFlex": "gs://big_vision/siglip2/siglip2_b16_naflex.npz:img",
503
+ "SigLIP2 So400m/16 NaFlex": "gs://big_vision/siglip2/siglip2_so400m16_naflex.npz:img",
504
+ # pylint: enable=line-too-long
505
+ }
Tipsomaly/model/big_vision/pp/__init__.py ADDED
File without changes
Tipsomaly/model/big_vision/pp/autoaugment.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """AutoAugment and RandAugment policies for enhanced image preprocessing.
16
+
17
+ AutoAugment Reference: https://arxiv.org/abs/1805.09501
18
+ RandAugment Reference: https://arxiv.org/abs/1909.13719
19
+
20
+ This code is forked from
21
+ https://github.com/tensorflow/tpu/blob/11d0db15cf1c3667f6e36fecffa111399e008acd/models/official/efficientnet/autoaugment.py
22
+ """
23
+
24
+ from __future__ import absolute_import
25
+ from __future__ import division
26
+ from __future__ import print_function
27
+
28
+ import dataclasses
29
+ import inspect
30
+ import math
31
+ import tensorflow.compat.v1 as tf
32
+ from tensorflow_addons import image as contrib_image
33
+
34
+ # This signifies the max integer that the controller RNN could predict for the
35
+ # augmentation scheme.
36
+ _MAX_LEVEL = 10.
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class HParams:
41
+ """Parameters for AutoAugment and RandAugment."""
42
+ cutout_const: int
43
+ translate_const: int
44
+
45
+
46
+ def policy_v0():
47
+ """Autoaugment policy that was used in AutoAugment Paper."""
48
+ # Each tuple is an augmentation operation of the form
49
+ # (operation, probability, magnitude). Each element in policy is a
50
+ # sub-policy that will be applied sequentially on the image.
51
+ policy = [
52
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
53
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
54
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
55
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
56
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
57
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
58
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
59
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
60
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
61
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
62
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
63
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
64
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
65
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
66
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
67
+ [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
68
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
69
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
70
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
71
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
72
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
73
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
74
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
75
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
76
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
77
+ ]
78
+ return policy
79
+
80
+
81
+ def policy_vtest():
82
+ """Autoaugment test policy for debugging."""
83
+ # Each tuple is an augmentation operation of the form
84
+ # (operation, probability, magnitude). Each element in policy is a
85
+ # sub-policy that will be applied sequentially on the image.
86
+ policy = [
87
+ [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
88
+ ]
89
+ return policy
90
+
91
+
92
+ def blend(image1, image2, factor):
93
+ """Blend image1 and image2 using 'factor'.
94
+ Factor can be above 0.0. A value of 0.0 means only image1 is used.
95
+ A value of 1.0 means only image2 is used. A value between 0.0 and
96
+ 1.0 means we linearly interpolate the pixel values between the two
97
+ images. A value greater than 1.0 "extrapolates" the difference
98
+ between the two pixel values, and we clip the results to values
99
+ between 0 and 255.
100
+ Args:
101
+ image1: An image Tensor of type uint8.
102
+ image2: An image Tensor of type uint8.
103
+ factor: A floating point value above 0.0.
104
+ Returns:
105
+ A blended image Tensor of type uint8.
106
+ """
107
+ if factor == 0.0:
108
+ return tf.convert_to_tensor(image1)
109
+ if factor == 1.0:
110
+ return tf.convert_to_tensor(image2)
111
+
112
+ image1 = tf.to_float(image1)
113
+ image2 = tf.to_float(image2)
114
+
115
+ difference = image2 - image1
116
+ scaled = factor * difference
117
+
118
+ # Do addition in float.
119
+ temp = tf.to_float(image1) + scaled
120
+
121
+ # Interpolate
122
+ if factor > 0.0 and factor < 1.0:
123
+ # Interpolation means we always stay within 0 and 255.
124
+ return tf.cast(temp, tf.uint8)
125
+
126
+ # Extrapolate:
127
+ #
128
+ # We need to clip and then cast.
129
+ return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
130
+
131
+
132
+ def cutout(image, pad_size, replace=0):
133
+ """Apply cutout (https://arxiv.org/abs/1708.04552) to image.
134
+ This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
135
+ a random location within `img`. The pixel values filled in will be of the
136
+ value `replace`. The located where the mask will be applied is randomly
137
+ chosen uniformly over the whole image.
138
+ Args:
139
+ image: An image Tensor of type uint8.
140
+ pad_size: Specifies how big the zero mask that will be generated is that
141
+ is applied to the image. The mask will be of size
142
+ (2*pad_size x 2*pad_size).
143
+ replace: What pixel value to fill in the image in the area that has
144
+ the cutout mask applied to it.
145
+ Returns:
146
+ An image Tensor that is of type uint8.
147
+ """
148
+ image_height = tf.shape(image)[0]
149
+ image_width = tf.shape(image)[1]
150
+
151
+ # Sample the center location in the image where the zero mask will be applied.
152
+ cutout_center_height = tf.random_uniform(
153
+ shape=[], minval=0, maxval=image_height,
154
+ dtype=tf.int32)
155
+
156
+ cutout_center_width = tf.random_uniform(
157
+ shape=[], minval=0, maxval=image_width,
158
+ dtype=tf.int32)
159
+
160
+ lower_pad = tf.maximum(0, cutout_center_height - pad_size)
161
+ upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
162
+ left_pad = tf.maximum(0, cutout_center_width - pad_size)
163
+ right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
164
+
165
+ cutout_shape = [image_height - (lower_pad + upper_pad),
166
+ image_width - (left_pad + right_pad)]
167
+ padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
168
+ mask = tf.pad(
169
+ tf.zeros(cutout_shape, dtype=image.dtype),
170
+ padding_dims, constant_values=1)
171
+ mask = tf.expand_dims(mask, -1)
172
+ mask = tf.tile(mask, [1, 1, 3])
173
+ image = tf.where(
174
+ tf.equal(mask, 0),
175
+ tf.ones_like(image, dtype=image.dtype) * replace,
176
+ image)
177
+ return image
178
+
179
+
180
+ def solarize(image, threshold=128):
181
+ # For each pixel in the image, select the pixel
182
+ # if the value is less than the threshold.
183
+ # Otherwise, subtract 255 from the pixel.
184
+ return tf.where(image < threshold, image, 255 - image)
185
+
186
+
187
+ def solarize_add(image, addition=0, threshold=128):
188
+ # For each pixel in the image less than threshold
189
+ # we add 'addition' amount to it and then clip the
190
+ # pixel value to be between 0 and 255. The value
191
+ # of 'addition' is between -128 and 128.
192
+ added_image = tf.cast(image, tf.int64) + addition
193
+ added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
194
+ return tf.where(image < threshold, added_image, image)
195
+
196
+
197
+ def color(image, factor):
198
+ """Equivalent of PIL Color."""
199
+ degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
200
+ return blend(degenerate, image, factor)
201
+
202
+
203
+ def contrast(image, factor):
204
+ """Equivalent of PIL Contrast."""
205
+ degenerate = tf.image.rgb_to_grayscale(image)
206
+ # Cast before calling tf.histogram.
207
+ degenerate = tf.cast(degenerate, tf.int32)
208
+
209
+ # Compute the grayscale histogram, then compute the mean pixel value,
210
+ # and create a constant image size of that value. Use that as the
211
+ # blending degenerate target of the original image.
212
+ hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
213
+ mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
214
+ degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
215
+ degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
216
+ degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
217
+ return blend(degenerate, image, factor)
218
+
219
+
220
+ def brightness(image, factor):
221
+ """Equivalent of PIL Brightness."""
222
+ degenerate = tf.zeros_like(image)
223
+ return blend(degenerate, image, factor)
224
+
225
+
226
+ def posterize(image, bits):
227
+ """Equivalent of PIL Posterize."""
228
+ shift = 8 - bits
229
+ return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
230
+
231
+
232
+ def rotate(image, degrees, replace):
233
+ """Rotates the image by degrees either clockwise or counterclockwise.
234
+ Args:
235
+ image: An image Tensor of type uint8.
236
+ degrees: Float, a scalar angle in degrees to rotate all images by. If
237
+ degrees is positive the image will be rotated clockwise otherwise it will
238
+ be rotated counterclockwise.
239
+ replace: A one or three value 1D tensor to fill empty pixels caused by
240
+ the rotate operation.
241
+ Returns:
242
+ The rotated version of image.
243
+ """
244
+ # Convert from degrees to radians.
245
+ degrees_to_radians = math.pi / 180.0
246
+ radians = degrees * degrees_to_radians
247
+
248
+ # In practice, we should randomize the rotation degrees by flipping
249
+ # it negatively half the time, but that's done on 'degrees' outside
250
+ # of the function.
251
+ image = contrib_image.rotate(wrap(image), radians)
252
+ return unwrap(image, replace)
253
+
254
+
255
+ def translate_x(image, pixels, replace):
256
+ """Equivalent of PIL Translate in X dimension."""
257
+ image = contrib_image.translate(wrap(image), [-pixels, 0])
258
+ return unwrap(image, replace)
259
+
260
+
261
+ def translate_y(image, pixels, replace):
262
+ """Equivalent of PIL Translate in Y dimension."""
263
+ image = contrib_image.translate(wrap(image), [0, -pixels])
264
+ return unwrap(image, replace)
265
+
266
+
267
+ def shear_x(image, level, replace):
268
+ """Equivalent of PIL Shearing in X dimension."""
269
+ # Shear parallel to x axis is a projective transform
270
+ # with a matrix form of:
271
+ # [1 level
272
+ # 0 1].
273
+ image = contrib_image.transform(
274
+ wrap(image), [1., level, 0., 0., 1., 0., 0., 0.])
275
+ return unwrap(image, replace)
276
+
277
+
278
+ def shear_y(image, level, replace):
279
+ """Equivalent of PIL Shearing in Y dimension."""
280
+ # Shear parallel to y axis is a projective transform
281
+ # with a matrix form of:
282
+ # [1 0
283
+ # level 1].
284
+ image = contrib_image.transform(
285
+ wrap(image), [1., 0., 0., level, 1., 0., 0., 0.])
286
+ return unwrap(image, replace)
287
+
288
+
289
+ def autocontrast(image):
290
+ """Implements Autocontrast function from PIL using TF ops.
291
+ Args:
292
+ image: A 3D uint8 tensor.
293
+ Returns:
294
+ The image after it has had autocontrast applied to it and will be of type
295
+ uint8.
296
+ """
297
+
298
+ def scale_channel(image):
299
+ """Scale the 2D image using the autocontrast rule."""
300
+ # A possibly cheaper version can be done using cumsum/unique_with_counts
301
+ # over the histogram values, rather than iterating over the entire image.
302
+ # to compute mins and maxes.
303
+ lo = tf.to_float(tf.reduce_min(image))
304
+ hi = tf.to_float(tf.reduce_max(image))
305
+
306
+ # Scale the image, making the lowest value 0 and the highest value 255.
307
+ def scale_values(im):
308
+ scale = 255.0 / (hi - lo)
309
+ offset = -lo * scale
310
+ im = tf.to_float(im) * scale + offset
311
+ im = tf.clip_by_value(im, 0.0, 255.0)
312
+ return tf.cast(im, tf.uint8)
313
+
314
+ result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
315
+ return result
316
+
317
+ # Assumes RGB for now. Scales each channel independently
318
+ # and then stacks the result.
319
+ s1 = scale_channel(image[:, :, 0])
320
+ s2 = scale_channel(image[:, :, 1])
321
+ s3 = scale_channel(image[:, :, 2])
322
+ image = tf.stack([s1, s2, s3], 2)
323
+ return image
324
+
325
+
326
+ def sharpness(image, factor):
327
+ """Implements Sharpness function from PIL using TF ops."""
328
+ orig_image = image
329
+ image = tf.cast(image, tf.float32)
330
+ # Make image 4D for conv operation.
331
+ image = tf.expand_dims(image, 0)
332
+ # SMOOTH PIL Kernel.
333
+ kernel = tf.constant(
334
+ [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32,
335
+ shape=[3, 3, 1, 1]) / 13.
336
+ # Tile across channel dimension.
337
+ kernel = tf.tile(kernel, [1, 1, 3, 1])
338
+ strides = [1, 1, 1, 1]
339
+ with tf.device('/cpu:0'):
340
+ # Some augmentation that uses depth-wise conv will cause crashing when
341
+ # training on GPU. See ((internal link)) for details.
342
+ degenerate = tf.nn.depthwise_conv2d(
343
+ image, kernel, strides, padding='VALID', rate=[1, 1])
344
+ degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
345
+ degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
346
+
347
+ # For the borders of the resulting image, fill in the values of the
348
+ # original image.
349
+ mask = tf.ones_like(degenerate)
350
+ padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
351
+ padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
352
+ result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
353
+
354
+ # Blend the final result.
355
+ return blend(result, orig_image, factor)
356
+
357
+
358
+ def equalize(image):
359
+ """Implements Equalize function from PIL using TF ops."""
360
+ def scale_channel(im, c):
361
+ """Scale the data in the channel to implement equalize."""
362
+ im = tf.cast(im[:, :, c], tf.int32)
363
+ # Compute the histogram of the image channel.
364
+ histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
365
+
366
+ # For the purposes of computing the step, filter out the nonzeros.
367
+ nonzero = tf.where(tf.not_equal(histo, 0))
368
+ nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
369
+ step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
370
+
371
+ def build_lut(histo, step):
372
+ # Compute the cumulative sum, shifting by step // 2
373
+ # and then normalization by step.
374
+ lut = (tf.cumsum(histo) + (step // 2)) // step
375
+ # Shift lut, prepending with 0.
376
+ lut = tf.concat([[0], lut[:-1]], 0)
377
+ # Clip the counts to be in range. This is done
378
+ # in the C code for image.point.
379
+ return tf.clip_by_value(lut, 0, 255)
380
+
381
+ # If step is zero, return the original image. Otherwise, build
382
+ # lut from the full histogram and step and then index from it.
383
+ result = tf.cond(tf.equal(step, 0),
384
+ lambda: im,
385
+ lambda: tf.gather(build_lut(histo, step), im))
386
+
387
+ return tf.cast(result, tf.uint8)
388
+
389
+ # Assumes RGB for now. Scales each channel independently
390
+ # and then stacks the result.
391
+ s1 = scale_channel(image, 0)
392
+ s2 = scale_channel(image, 1)
393
+ s3 = scale_channel(image, 2)
394
+ image = tf.stack([s1, s2, s3], 2)
395
+ return image
396
+
397
+
398
+ def invert(image):
399
+ """Inverts the image pixels."""
400
+ image = tf.convert_to_tensor(image)
401
+ return 255 - image
402
+
403
+
404
+ def wrap(image):
405
+ """Returns 'image' with an extra channel set to all 1s."""
406
+ shape = tf.shape(image)
407
+ extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
408
+ extended = tf.concat([image, extended_channel], 2)
409
+ return extended
410
+
411
+
412
+ def unwrap(image, replace):
413
+ """Unwraps an image produced by wrap.
414
+ Where there is a 0 in the last channel for every spatial position,
415
+ the rest of the three channels in that spatial dimension are grayed
416
+ (set to 128). Operations like translate and shear on a wrapped
417
+ Tensor will leave 0s in empty locations. Some transformations look
418
+ at the intensity of values to do preprocessing, and we want these
419
+ empty pixels to assume the 'average' value, rather than pure black.
420
+ Args:
421
+ image: A 3D Image Tensor with 4 channels.
422
+ replace: A one or three value 1D tensor to fill empty pixels.
423
+ Returns:
424
+ image: A 3D image Tensor with 3 channels.
425
+ """
426
+ image_shape = tf.shape(image)
427
+ # Flatten the spatial dimensions.
428
+ flattened_image = tf.reshape(image, [-1, image_shape[2]])
429
+
430
+ # Find all pixels where the last channel is zero.
431
+ alpha_channel = flattened_image[:, 3]
432
+
433
+ replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)
434
+
435
+ # Where they are zero, fill them in with 'replace'.
436
+ flattened_image = tf.where(
437
+ tf.equal(alpha_channel, 0),
438
+ tf.ones_like(flattened_image, dtype=image.dtype) * replace,
439
+ flattened_image)
440
+
441
+ image = tf.reshape(flattened_image, image_shape)
442
+ image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
443
+ return image
444
+
445
+
446
+ NAME_TO_FUNC = {
447
+ 'AutoContrast': autocontrast,
448
+ 'Equalize': equalize,
449
+ 'Invert': invert,
450
+ 'Rotate': rotate,
451
+ 'Posterize': posterize,
452
+ 'Solarize': solarize,
453
+ 'SolarizeAdd': solarize_add,
454
+ 'Color': color,
455
+ 'Contrast': contrast,
456
+ 'Brightness': brightness,
457
+ 'Sharpness': sharpness,
458
+ 'ShearX': shear_x,
459
+ 'ShearY': shear_y,
460
+ 'TranslateX': translate_x,
461
+ 'TranslateY': translate_y,
462
+ 'Cutout': cutout,
463
+ }
464
+
465
+
466
+ def _randomly_negate_tensor(tensor):
467
+ """With 50% prob turn the tensor negative."""
468
+ should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool)
469
+ final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
470
+ return final_tensor
471
+
472
+
473
+ def _rotate_level_to_arg(level):
474
+ level = (level/_MAX_LEVEL) * 30.
475
+ level = _randomly_negate_tensor(level)
476
+ return (level,)
477
+
478
+
479
+ def _shrink_level_to_arg(level):
480
+ """Converts level to ratio by which we shrink the image content."""
481
+ if level == 0:
482
+ return (1.0,) # if level is zero, do not shrink the image
483
+ # Maximum shrinking ratio is 2.9.
484
+ level = 2. / (_MAX_LEVEL / level) + 0.9
485
+ return (level,)
486
+
487
+
488
+ def _enhance_level_to_arg(level):
489
+ return ((level/_MAX_LEVEL) * 1.8 + 0.1,)
490
+
491
+
492
+ def _shear_level_to_arg(level):
493
+ level = (level/_MAX_LEVEL) * 0.3
494
+ # Flip level to negative with 50% chance.
495
+ level = _randomly_negate_tensor(level)
496
+ return (level,)
497
+
498
+
499
+ def _translate_level_to_arg(level, translate_const):
500
+ level = (level/_MAX_LEVEL) * float(translate_const)
501
+ # Flip level to negative with 50% chance.
502
+ level = _randomly_negate_tensor(level)
503
+ return (level,)
504
+
505
+
506
+ def level_to_arg(hparams):
507
+ return {
508
+ 'AutoContrast': lambda level: (),
509
+ 'Equalize': lambda level: (),
510
+ 'Invert': lambda level: (),
511
+ 'Rotate': _rotate_level_to_arg,
512
+ 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),),
513
+ 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),),
514
+ 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),),
515
+ 'Color': _enhance_level_to_arg,
516
+ 'Contrast': _enhance_level_to_arg,
517
+ 'Brightness': _enhance_level_to_arg,
518
+ 'Sharpness': _enhance_level_to_arg,
519
+ 'ShearX': _shear_level_to_arg,
520
+ 'ShearY': _shear_level_to_arg,
521
+ 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),),
522
+ 'TranslateX': lambda level: _translate_level_to_arg(
523
+ level, hparams.translate_const),
524
+ 'TranslateY': lambda level: _translate_level_to_arg(
525
+ level, hparams.translate_const),
526
+ # pylint:enable=g-long-lambda
527
+ }
528
+
529
+
530
+ def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
531
+ """Return the function that corresponds to `name` and update `level` param."""
532
+ func = NAME_TO_FUNC[name]
533
+ args = level_to_arg(augmentation_hparams)[name](level)
534
+
535
+ # Check to see if prob is passed into function. This is used for operations
536
+ # where we alter bboxes independently.
537
+ # pytype:disable=wrong-arg-types
538
+ if 'prob' in inspect.getfullargspec(func).args:
539
+ args = tuple([prob] + list(args))
540
+ # pytype:enable=wrong-arg-types
541
+
542
+ # Add in replace arg if it is required for the function that is being called.
543
+ # pytype:disable=wrong-arg-types
544
+ if 'replace' in inspect.getfullargspec(func).args:
545
+ # Make sure replace is the final argument
546
+ assert 'replace' == inspect.getfullargspec(func).args[-1]
547
+ args = tuple(list(args) + [replace_value])
548
+ # pytype:enable=wrong-arg-types
549
+
550
+ return (func, prob, args)
551
+
552
+
553
+ def _apply_func_with_prob(func, image, args, prob):
554
+ """Apply `func` to image w/ `args` as input with probability `prob`."""
555
+ assert isinstance(args, tuple)
556
+
557
+ # If prob is a function argument, then this randomness is being handled
558
+ # inside the function, so make sure it is always called.
559
+ # pytype:disable=wrong-arg-types
560
+ if 'prob' in inspect.getfullargspec(func).args:
561
+ prob = 1.0
562
+ # pytype:enable=wrong-arg-types
563
+
564
+ # Apply the function with probability `prob`.
565
+ should_apply_op = tf.cast(
566
+ tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool)
567
+ augmented_image = tf.cond(
568
+ should_apply_op,
569
+ lambda: func(image, *args),
570
+ lambda: image)
571
+ return augmented_image
572
+
573
+
574
+ def select_and_apply_random_policy(policies, image):
575
+ """Select a random policy from `policies` and apply it to `image`."""
576
+ policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32)
577
+ # Note that using tf.case instead of tf.conds would result in significantly
578
+ # larger graphs and would even break export for some larger policies.
579
+ for (i, policy) in enumerate(policies):
580
+ image = tf.cond(
581
+ tf.equal(i, policy_to_select),
582
+ lambda selected_policy=policy: selected_policy(image),
583
+ lambda: image)
584
+ return image
585
+
586
+
587
+ def build_and_apply_nas_policy(policies, image,
588
+ augmentation_hparams):
589
+ """Build a policy from the given policies passed in and apply to image.
590
+ Args:
591
+ policies: list of lists of tuples in the form `(func, prob, level)`, `func`
592
+ is a string name of the augmentation function, `prob` is the probability
593
+ of applying the `func` operation, `level` is the input argument for
594
+ `func`.
595
+ image: tf.Tensor that the resulting policy will be applied to.
596
+ augmentation_hparams: Hparams associated with the NAS learned policy.
597
+ Returns:
598
+ A version of image that now has data augmentation applied to it based on
599
+ the `policies` pass into the function.
600
+ """
601
+ replace_value = [128, 128, 128]
602
+
603
+ # func is the string name of the augmentation function, prob is the
604
+ # probability of applying the operation and level is the parameter associated
605
+ # with the tf op.
606
+
607
+ # tf_policies are functions that take in an image and return an augmented
608
+ # image.
609
+ tf_policies = []
610
+ for policy in policies:
611
+ tf_policy = []
612
+ # Link string name to the correct python function and make sure the correct
613
+ # argument is passed into that function.
614
+ for policy_info in policy:
615
+ policy_info = list(policy_info) + [replace_value, augmentation_hparams]
616
+
617
+ tf_policy.append(_parse_policy_info(*policy_info))
618
+ # Now build the tf policy that will apply the augmentation procedue
619
+ # on image.
620
+ def make_final_policy(tf_policy_):
621
+ def final_policy(image_):
622
+ for func, prob, args in tf_policy_:
623
+ image_ = _apply_func_with_prob(
624
+ func, image_, args, prob)
625
+ return image_
626
+ return final_policy
627
+ tf_policies.append(make_final_policy(tf_policy))
628
+
629
+ augmented_image = select_and_apply_random_policy(
630
+ tf_policies, image)
631
+ return augmented_image
632
+
633
+
634
+ def distort_image_with_autoaugment(image, augmentation_name):
635
+ """Applies the AutoAugment policy to `image`.
636
+ AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
637
+ Args:
638
+ image: `Tensor` of shape [height, width, 3] representing an image.
639
+ augmentation_name: The name of the AutoAugment policy to use. The available
640
+ options are `v0` and `test`. `v0` is the policy used for
641
+ all of the results in the paper and was found to achieve the best results
642
+ on the COCO dataset. `v1`, `v2` and `v3` are additional good policies
643
+ found on the COCO dataset that have slight variation in what operations
644
+ were used during the search procedure along with how many operations are
645
+ applied in parallel to a single image (2 vs 3).
646
+ Returns:
647
+ A tuple containing the augmented versions of `image`.
648
+ """
649
+ available_policies = {'v0': policy_v0,
650
+ 'test': policy_vtest}
651
+ if augmentation_name not in available_policies:
652
+ raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name))
653
+
654
+ policy = available_policies[augmentation_name]()
655
+ # Hparams that will be used for AutoAugment.
656
+ augmentation_hparams = HParams(
657
+ cutout_const=100, translate_const=250)
658
+
659
+ return build_and_apply_nas_policy(policy, image, augmentation_hparams)
660
+
661
+
662
+ def distort_image_with_randaugment(image, num_layers, magnitude):
663
+ """Applies the RandAugment policy to `image`.
664
+ RandAugment is from the paper https://arxiv.org/abs/1909.13719,
665
+ Args:
666
+ image: `Tensor` of shape [height, width, 3] representing an image.
667
+ num_layers: Integer, the number of augmentation transformations to apply
668
+ sequentially to an image. Represented as (N) in the paper. Usually best
669
+ values will be in the range [1, 3].
670
+ magnitude: Integer, shared magnitude across all augmentation operations.
671
+ Represented as (M) in the paper. Usually best values are in the range
672
+ [5, 30].
673
+ Returns:
674
+ The augmented version of `image`.
675
+ """
676
+ replace_value = [128] * 3
677
+ tf.logging.info('Using RandAug.')
678
+ augmentation_hparams = HParams(
679
+ cutout_const=40, translate_const=100)
680
+ available_ops = [
681
+ 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize',
682
+ 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness',
683
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd']
684
+
685
+ for layer_num in range(num_layers):
686
+ op_to_select = tf.random_uniform(
687
+ [], maxval=len(available_ops), dtype=tf.int32)
688
+ random_magnitude = float(magnitude)
689
+ with tf.name_scope('randaug_layer_{}'.format(layer_num)):
690
+ for (i, op_name) in enumerate(available_ops):
691
+ prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32)
692
+ func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
693
+ replace_value, augmentation_hparams)
694
+ image = tf.cond(
695
+ tf.equal(i, op_to_select),
696
+ lambda selected_func=func, selected_args=args: selected_func(
697
+ image, *selected_args),
698
+ # pylint:enable=g-long-lambda
699
+ lambda: image)
700
+ return image
Tipsomaly/model/big_vision/pp/builder.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Preprocessing builder."""
16
+
17
+ from absl import logging
18
+ from big_vision.pp import registry
19
+ import tensorflow as tf
20
+
21
+
22
+ def get_preprocess_fn(pp_pipeline, log_data=True, log_steps=False):
23
+ """Transform an input string into the preprocessing function.
24
+
25
+ The minilanguage is as follows:
26
+
27
+ fn1|fn2(arg, arg2,...)|...
28
+
29
+ And describes the successive application of the various `fn`s to the input,
30
+ where each function can optionally have one or more arguments, which are
31
+ either positional or key/value, as dictated by the `fn`.
32
+
33
+ The output preprocessing function expects a dictionary as input. This
34
+ dictionary should have a key "image" that corresponds to a 3D tensor
35
+ (height x width x channel).
36
+
37
+ Args:
38
+ pp_pipeline: A string describing the pre-processing pipeline. If empty or
39
+ None, no preprocessing will be executed.
40
+ log_data: Whether to log the data before and after preprocessing. Can also
41
+ be a string to show in the log for debugging, for example dataset name.
42
+ log_steps: Whether to log the steps of the preprocessing pipeline.
43
+
44
+ Returns:
45
+ preprocessing function.
46
+
47
+ Raises:
48
+ ValueError: if preprocessing function name is unknown
49
+ """
50
+
51
+ names, ops, spec_strings = [], [], []
52
+ if pp_pipeline:
53
+ for op_spec in pp_pipeline.split("|"):
54
+ if not op_spec: continue # Skip empty section instead of error.
55
+ try:
56
+ ops.append(registry.Registry.lookup(f"preprocess_ops.{op_spec}")())
57
+ names.append(registry.parse_name(op_spec)[0])
58
+ spec_strings.append(op_spec)
59
+ except SyntaxError as err:
60
+ raise ValueError(f"Syntax error on: {op_spec}") from err
61
+
62
+ def _preprocess_fn(data):
63
+ """The preprocessing function that is returned."""
64
+ nonlocal log_data, log_steps
65
+
66
+ # Apply all the individual steps in sequence.
67
+ if log_data:
68
+ logging.info("Data before pre-processing (%s):\n%s", log_data, data)
69
+ for name, op, spec in zip(names, ops, spec_strings):
70
+ if log_steps:
71
+ logging.info("Pre-processing step (%s): %s\n%s", name, spec, data)
72
+ with tf.name_scope(name):
73
+ data = op(data)
74
+
75
+ # Validate input
76
+ if not isinstance(data, dict):
77
+ raise ValueError("Argument `data` must be a dictionary, "
78
+ "not %s" % str(type(data)))
79
+
80
+ if log_data:
81
+ logging.info("Data after pre-processing (%s):\n%s", log_data, data)
82
+ log_data = False # For eager&pygrain: only log first one of each pipeline.
83
+ return data
84
+
85
+ return _preprocess_fn
Tipsomaly/model/big_vision/pp/builder_test.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for builder."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ from big_vision.pp import builder
22
+ from big_vision.pp import ops_general # pylint: disable=unused-import
23
+ from big_vision.pp import ops_image # pylint: disable=unused-import
24
+ import numpy as np
25
+ import tensorflow.compat.v1 as tf
26
+
27
+
28
+ class BuilderTest(tf.test.TestCase):
29
+
30
+ def testSingle(self):
31
+ pp_fn = builder.get_preprocess_fn("resize(256)")
32
+ x = np.random.randint(0, 256, [640, 480, 3])
33
+ image = pp_fn({"image": x})["image"]
34
+ self.assertEqual(image.numpy().shape, (256, 256, 3))
35
+
36
+ def testEmpty(self):
37
+ pp_fn = builder.get_preprocess_fn("||inception_crop|||resize(256)||")
38
+
39
+ # Typical image input
40
+ x = np.random.randint(0, 256, [640, 480, 3])
41
+ image = pp_fn({"image": x})["image"]
42
+ self.assertEqual(image.numpy().shape, (256, 256, 3))
43
+
44
+ def testPreprocessingPipeline(self):
45
+ pp_str = ("inception_crop|resize(256)|resize((256, 256))|"
46
+ "central_crop((80, 120))|flip_lr|value_range(0,1)|"
47
+ "value_range(-1,1)")
48
+ pp_fn = builder.get_preprocess_fn(pp_str)
49
+
50
+ # Typical image input
51
+ x = np.random.randint(0, 256, [640, 480, 3])
52
+ image = pp_fn({"image": x})["image"]
53
+ self.assertEqual(image.numpy().shape, (80, 120, 3))
54
+ self.assertLessEqual(np.max(image.numpy()), 1)
55
+ self.assertGreaterEqual(np.min(image.numpy()), -1)
56
+
57
+ def testNumArgsException(self):
58
+
59
+ x = np.random.randint(0, 256, [640, 480, 3])
60
+ for pp_str in [
61
+ "inception_crop(1)",
62
+ "resize()",
63
+ "resize(1, 1, 1)"
64
+ "flip_lr(1)",
65
+ "central_crop()",
66
+ ]:
67
+ with self.assertRaises(BaseException):
68
+ builder.get_preprocess_fn(pp_str)(x)
69
+
70
+
71
+ if __name__ == "__main__":
72
+ tf.test.main()
Tipsomaly/model/big_vision/pp/ops_general.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Generic tensor preprocessing ops.
16
+
17
+ All preprocessing ops should return a data processing functors. A data
18
+ is represented as a dictionary of (TF) tensors. The functors output a modified
19
+ dictionary.
20
+ """
21
+
22
+ import collections
23
+
24
+ from big_vision.pp import utils
25
+ from big_vision.pp.registry import Registry
26
+ import big_vision.utils as bv_utils
27
+ import jax
28
+ import numpy as np
29
+ import tensorflow as tf
30
+
31
+
32
+ @Registry.register("preprocess_ops.value_range")
33
+ @utils.InKeyOutKey()
34
+ def get_value_range(vmin=-1, vmax=1, in_min=0, in_max=255.0, clip_values=False):
35
+ """Transforms a [in_min,in_max] image to [vmin,vmax] range.
36
+
37
+ Input ranges in_min/in_max can be equal-size lists to rescale the invidudal
38
+ channels independently.
39
+
40
+ Args:
41
+ vmin: A scalar. Output max value.
42
+ vmax: A scalar. Output min value.
43
+ in_min: A scalar or a list of input min values to scale. If a list, the
44
+ length should match to the number of channels in the image.
45
+ in_max: A scalar or a list of input max values to scale. If a list, the
46
+ length should match to the number of channels in the image.
47
+ clip_values: Whether to clip the output values to the provided ranges.
48
+
49
+ Returns:
50
+ A function to rescale the values.
51
+ """
52
+
53
+ def _value_range(image):
54
+ """Scales values in given range."""
55
+ in_min_t = tf.constant(in_min, tf.float32)
56
+ in_max_t = tf.constant(in_max, tf.float32)
57
+ image = tf.cast(image, tf.float32)
58
+ image = (image - in_min_t) / (in_max_t - in_min_t)
59
+ image = vmin + image * (vmax - vmin)
60
+ if clip_values:
61
+ image = tf.clip_by_value(image, vmin, vmax)
62
+ return image
63
+
64
+ return _value_range
65
+
66
+
67
+ @Registry.register("preprocess_ops.lookup")
68
+ @utils.InKeyOutKey()
69
+ def get_lookup(mapping, npzkey="fnames", sep=None):
70
+ """Map string to number."""
71
+
72
+ # For NumPy files, we use the `npzkey` array in that file as the list of
73
+ # strings which are mapped to their index in that array.
74
+ # This is especially useful when other data (eg precomputed predictions)
75
+ # goes along with this mapping, to have everything in one place (the npz).
76
+ if mapping.endswith(".npz"):
77
+ with tf.io.gfile.GFile(mapping, "rb") as f:
78
+ keys = np.array(np.load(f, allow_pickle=False)[npzkey])
79
+ vals = np.arange(len(keys))
80
+
81
+ # Otherwise, we simply use the file as a text file, with either of:
82
+ # - a string per line, mapped to its line-number
83
+ # - a pair, separated by `sep` per line, first value being the string, second
84
+ # value being the integer that the string is mapped to.
85
+ else:
86
+ with tf.io.gfile.GFile(mapping, "r") as f:
87
+ buf = f.read()
88
+ if sep is None: # values are the line numbers
89
+ keys = buf.splitlines()
90
+ vals = np.arange(len(keys))
91
+ else: # each line is key<sep>val, also make val int
92
+ keys, vals = zip(*[l.split(sep) for l in buf.splitlines()])
93
+ vals = [int(v) for v in vals]
94
+
95
+ def _do_the_mapping(needle):
96
+ """Map string to number."""
97
+ with tf.init_scope(): # (Originally added for performance reasons.)
98
+ table = tf.lookup.StaticHashTable(
99
+ tf.lookup.KeyValueTensorInitializer(keys, vals), -1)
100
+ return table.lookup(needle)
101
+
102
+ return _do_the_mapping
103
+
104
+
105
+ @Registry.register("preprocess_ops.onehot")
106
+ def get_onehot(depth,
107
+ key="labels",
108
+ key_result=None,
109
+ multi=True,
110
+ on=1.0,
111
+ off=0.0):
112
+ """One-hot encodes the input.
113
+
114
+ Args:
115
+ depth: Length of the one-hot vector (how many classes).
116
+ key: Key of the data to be one-hot encoded.
117
+ key_result: Key under which to store the result (same as `key` if None).
118
+ multi: If there are multiple labels, whether to merge them into the same
119
+ "multi-hot" vector (True) or keep them as an extra dimension (False).
120
+ on: Value to fill in for the positive label (default: 1).
121
+ off: Value to fill in for negative labels (default: 0).
122
+
123
+ Returns:
124
+ Data dictionary.
125
+ """
126
+
127
+ def _onehot(data):
128
+ # When there's more than one label, this is significantly more efficient
129
+ # than using tf.one_hot followed by tf.reduce_max; we tested.
130
+ labels = data[key]
131
+ labels = tf.cast(labels, tf.int64) # both scatter and one_hot expect this
132
+ if labels.shape.rank > 0 and multi:
133
+ x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), (depth,))
134
+ x = tf.clip_by_value(x, 0, 1) * (on - off) + off
135
+ else:
136
+ x = tf.one_hot(labels, depth, on_value=on, off_value=off)
137
+ data[key_result or key] = x
138
+ return data
139
+
140
+ return _onehot
141
+
142
+
143
+ @Registry.register("preprocess_ops.keep")
144
+ def get_keep(*keys):
145
+ """Keeps only the given keys."""
146
+
147
+ def _keep(data):
148
+ return {k: v for k, v in data.items() if k in keys}
149
+
150
+ return _keep
151
+
152
+
153
+ @Registry.register("preprocess_ops.drop")
154
+ def get_drop(*keys):
155
+ """Drops the given keys."""
156
+
157
+ def _drop(data):
158
+ return {k: v for k, v in data.items() if k not in keys}
159
+
160
+ return _drop
161
+
162
+
163
+ @Registry.register("preprocess_ops.copy")
164
+ def get_copy(inkey, outkey):
165
+ """Copies value of `inkey` into `outkey`."""
166
+
167
+ def _copy(data):
168
+ # A "semi-deep" copy. deepcopy doesn't work when tf tensors are part of the
169
+ # game. What we want, is to only copy the python structure (dicts, lists)
170
+ # and keep tensors as they are, since we never modify them in-place anyways.
171
+ # The following achieves exactly that.
172
+ data[outkey] = jax.tree.map(lambda x: x, data[inkey])
173
+ return data
174
+
175
+ return _copy
176
+
177
+
178
+ @Registry.register("preprocess_ops.squeeze_last_dim")
179
+ @utils.InKeyOutKey()
180
+ def get_squeeze_last_dim():
181
+ def _squeeze_last_dim(x):
182
+ return tf.squeeze(x, axis=-1)
183
+ return _squeeze_last_dim
184
+
185
+
186
+ @Registry.register("preprocess_ops.concat")
187
+ def get_concat(inkeys, outkey=None, axis=-1):
188
+ """Concatenates elements along some axis."""
189
+
190
+ def _concat(data):
191
+ data[outkey or inkeys[0]] = tf.concat([data[k] for k in inkeys], axis)
192
+ return data
193
+
194
+ return _concat
195
+
196
+
197
+ @Registry.register("preprocess_ops.rag_tensor")
198
+ @utils.InKeyOutKey()
199
+ def get_rag_tensor():
200
+ """Converts the specified feature to ragged tensor."""
201
+
202
+ def rag_tensor(raw_tensor):
203
+ # Note: Add one more dimension as `from_tensor` requires at least rank 2.
204
+ return tf.RaggedTensor.from_tensor(raw_tensor[None])
205
+
206
+ return rag_tensor
207
+
208
+
209
+ @Registry.register("preprocess_ops.pad_to_shape")
210
+ @utils.InKeyOutKey()
211
+ def get_pad_to_shape(shape, pad_value=0, where="after"):
212
+ """Pads tensor to specified `shape`."""
213
+
214
+ def _pads(cur, tgt):
215
+ if tgt is None:
216
+ return [0, 0]
217
+ diff = tgt - cur
218
+ return {
219
+ "before": [diff, 0],
220
+ "after": [0, diff],
221
+ "both": [diff // 2, diff - diff // 2],
222
+ }[where]
223
+
224
+ def _pad_to_shape(x):
225
+ assert len(x.shape.as_list()) == len(shape)
226
+ paddings = [_pads(tgt=shape[i], cur=tf.shape(x)[i])
227
+ for i in range(len(shape))]
228
+ constant_value = tf.constant(pad_value, x.dtype)
229
+ ret = tf.pad(x, paddings, constant_values=constant_value)
230
+ ret.set_shape(shape)
231
+ return ret
232
+
233
+ return _pad_to_shape
234
+
235
+
236
+ @Registry.register("preprocess_ops.flatten")
237
+ def get_flatten(keys=None):
238
+ """Flattens the keys of data with separator '/'."""
239
+
240
+ def _flatten(data):
241
+ flatten_keys = keys or list(data.keys())
242
+ not_flattened = {k: v for k, v in data.items() if k not in flatten_keys}
243
+ flattened = {k: v for k, v in data.items() if k in flatten_keys}
244
+ flattened, _ = bv_utils.tree_flatten_with_names(flattened)
245
+ return {**dict(flattened), **not_flattened}
246
+
247
+ return _flatten
248
+
249
+
250
+ @Registry.register("preprocess_ops.reshape")
251
+ @utils.InKeyOutKey()
252
+ def get_reshape(new_shape):
253
+ """Reshapes tensor to a given new shape.
254
+
255
+ Args:
256
+ new_shape: new shape for the tensor.
257
+
258
+ Returns:
259
+ A function for reshaping a tensor.
260
+
261
+ """
262
+
263
+ def _reshape(tensor):
264
+ """Reshapes a tensor to a given shape."""
265
+ dtype = tensor.dtype
266
+ tensor = tf.reshape(tensor, new_shape)
267
+ return tf.cast(tensor, dtype)
268
+
269
+ return _reshape
270
+
271
+
272
+ @Registry.register("preprocess_ops.setdefault")
273
+ def get_setdefault(key, value):
274
+ """If `key` is an empty tensor or missing, set it to `value`."""
275
+ def _setdefault(data):
276
+ x = data.get(key, tf.constant(value))
277
+ v = tf.constant(value, dtype=x.dtype)
278
+ v = tf.broadcast_to(v, [s or 1 for s in x.shape])
279
+ data[key] = tf.cond(tf.size(x) > 0, lambda: x, lambda: v)
280
+ return data
281
+ return _setdefault
282
+
283
+
284
+ @Registry.register("preprocess_ops.choice")
285
+ def get_choice(n="single", key=None, fewer_ok=False, inkey=None, outkey=None):
286
+ """Chooses the same `n` random entries of all `keys`.
287
+
288
+ Args:
289
+ n: how many entries to randomly sample (without repeat). Possible values:
290
+ - int: that many entries (or fewer if there's fewer, see `fewer_ok`.)
291
+ - "single": The string "single" only chooses one and drop the leading dim.
292
+ - [min, max]: A pair means randomly take between min/max examples (incl.).
293
+ key: str or list of str: See Note.
294
+ fewer_ok: whether to fail when there's fewer than `n` elements to choose
295
+ from (and hence set static shape to `n`), or whether to allow it.
296
+ (and hence have unknown static shape).
297
+ inkey: str or list of str: See Note.
298
+ outkey: str or list of str: See Note.
299
+
300
+ Note:
301
+ If key/inkey/outkey is a list, then the same random entries are chosen for
302
+ all of the keys. Other than that, they function the same as InKeyOutKey.
303
+
304
+ The outkey can also contain the placeholder `{key}` that'll be .
305
+
306
+ Examples:
307
+ choice(key="alt_text/text")
308
+ choice(n=128, key=["patches", "positions"])
309
+ choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"])
310
+
311
+ Returns:
312
+ The pp op.
313
+ """
314
+
315
+ # Normalize keys:
316
+ inkeys = utils.maybe_repeat(inkey or key, 1)
317
+ outkeys = utils.maybe_repeat(outkey or key, 1)
318
+ outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)]
319
+
320
+ # Let's DRY on this condition and give it a name.
321
+ is_varlen = isinstance(n, (list, tuple))
322
+ min_n = n[0] if is_varlen else 1 if n == "single" else n
323
+
324
+ def _choice(data):
325
+ # Catch a hard to identify/understand user error:
326
+ assert data[inkeys[0]].ndim > 0, (
327
+ f"You're calling `choice_no_replacement` on {inkeys}, a scalar."
328
+ " That's probably a mistake ; double-check and then just don't."
329
+ )
330
+
331
+ nitems = tf.shape(data[inkeys[0]])[0]
332
+
333
+ # Sanity check that all keys have same leading dimension, and that is at
334
+ # least as large as the minimum requested output.
335
+ lengths = [tf.shape(data[k])[0] for k in inkeys]
336
+ checks = [tf.debugging.assert_equal(l, nitems) for l in lengths]
337
+ if not fewer_ok: # Since we check for all-same, a single suffices here.
338
+ checks.append(tf.debugging.assert_greater_equal(nitems, min_n))
339
+ with tf.control_dependencies(checks):
340
+ nitems = tf.identity(nitems)
341
+
342
+ if n == "single":
343
+ index = tf.random.uniform([], 0, nitems, dtype=tf.int32)
344
+ else:
345
+ # Subsample by shuffling and taking first n, but...
346
+ indices = tf.random.shuffle(tf.range(nitems))
347
+ end = n
348
+ if is_varlen:
349
+ end = tf.random.uniform([], n[0], n[1] + 1, dtype=tf.int32)
350
+ # ...keep the order while subsampling (it might have a meaning, eg boxes)
351
+ indices = tf.sort(indices[:end])
352
+
353
+ for ik, ok in zip(inkeys, outkeys):
354
+ if n == "single":
355
+ result = data[ik][index]
356
+ else:
357
+ result = tf.gather(data[ik], indices, axis=0)
358
+ if not is_varlen: # Give static shape when we can.
359
+ result = tf.ensure_shape(result, [n] + [None] * (result.ndim - 1))
360
+ data[ok] = result
361
+
362
+ return data
363
+ return _choice
364
+
365
+
366
+ def _shuffled_index(count, nitems, seed):
367
+ """Returns index from a shuffled sequence (items only repeat after epoch)."""
368
+ nitems = tf.cast(nitems, count.dtype)
369
+ item_epoch, item_offset = (count // nitems, count % nitems)
370
+ shuffled_indices = tf.random.experimental.stateless_shuffle(
371
+ tf.range(nitems), seed=tf.random.fold_in(seed, item_epoch))
372
+ return shuffled_indices[item_offset]
373
+
374
+
375
+ @Registry.register("preprocess_ops.choice_no_replacement")
376
+ def get_choice_no_replacement(key=None, inkey=None, outkey=None):
377
+ """Chooses the same random (no replacement) entry of all `keys`.
378
+
379
+ Note: Consider using this for iterating over small datasets with a small
380
+ number of epochs. It differs from `choice(n='single')` in that if an example,
381
+ as identified by its `_id` field, is seen N times then it will cycled through
382
+ all the inkeys values before repeating them. Additionally each repetition uses
383
+ a different order.
384
+
385
+ Caveats: requires dataset to provide a _id field and uses host RAM to keep a
386
+ counter how often each id is seen. It is also not robust to preemptions.
387
+
388
+ Args:
389
+ key: str or list of str: See Note.
390
+ inkey: str or list of str: See Note.
391
+ outkey: str or list of str: See Note.
392
+
393
+ Note:
394
+ If key/inkey/outkey is a list, then the same random entries are chosen for
395
+ all of the keys. Other than that, they function the same as InKeyOutKey.
396
+
397
+ The outkey can also contain the placeholder `{key}` that'll be replaced
398
+ by the inkey name.
399
+
400
+ Examples:
401
+ choice(key="alt_text/text")
402
+ choice(key=["patches", "positions"])
403
+ choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"])
404
+
405
+ Returns:
406
+ The pp op.
407
+ """
408
+ # Normalize keys:
409
+ inkeys = utils.maybe_repeat(inkey or key, 1)
410
+ outkeys = utils.maybe_repeat(outkey or key, 1)
411
+ outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)]
412
+
413
+ # TODO: Ideally the data pipeline should provide us with an epoch
414
+ # counter. For now count how often we see a given example id and don't worry
415
+ # on memory consumption. Counter returns 0 the first time an example is seen.
416
+ counter = collections.defaultdict(lambda: -1)
417
+ def _seen_count(example_id):
418
+ example_id = example_id.item()
419
+ counter[example_id] += 1
420
+ return counter[example_id]
421
+
422
+ # We need a seed to deterministically decide on a shuffled sequence and use
423
+ # the number of times an example was seen to iterate through it. The seed
424
+ # should be different for every instance of a create preprocessing function
425
+ # but it has to be fixed for each instance.
426
+ seed = tf.random.uniform(
427
+ [2], minval=tf.int32.min, maxval=tf.int32.max, dtype=tf.int32)
428
+
429
+ def _choice(data):
430
+ # Catch a hard to identify/understand user error:
431
+ assert data[inkeys[0]].ndim > 0, (
432
+ f"You're calling `choice` on {inkeys}, a scalar."
433
+ " That's probably a mistake ; double-check and then just don't."
434
+ )
435
+
436
+ nitems = tf.shape(data[inkeys[0]])[0]
437
+
438
+ # Sanity check that all keys have same leading dimension.
439
+ checks = [
440
+ tf.debugging.assert_equal(tf.shape(data[k])[0], nitems)
441
+ for k in inkeys
442
+ ]
443
+ with tf.control_dependencies(checks):
444
+ nitems = tf.identity(nitems)
445
+
446
+ # Using the seed, example id and the number of times an example was seen
447
+ # pick an `index` such that items are only repeated after all items are seen
448
+ # an equal number of times. E.g. it could return indexes from this sequence:
449
+ # [0, 1, 2, 1, 2, 0, 2, 0, 1, 0, 2, 1, ...].
450
+ count = tf.numpy_function(
451
+ _seen_count, (data["_id"],), Tout=tf.int64, stateful=True)
452
+ count = tf.cast(count, tf.int32)
453
+ nitems = tf.cast(nitems, tf.int32)
454
+ shuffle_epoch = count // nitems
455
+ shuffle_offset = count % nitems
456
+
457
+ example_seed = tf.random.fold_in(seed, data["_id"])
458
+ shuffle_seed = tf.random.fold_in(example_seed, shuffle_epoch)
459
+ shuffle = tf.random.experimental.stateless_shuffle(
460
+ tf.range(nitems), seed=shuffle_seed)
461
+ index = shuffle[shuffle_offset]
462
+
463
+ # Select item[index] for all keys.
464
+ for ik, ok in zip(inkeys, outkeys):
465
+ data[ok] = data[ik][index]
466
+ return data
467
+
468
+ return _choice
Tipsomaly/model/big_vision/pp/ops_general_test.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for ops_general."""
16
+
17
+ import copy
18
+
19
+ import big_vision.pp.ops_general as pp
20
+ import numpy as np
21
+ import tensorflow as tf
22
+
23
+
24
+ class PreprocessOpsTest(tf.test.TestCase):
25
+
26
+ def tfrun(self, ppfn, data):
27
+ # Run once as standalone, as could happen eg in colab.
28
+ yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()}
29
+
30
+ # And then once again as part of tfdata pipeline.
31
+ # You'd be surprised how much these two differ!
32
+ tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data))
33
+ for npdata in tfdata.map(ppfn).as_numpy_iterator():
34
+ yield npdata
35
+
36
+ def test_value_range(self):
37
+ img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32)
38
+ data = {"image": tf.cast(img, tf.uint8)}
39
+ for out in self.tfrun(pp.get_value_range(-0.5, 0.5), data):
40
+ self.assertLessEqual(np.max(out["image"]), 0.5)
41
+ self.assertGreaterEqual(np.min(out["image"]), -0.5)
42
+
43
+ def test_value_range_custom_input_range(self):
44
+ img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32)
45
+ data = {"image": tf.cast(img, tf.uint8)}
46
+ for out in self.tfrun(pp.get_value_range(-0.5, 0.5, -256, 255, True), data):
47
+ self.assertLessEqual(np.max(out["image"]), 0.5)
48
+ self.assertGreaterEqual(np.min(out["image"]), 0.0)
49
+
50
+ def test_get_keep_drop(self):
51
+ data = {"image": 1, "labels": 2, "something": 3}
52
+
53
+ for data_keep in self.tfrun(pp.get_keep("image", "labels"), data):
54
+ self.assertAllEqual(set(data_keep.keys()), {"image", "labels"})
55
+
56
+ for data_drop in self.tfrun(pp.get_drop("image", "labels"), data):
57
+ self.assertAllEqual(set(data_drop.keys()), {"something"})
58
+
59
+ def test_onehot(self):
60
+ data = {"labels": tf.constant(2, dtype=tf.int64)}
61
+ for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data):
62
+ self.assertAllClose(out["labels"], [0., 0., 1., 0.])
63
+
64
+ def test_onehot_multi(self):
65
+ data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)}
66
+ for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data):
67
+ self.assertAllClose(out["labels"], [
68
+ [0., 0., 1., 0.],
69
+ [0., 0., 0., 1.],
70
+ [1., 0., 0., 0.]])
71
+
72
+ for out in self.tfrun(pp.get_onehot(4, "labels", multi=True), data):
73
+ self.assertAllClose(out["labels"], [1., 0., 1., 1.])
74
+
75
+ def test_onehot_2d(self):
76
+ data = {"labels": tf.constant([[2, 3], [0, 1]], dtype=tf.int64)}
77
+ for out in self.tfrun(pp.get_onehot(4, "labels", multi=False), data):
78
+ self.assertAllClose(out["labels"], [
79
+ [[0., 0., 1., 0.], [0., 0., 0., 1.]],
80
+ [[1., 0., 0., 0.], [0., 1., 0., 0.]]])
81
+
82
+ def test_onehot_smoothing(self):
83
+ data = {"labels": tf.constant([2, 3, 0], dtype=tf.int64)}
84
+ for out in self.tfrun(
85
+ pp.get_onehot(4, "labels", multi=False, on=0.8, off=0.1), data):
86
+ self.assertAllClose(out["labels"], [
87
+ [0.1, 0.1, 0.8, 0.1],
88
+ [0.1, 0.1, 0.1, 0.8],
89
+ [0.8, 0.1, 0.1, 0.1]])
90
+
91
+ for out in self.tfrun(
92
+ pp.get_onehot(4, "labels", multi=True, on=0.8, off=0.1), data):
93
+ self.assertAllClose(out["labels"], [0.8, 0.1, 0.8, 0.8])
94
+
95
+ def test_squeeze_last_dim(self):
96
+ data = {"image": tf.constant(np.zeros((32, 32, 3, 1)))}
97
+ for out in self.tfrun(pp.get_squeeze_last_dim(), data):
98
+ self.assertAllEqual(out["image"].shape, [32, 32, 3])
99
+
100
+ def test_pad_to_shape(self):
101
+ desired_shape = (8, 10)
102
+ for input_shape in [(8, 4), (8, 3), (8, 10), (8, 1)]:
103
+ data = {"x": tf.ones(input_shape, dtype=tf.float32)}
104
+ for out in self.tfrun(
105
+ pp.get_pad_to_shape(desired_shape, pad_value=-1, key="x"), data):
106
+ self.assertEqual(
107
+ tf.reduce_sum(out["x"]),
108
+ 2 * np.prod(input_shape) - np.prod(desired_shape))
109
+
110
+ def test_pad_to_shape_none(self):
111
+ data = {"x": tf.ones((8, 4), dtype=tf.float32)}
112
+ for out in self.tfrun(
113
+ pp.get_pad_to_shape((None, 6), pad_value=-1, key="x"), data):
114
+ self.assertEqual(out["x"].shape, (8, 6))
115
+ self.assertEqual(tf.reduce_sum(out["x"]), 8*4 - 8*2)
116
+
117
+ def test_pad_to_shape_which_side(self):
118
+ data = {"x": tf.ones((8, 4), dtype=tf.float32)}
119
+ for where, idxs in [("before", [0]), ("both", [0, -1]), ("after", [-1])]:
120
+ for out in self.tfrun(
121
+ pp.get_pad_to_shape((8, 6), key="x", where=where), data):
122
+ self.assertEqual(out["x"].shape, (8, 6))
123
+ self.assertEqual(tf.reduce_sum(out["x"]), 8*4)
124
+ for i in idxs:
125
+ self.assertEqual(out["x"][0, i], 0)
126
+
127
+ def test_flatten(self):
128
+ d = {"a": {"b": tf.constant([1, 2, 3])}, "c": "str"}
129
+ self.assertEqual(pp.get_flatten()(d), {
130
+ "a/b": tf.constant([1, 2, 3]),
131
+ "c": "str"
132
+ })
133
+
134
+ def test_reshape(self):
135
+ data = {"image": tf.constant(np.zeros((8, 32 * 32 * 3)))}
136
+ for out in self.tfrun(pp.get_reshape(new_shape=(8, 32, 32, 3)), data):
137
+ self.assertAllEqual(out["image"].shape, [8, 32, 32, 3])
138
+
139
+ def test_setdefault(self):
140
+ data = {
141
+ "empty_image": tf.zeros([0, 0, 0]),
142
+ "image": tf.constant(np.arange(9).reshape(3, 3)),
143
+ "empty_text": tf.zeros([0], tf.string),
144
+ "text": tf.constant(["Hello", "World"], tf.string),
145
+ }
146
+ for out in self.tfrun(pp.get_setdefault("empty_image", 1), data):
147
+ self.assertAllEqual(out["empty_image"], np.array([[[1]]]))
148
+ for out in self.tfrun(pp.get_setdefault("image", 1), data):
149
+ self.assertAllEqual(out["image"], data["image"])
150
+ for out in self.tfrun(pp.get_setdefault("empty_text", "Lucas"), data):
151
+ self.assertAllEqual(out["empty_text"], np.array(["Lucas"]))
152
+ for out in self.tfrun(pp.get_setdefault("text", "Lucas"), data):
153
+ self.assertAllEqual(out["text"], data["text"])
154
+
155
+ def _data_for_choice(self):
156
+ return {
157
+ "one_f32": tf.constant([0.42], tf.float32),
158
+ "two_f32": tf.constant([3.14, 0.42], tf.float32),
159
+ "one_str": tf.constant(["Hi"], tf.string),
160
+ "two_str": tf.constant(["Hi", "Lucas"], tf.string),
161
+ "one_vec": tf.reshape(tf.range(2, dtype=tf.float32), (1, 2)),
162
+ "two_vec": tf.reshape(tf.range(4, dtype=tf.float32), (2, 2)),
163
+ }
164
+
165
+ def test_choice(self):
166
+ # Test for the default call (n="single")
167
+ data = self._data_for_choice()
168
+ self.assertEqual(
169
+ pp.get_choice(inkey="one_f32", outkey="choice")(data)["choice"], 0.42)
170
+ self.assertEqual(
171
+ pp.get_choice(inkey="one_str", outkey="choice")(data)["choice"], "Hi")
172
+ self.assertIn(
173
+ pp.get_choice(inkey="two_f32", outkey="choice")(data)["choice"],
174
+ [3.14, 0.42])
175
+ self.assertIn(
176
+ pp.get_choice(inkey="two_str", outkey="choice")(data)["choice"],
177
+ ["Hi", "Lucas"])
178
+
179
+ def test_choice_nmax(self):
180
+ # n == nelems should be identity (and keep ordering!)
181
+ data = self._data_for_choice()
182
+ for k in ("one_f32", "one_str", "one_vec"):
183
+ for out in self.tfrun(pp.get_choice(n=1, key=[k]), data):
184
+ self.assertAllEqual(out[k], data[k])
185
+ for out in self.tfrun(pp.get_choice(n=[1, 1], key=[k]), data):
186
+ self.assertAllEqual(out[k], data[k])
187
+ for k in ("two_f32", "two_str", "two_vec"):
188
+ for out in self.tfrun(pp.get_choice(n=2, key=[k]), data):
189
+ self.assertAllEqual(out[k], data[k])
190
+ for out in self.tfrun(pp.get_choice(n=[2, 2], key=[k]), data):
191
+ self.assertAllEqual(out[k], data[k])
192
+
193
+ def test_choice_n(self):
194
+ # n < nelems should be one of them:
195
+ data = self._data_for_choice()
196
+ for k in ("two_f32", "two_str"):
197
+ for out in self.tfrun(pp.get_choice(n=1, key=[k]), data):
198
+ self.assertIn(out[k], data[k])
199
+
200
+ # Special testing for vectors.
201
+ for out in self.tfrun(pp.get_choice(n=1, key=["two_vec"]), data):
202
+ self.assertTrue(tf.logical_or(
203
+ tf.reduce_all(out["two_vec"][0] == data["two_vec"][0]),
204
+ tf.reduce_all(out["two_vec"][0] == data["two_vec"][1]),
205
+ ))
206
+
207
+ def test_choice_multi(self):
208
+ # Select consistently across multiple keys.
209
+ data = self._data_for_choice()
210
+ op = pp.get_choice(n=1, key=["two_f32", "two_str"])
211
+ for out in self.tfrun(op, data):
212
+ self.assertTrue(tf.logical_or(
213
+ tf.logical_and(
214
+ tf.reduce_all(out["two_f32"][0] == data["two_f32"][0]),
215
+ tf.reduce_all(out["two_str"][0] == data["two_str"][0]),
216
+ ),
217
+ tf.logical_and(
218
+ tf.reduce_all(out["two_f32"][0] == data["two_f32"][1]),
219
+ tf.reduce_all(out["two_str"][0] == data["two_str"][1]),
220
+ ),
221
+ ))
222
+
223
+ def test_choice_n_range(self):
224
+ # n < nelems should be one of them:
225
+ data = self._data_for_choice()
226
+ for k in ("two_f32", "two_str", "two_vec"):
227
+ for out in self.tfrun(pp.get_choice(n=[1, 2], key=[k]), data):
228
+ self.assertTrue(tf.reduce_any([
229
+ tf.reduce_all(out[k] == data[k][0:1]),
230
+ tf.reduce_all(out[k] == data[k][1:2]),
231
+ tf.reduce_all(out[k] == data[k][0:2]),
232
+ ]))
233
+
234
+
235
+ if __name__ == "__main__":
236
+ tf.test.main()
Tipsomaly/model/big_vision/pp/ops_image.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Image-centric preprocessing ops.
16
+
17
+ All preprocessing ops should return a data processing functors. A data
18
+ is represented as a dictionary of (TF) tensors. The functors output a modified
19
+ dictionary.
20
+
21
+ The key named "image" is commonly used for the image, and is a 3D tensor of
22
+ shape (height x width x channels).
23
+ """
24
+
25
+ from big_vision.pp import utils
26
+ from big_vision.pp.registry import Registry
27
+
28
+ import tensorflow as tf
29
+
30
+
31
+ @Registry.register("preprocess_ops.decode")
32
+ @utils.InKeyOutKey()
33
+ def get_decode(channels=3, precise=False):
34
+ """Decode an encoded image string, see tf.io.decode_image.
35
+
36
+ Args:
37
+ channels: see tf.io.decode_image.
38
+ precise: if False, use default TF image decoding algorithm.
39
+ If True, change DCT method for JPEG decoding to match PIL/cv2/PyTorch.
40
+ See also (internal link) for a concrete example.
41
+
42
+ Returns:
43
+ The decoded image.
44
+ """
45
+
46
+ def _decode(image):
47
+ if precise:
48
+ return tf.image.decode_jpeg( # Also supports png btw.
49
+ image, channels=channels, dct_method="INTEGER_ACCURATE")
50
+ else:
51
+ return tf.io.decode_image(
52
+ image, channels=channels, expand_animations=False)
53
+
54
+ return _decode
55
+
56
+
57
+ @Registry.register("preprocess_ops.resize")
58
+ @utils.InKeyOutKey()
59
+ def get_resize(size, method="bilinear", antialias=False):
60
+ """Resizes image to a given size.
61
+
62
+ Args:
63
+ size: either an integer H, where H is both the new height and width
64
+ of the resized image, or a list or tuple [H, W] of integers, where H and W
65
+ are new image"s height and width respectively.
66
+ method: resize method, see tf.image.resize docs for options.
67
+ antialias: see tf.image.resize. Ideally set to True for all new configs.
68
+
69
+ Returns:
70
+ A function for resizing an image.
71
+
72
+ """
73
+ size = utils.maybe_repeat(size, 2)
74
+
75
+ def _resize(image):
76
+ """Resizes image to a given size."""
77
+ # Note: use TF-2 version of tf.image.resize as the version in TF-1 is
78
+ # buggy: https://github.com/tensorflow/tensorflow/issues/6720.
79
+ # In particular it was not equivariant with rotation and lead to the network
80
+ # to learn a shortcut in self-supervised rotation task, if rotation was
81
+ # applied after resize.
82
+ dtype = image.dtype
83
+ tf_dtype = tf.type_spec_from_value(image).dtype
84
+ image = tf.image.resize(image, size, method=method, antialias=antialias)
85
+ return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype)
86
+
87
+ return _resize
88
+
89
+
90
+ # This functionality is used by resize_small and resize_long. But we're not
91
+ # registering it as a pp op yet, as there is no need for it. However, it can
92
+ # probably be slightly generalized into "scale augmentation" eventually.
93
+ def _resize_factor(image, factor, method="area", antialias=True):
94
+ """Resizes the image by a (float) `factor`, keeping the aspect ratio fixed."""
95
+ h, w = tf.shape(image)[0], tf.shape(image)[1]
96
+
97
+ h = tf.cast(tf.round(tf.cast(h, tf.float32) * factor), tf.int32)
98
+ w = tf.cast(tf.round(tf.cast(w, tf.float32) * factor), tf.int32)
99
+
100
+ dtype = image.dtype
101
+ tf_dtype = tf.type_spec_from_value(image).dtype
102
+ image = tf.image.resize(image, (h, w), method=method, antialias=antialias)
103
+ return tf.cast(tf.clip_by_value(image, tf_dtype.min, tf_dtype.max), dtype)
104
+
105
+
106
+ @Registry.register("preprocess_ops.resize_small")
107
+ @utils.InKeyOutKey()
108
+ def get_resize_small(smaller_size, method="area", antialias=False):
109
+ """Resizes the smaller side to `smaller_size` keeping aspect ratio.
110
+
111
+ Args:
112
+ smaller_size: an integer, that represents a new size of the smaller side of
113
+ an input image.
114
+ method: the resize method. `area` is a meaningful, bwd-compat default.
115
+ antialias: see tf.image.resize. Ideally set to True for all new configs.
116
+
117
+ Returns:
118
+ A function, that resizes an image and preserves its aspect ratio.
119
+
120
+ Note:
121
+ backwards-compat for "area"+antialias tested here:
122
+ (internal link)
123
+ """
124
+
125
+ def _resize_small(image): # pylint: disable=missing-docstring
126
+ h, w = tf.shape(image)[0], tf.shape(image)[1]
127
+ factor = (
128
+ tf.cast(smaller_size, tf.float32) /
129
+ tf.cast(tf.minimum(h, w), tf.float32))
130
+ return _resize_factor(image, factor, method=method, antialias=antialias)
131
+ return _resize_small
132
+
133
+
134
+ @Registry.register("preprocess_ops.resize_long")
135
+ @utils.InKeyOutKey()
136
+ def get_resize_long(longer_size, method="area", antialias=True):
137
+ """Resizes the longer side to `longer_size` keeping aspect ratio.
138
+
139
+ Args:
140
+ longer_size: an integer, that represents a new size of the longer side of
141
+ an input image.
142
+ method: the resize method. `area` is a meaningful, bwd-compat default.
143
+ antialias: see tf.image.resize. Ideally set to True for all new configs.
144
+
145
+ Returns:
146
+ A function, that resizes an image and preserves its aspect ratio.
147
+ """
148
+
149
+ def _resize_long(image): # pylint: disable=missing-docstring
150
+ h, w = tf.shape(image)[0], tf.shape(image)[1]
151
+ factor = (
152
+ tf.cast(longer_size, tf.float32) /
153
+ tf.cast(tf.maximum(h, w), tf.float32))
154
+ return _resize_factor(image, factor, method=method, antialias=antialias)
155
+ return _resize_long
156
+
157
+
158
+ @Registry.register("preprocess_ops.inception_crop")
159
+ @utils.InKeyOutKey()
160
+ def get_inception_crop(size=None, area_min=5, area_max=100,
161
+ method="bilinear", antialias=False):
162
+ """Makes inception-style image crop.
163
+
164
+ Inception-style crop is a random image crop (its size and aspect ratio are
165
+ random) that was used for training Inception models, see
166
+ https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf.
167
+
168
+ Args:
169
+ size: Resize image to [size, size] after crop.
170
+ area_min: minimal crop area.
171
+ area_max: maximal crop area.
172
+ method: rezied method, see tf.image.resize docs for options.
173
+ antialias: see tf.image.resize. Ideally set to True for all new configs.
174
+
175
+ Returns:
176
+ A function, that applies inception crop.
177
+ """
178
+
179
+ def _inception_crop(image): # pylint: disable=missing-docstring
180
+ begin, crop_size, _ = tf.image.sample_distorted_bounding_box(
181
+ tf.shape(image),
182
+ tf.zeros([0, 0, 4], tf.float32),
183
+ area_range=(area_min / 100, area_max / 100),
184
+ min_object_covered=0, # Don't enforce a minimum area.
185
+ use_image_if_no_bounding_boxes=True)
186
+ crop = tf.slice(image, begin, crop_size)
187
+ # Unfortunately, the above operation loses the depth-dimension. So we need
188
+ # to restore it the manual way.
189
+ crop.set_shape([None, None, image.shape[-1]])
190
+ if size:
191
+ crop = get_resize(size, method, antialias)({"image": crop})["image"]
192
+ return crop
193
+
194
+ return _inception_crop
195
+
196
+
197
+ @Registry.register("preprocess_ops.decode_jpeg_and_inception_crop")
198
+ @utils.InKeyOutKey()
199
+ def get_decode_jpeg_and_inception_crop(size=None, area_min=5, area_max=100,
200
+ ratio_min=0.75, ratio_max=1.33,
201
+ method="bilinear", antialias=False):
202
+ """Decode jpeg string and make inception-style image crop.
203
+
204
+ Inception-style crop is a random image crop (its size and aspect ratio are
205
+ random) that was used for training Inception models, see
206
+ https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf.
207
+
208
+ Args:
209
+ size: Resize image to [size, size] after crop.
210
+ area_min: minimal crop area.
211
+ area_max: maximal crop area.
212
+ ratio_min: minimal aspect ratio.
213
+ ratio_max: maximal aspect ratio.
214
+ method: rezied method, see tf.image.resize docs for options.
215
+ antialias: see tf.image.resize. Ideally set to True for all new configs.
216
+
217
+ Returns:
218
+ A function, that applies inception crop.
219
+ """
220
+
221
+ def _inception_crop(image_data): # pylint: disable=missing-docstring
222
+ shape = tf.image.extract_jpeg_shape(image_data)
223
+ begin, crop_size, _ = tf.image.sample_distorted_bounding_box(
224
+ shape,
225
+ tf.zeros([0, 0, 4], tf.float32),
226
+ area_range=(area_min / 100, area_max / 100),
227
+ aspect_ratio_range=(ratio_min, ratio_max),
228
+ min_object_covered=0, # Don't enforce a minimum area.
229
+ use_image_if_no_bounding_boxes=True)
230
+
231
+ # Crop the image to the specified bounding box.
232
+ offset_y, offset_x, _ = tf.unstack(begin)
233
+ target_height, target_width, _ = tf.unstack(crop_size)
234
+ crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
235
+ image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3)
236
+
237
+ if size:
238
+ image = get_resize(size, method, antialias)({"image": image})["image"]
239
+
240
+ return image
241
+
242
+ return _inception_crop
243
+
244
+
245
+ @Registry.register("preprocess_ops.random_crop")
246
+ @utils.InKeyOutKey()
247
+ def get_random_crop(crop_size):
248
+ """Makes a random crop of a given size.
249
+
250
+ Args:
251
+ crop_size: either an integer H, where H is both the height and width of the
252
+ random crop, or a list or tuple [H, W] of integers, where H and W are
253
+ height and width of the random crop respectively.
254
+
255
+ Returns:
256
+ A function, that applies random crop.
257
+ """
258
+ crop_size = utils.maybe_repeat(crop_size, 2)
259
+
260
+ def _crop(image):
261
+ return tf.image.random_crop(image, (*crop_size, image.shape[-1]))
262
+
263
+ return _crop
264
+
265
+
266
+ @Registry.register("preprocess_ops.central_crop")
267
+ @utils.InKeyOutKey()
268
+ def get_central_crop(crop_size=None):
269
+ """Makes central crop of a given size.
270
+
271
+ Args:
272
+ crop_size: either an integer H, where H is both the height and width of the
273
+ central crop, or a list or tuple [H, W] of integers, where H and W are
274
+ height and width of the central crop respectively. If `crop_size` is not
275
+ specified, then the largest possible center crop will be taken.
276
+
277
+ Returns:
278
+ A function, that applies central crop.
279
+ """
280
+ if crop_size:
281
+ crop_size = utils.maybe_repeat(crop_size, 2)
282
+
283
+ def _crop(image):
284
+ if crop_size:
285
+ h, w = crop_size[0], crop_size[1]
286
+ else:
287
+ h = w = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
288
+ dy = (tf.shape(image)[0] - h) // 2
289
+ dx = (tf.shape(image)[1] - w) // 2
290
+ return tf.image.crop_to_bounding_box(image, dy, dx, h, w)
291
+
292
+ return _crop
293
+
294
+
295
+ @Registry.register("preprocess_ops.flip_lr")
296
+ @utils.InKeyOutKey()
297
+ def get_random_flip_lr():
298
+ """Flips an image horizontally with probability 50%."""
299
+
300
+ def _random_flip_lr_pp(image):
301
+ return tf.image.random_flip_left_right(image)
302
+
303
+ return _random_flip_lr_pp
304
+
305
+
306
+ @Registry.register("preprocess_ops.vgg_value_range")
307
+ @utils.InKeyOutKey()
308
+ def get_vgg_value_range(
309
+ mean=(0.485 * 255, 0.456 * 255, 0.406 * 255),
310
+ std=(0.229 * 255, 0.224 * 255, 0.225 * 255),
311
+ ):
312
+ """VGG-style preprocessing, subtracts mean and divides by stddev.
313
+
314
+ This preprocessing is very common for ImageNet pre-trained models since VGG,
315
+ and to this day the standard for models coming from most PyTorch codes.
316
+
317
+ Args:
318
+ mean: Tuple of values to be subtracted. Default to widespread VGG values.
319
+ std: Tuple of values to be divided by. Default to widespread VGG values.
320
+
321
+ Returns:
322
+ A function to rescale the values.
323
+ """
324
+ mean = tf.constant(mean, tf.float32)
325
+ std = tf.constant(std, tf.float32)
326
+
327
+ def _vgg_value_range(image):
328
+ return (tf.cast(image, tf.float32) - mean) / std
329
+ return _vgg_value_range
330
+
331
+
332
+ @Registry.register("preprocess_ops.clip_value_range")
333
+ @utils.InKeyOutKey()
334
+ def get_clip_value_range():
335
+ mean = (0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255)
336
+ std = (0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255)
337
+
338
+ def _clip_value_range(image):
339
+ return (tf.cast(image, tf.float32) - mean) / std
340
+ return _clip_value_range
341
+
342
+
343
+ @Registry.register("preprocess_ops.convert_to_video")
344
+ @utils.InKeyOutKey()
345
+ def get_convert_to_video(num_frames):
346
+ """Converts an image to a video with zero padded frames.
347
+
348
+ Args:
349
+ num_frames: total number of frames that the video should have.
350
+
351
+ Returns:
352
+ A function for converting an image to a video.
353
+ """
354
+
355
+ def _convert_to_video(image):
356
+ return tf.pad(
357
+ tf.expand_dims(image, axis=0),
358
+ [[0, num_frames - 1], [0, 0], [0, 0], [0, 0]],
359
+ )
360
+
361
+ return _convert_to_video
Tipsomaly/model/big_vision/pp/ops_image_test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for ops_image."""
16
+
17
+ import copy
18
+ import io
19
+
20
+ import big_vision.pp.ops_image as pp
21
+ import matplotlib.pyplot as plt
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+
26
+ def get_image_data():
27
+ img = tf.random.uniform((640, 480, 3), 0, 255, tf.int32) # Can't ask uint8!?
28
+ return {"image": tf.cast(img, tf.uint8)}
29
+
30
+
31
+ class PreprocessOpsTest(tf.test.TestCase):
32
+
33
+ def tfrun(self, ppfn, data):
34
+ # Run once as standalone, as could happen eg in colab.
35
+ yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()}
36
+
37
+ # And then once again as part of tfdata pipeline.
38
+ # You'd be surprised how much these two differ!
39
+ tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data))
40
+ for npdata in tfdata.map(ppfn).as_numpy_iterator():
41
+ yield npdata
42
+
43
+ def test_resize(self):
44
+ for data in self.tfrun(pp.get_resize([120, 80]), get_image_data()):
45
+ self.assertEqual(data["image"].shape, (120, 80, 3))
46
+
47
+ def test_resize_small(self):
48
+ for data in self.tfrun(pp.get_resize_small(240), get_image_data()):
49
+ self.assertEqual(data["image"].shape, (320, 240, 3))
50
+
51
+ def test_resize_long(self):
52
+ for data in self.tfrun(pp.get_resize_long(320), get_image_data()):
53
+ self.assertEqual(data["image"].shape, (320, 240, 3))
54
+
55
+ def test_inception_crop(self):
56
+ for data in self.tfrun(pp.get_inception_crop(), get_image_data()):
57
+ self.assertEqual(data["image"].shape[-1], 3)
58
+
59
+ def test_decode_jpeg_and_inception_crop(self):
60
+ f = io.BytesIO()
61
+ plt.imsave(f, get_image_data()["image"].numpy(), format="jpg")
62
+ data = {"image": tf.cast(f.getvalue(), tf.string)}
63
+ for data in self.tfrun(pp.get_decode_jpeg_and_inception_crop(), data):
64
+ self.assertEqual(data["image"].shape[-1], 3)
65
+
66
+ def test_random_crop(self):
67
+ for data in self.tfrun(pp.get_random_crop([120, 80]), get_image_data()):
68
+ self.assertEqual(data["image"].shape, (120, 80, 3))
69
+
70
+ def test_central_crop(self):
71
+ for data in self.tfrun(pp.get_central_crop([20, 80]), get_image_data()):
72
+ self.assertEqual(data["image"].shape, (20, 80, 3))
73
+
74
+ def test_random_flip_lr(self):
75
+ data_orig = get_image_data()
76
+ for data in self.tfrun(pp.get_random_flip_lr(), data_orig):
77
+ self.assertTrue(
78
+ np.all(data_orig["image"].numpy() == data["image"]) or
79
+ np.all(data_orig["image"].numpy() == data["image"][:, ::-1]))
80
+
81
+ if __name__ == "__main__":
82
+ tf.test.main()
Tipsomaly/model/big_vision/pp/ops_text.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Text-centric preprocessing ops.
16
+
17
+ All preprocessing ops should return a data processing functors. A data
18
+ is represented as a dictionary of (TF) tensors. The functors output a modified
19
+ dictionary.
20
+
21
+ A commonly used key for the tokenized output is "labels".
22
+ """
23
+ import functools
24
+ import importlib
25
+ import string
26
+
27
+ from absl import logging
28
+ from big_vision.datasets.imagenet import class_names as imagenet_class_names
29
+ from big_vision.pp import ops_general
30
+ from big_vision.pp import tokenizer as bv_tok
31
+ from big_vision.pp import utils
32
+ from big_vision.pp.registry import Registry
33
+ import tensorflow as tf
34
+
35
+ from tensorflow.io import gfile
36
+
37
+ import sentencepiece
38
+ SPProcessor = sentencepiece.SentencePieceProcessor
39
+
40
+ import os
41
+ os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
42
+ import sentencepiece.sentencepiece_model_pb2
43
+ del os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION']
44
+ SPModelProto = sentencepiece.sentencepiece_model_pb2.ModelProto
45
+
46
+
47
+ # TODO: b/lbeyer - softly introduce and move to new tokenizer API.
48
+
49
+ KNOWN_TOKENIZERS = {
50
+ "mc4": # used in multilingual models (mT5, PaLI), vocab_size=250_000
51
+ "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
52
+ "cc_all": # vocab_size=32_000
53
+ "gs://t5-data/vocabs/cc_all.32000/sentencepiece.model",
54
+ "c4_en": # vocab_size=32_000
55
+ "gs://t5-data/vocabs/cc_en.32000/sentencepiece.model",
56
+ "t5": # same as cc_all, but with 100 extra dummy tokens used by T5 models
57
+ "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model",
58
+ "mt5": # same as mc4, but with 100 extra dummy tokens used by T5 models
59
+ "gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
60
+ }
61
+
62
+
63
+ def create_tokenizer(model="c4_en", add_eos=True, add_bos=False):
64
+ """Creates a tokenizer which can be used in tfds."""
65
+ logging.info("Creating tokenizer: %s", model)
66
+ with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f:
67
+ model = f.read()
68
+
69
+ # Lazy import of tensorflow_text so it is an optional dependency for
70
+ # the users of this file.
71
+ import tensorflow_text
72
+ return tensorflow_text.SentencepieceTokenizer(
73
+ model=model, add_eos=add_eos, add_bos=add_bos
74
+ )
75
+
76
+
77
+ def tokenize(input_text, tokenizer, max_len, *, pad_value, force_eos,
78
+ multi_text=False):
79
+ """Tokenizes string, and adds `pad_value` if longer than `max_len`."""
80
+
81
+ def pad(tokens):
82
+ # Truncate/pad to max_len.
83
+ if force_eos:
84
+ tokens = tf.cond(
85
+ tf.shape(tokens)[0] >= max_len,
86
+ lambda: tf.concat(
87
+ # For too long, cut them off, but do keep the final EOS token.
88
+ [tokens[:max_len - 1], tokens[-1:]], axis=0),
89
+ lambda: tf.pad(
90
+ tokens, [(0, max_len - tf.shape(tokens)[0])],
91
+ constant_values=pad_value),
92
+ )
93
+ else:
94
+ tokens = tokens[:max_len]
95
+ tokens = tf.pad(
96
+ tokens, [(0, max_len - tf.shape(tokens)[0])],
97
+ constant_values=pad_value)
98
+ tokens.set_shape([max_len])
99
+ return tokens
100
+
101
+ tokens = tokenizer.tokenize(input_text)
102
+
103
+ if multi_text:
104
+ tokens = tokens.to_tensor(pad_value) # tf.RaggedTensor to tf.Tensor
105
+ tokens = tf.reshape(tokens, [-1, tf.shape(tokens)[-1]])
106
+ tokens = tf.map_fn(pad, tokens) # `map_fn` only maps on axis 0
107
+
108
+ final_shape = tf.concat([tf.shape(input_text), [max_len]], axis=0)
109
+ return tf.reshape(tokens, final_shape)
110
+ else:
111
+ return pad(tokens)
112
+
113
+
114
+ @Registry.register("preprocess_ops.tokenize")
115
+ @utils.InKeyOutKey(indefault=None, outdefault="labels")
116
+ def get_pp_tokenize(
117
+ max_len,
118
+ eos,
119
+ model="c4_en",
120
+ lower=True,
121
+ sample_if_multi=True,
122
+ pad_value="<pad>",
123
+ add_bos=False
124
+ ):
125
+ """Tokenizes a text.
126
+
127
+ Let's assume max_len=3 and id("</s>")=1, id("a")=2, then we have
128
+
129
+ 1. `eos="none", pad_value=0`:
130
+ - "a" -> [2, 0, 0]
131
+ - "aa" -> [2, 2, 0]
132
+ - "aaa" -> [2, 2, 2]
133
+
134
+ 2. `eos="yes", pad_value=0`:
135
+ - "a" -> [2, 1, 0]
136
+ - "aa" -> [2, 2, 1]
137
+ - "aaa" -> [2, 2, 2]
138
+
139
+ This is usually used with generative models that need to learn when to
140
+ properly predict a "</s>" (when the sentence is finished) and when to
141
+ abstain (when the sentence is truncated).
142
+
143
+ 3. `eos="sticky", pad_value=0`:
144
+ - "a" -> [2, 1, 0]
145
+ - "aa" -> [2, 2, 1]
146
+ - "aaa" -> [2, 2, 1]
147
+
148
+ 4. `eos="sticky", pad_value=1`:
149
+ - "a" -> [2, 1, 1]
150
+ - "aa" -> [2, 2, 1]
151
+ - "aaa" -> [2, 2, 1]
152
+
153
+ This is traditionally used with contrastive models that use the last token
154
+ for embeddings, similarly to "cls" tokens in BERT-style models.
155
+
156
+ Args:
157
+ max_len: maximum length of the tokenized text.
158
+ eos: Whether to add an "</s>" (end of sentence) token and whether to keep it
159
+ when the sequence is longer than `max_len - 1`. See examples above for
160
+ details. Valid values: "none", "yes", "sticky".
161
+ model: a path to the pretrained sentencepiece model.
162
+ lower: lowercase the text before tokenizing.
163
+ sample_if_multi: If there's more than one, randomly pick one if this is
164
+ True; otherwise pick all texts and keep the input's batch shape in result.
165
+ pad_value: which token to pad the sequence with. If a string (for example
166
+ `"<pad>"`), tokenize it and use its first token. Note that there is no
167
+ guarantee to have any padding at the end of the sentence, if the sentence
168
+ is longer than `max_len`.
169
+ add_bos: adds beginning of sentence symbol.
170
+
171
+ Returns:
172
+ an op that outputs tokenized text.
173
+ """
174
+
175
+ if eos not in ("yes", "none", "sticky"):
176
+ raise ValueError(f"Invalid value for eos: '{eos}'.")
177
+
178
+ tokenizer = create_tokenizer(model, add_eos=eos != "none", add_bos=add_bos)
179
+
180
+ if isinstance(pad_value, str):
181
+ pad_value = tokenizer.string_to_id(pad_value)
182
+
183
+ def _pp_tokenize(txt):
184
+ if sample_if_multi and tf.convert_to_tensor(txt).ndim:
185
+ # TODO: I wish this code-path could die.
186
+ logging.warning("sample_if_multi is deprecated and will be removed."
187
+ "Call `choice` (and maybe `setdefault`) instead.")
188
+ txt = ops_general.get_choice(key="t")(
189
+ ops_general.get_setdefault("t", "")({"t": txt}))["t"]
190
+
191
+ if lower:
192
+ txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn(
193
+ tf.strings.lower, txt)
194
+
195
+ return tokenize(
196
+ txt,
197
+ tokenizer,
198
+ max_len,
199
+ pad_value=pad_value,
200
+ force_eos=eos == "sticky",
201
+ multi_text=not sample_if_multi)
202
+
203
+ return _pp_tokenize
204
+
205
+
206
+ @Registry.register("preprocess_ops.coco_captions")
207
+ def get_coco_captions(outkey="captions"):
208
+ """Extracts coco's captions from nested dict."""
209
+
210
+ def _pp_coco_captions(data):
211
+ data[outkey] = data["captions"]["text"]
212
+ return data
213
+
214
+ return _pp_coco_captions
215
+
216
+
217
+ @Registry.register("preprocess_ops.clip_i1k_label_names")
218
+ @utils.InKeyOutKey(indefault="label", outdefault="labels")
219
+ def get_pp_clip_i1k_label_names():
220
+ """Convert i1k label numbers to strings, using CLIP's class names."""
221
+
222
+ def _pp_imagenet_labels(label):
223
+ return tf.gather(imagenet_class_names.CLIP_IMAGENET_CLASS_NAMES, label)
224
+
225
+ return _pp_imagenet_labels
226
+
227
+
228
+ @Registry.register("preprocess_ops.i21k_label_names")
229
+ @utils.InKeyOutKey(indefault="label", outdefault="labels")
230
+ def get_pp_i21k_label_names():
231
+ """Converts i21k label ids to strings."""
232
+
233
+ def _pp_imagenet_labels(label):
234
+ return tf.gather(imagenet_class_names.IMAGENET21k_CLASS_NAMES, label)
235
+
236
+ return _pp_imagenet_labels
237
+
238
+
239
+ @Registry.register("preprocess_ops.lower")
240
+ @utils.InKeyOutKey(indefault="text", outdefault="text")
241
+ def get_lower():
242
+ """Lowercases text feature."""
243
+
244
+ def _pp_lower(text):
245
+ return tf.strings.lower(text)
246
+
247
+ return _pp_lower
248
+
249
+
250
+ @Registry.register("preprocess_ops.strfmt")
251
+ def get_strfmt(template, outkey="text"):
252
+ """Formats a string template with content form the data dict."""
253
+
254
+ def _template(data):
255
+ outputs = []
256
+ parts = string.Formatter().parse(template)
257
+ for (literal_text, field_name, format_spec, conversion) in parts:
258
+ # For now, we keep it simple and don't support fancy format specs.
259
+ # But we can add support to that via py_func as soon as we need it.
260
+ assert not format_spec and not conversion
261
+ outputs.append(tf.constant(literal_text))
262
+ if field_name:
263
+ value = data[field_name]
264
+ # Convert any non-strings (numbers, vectors) to a string.
265
+ if tf.convert_to_tensor(value).dtype != tf.string:
266
+ value = tf.strings.format("{}", value, summarize=-1)
267
+ outputs.append(value)
268
+ data[outkey] = tf.strings.join(outputs)
269
+ return data
270
+
271
+ return _template
272
+
273
+
274
+ def _add_pieces(model_bytes, extra_pieces):
275
+ """Adds extra pieces to sentencpiece model specified by `model_bytes`."""
276
+
277
+ model = SPProcessor()
278
+ model.LoadFromSerializedProto(model_bytes)
279
+ unk_idx = model.PieceToId("<unk>")
280
+ assert model.IdToPiece(unk_idx) == "<unk>", model.IdToPiece(unk_idx)
281
+
282
+ model_proto = SPModelProto.FromString(model_bytes)
283
+ idx_to_updated_piece = {}
284
+ for piece in extra_pieces:
285
+ # The SentencePieceModel proto stores whitespaces as the special
286
+ # character '▁'. We perform the conversion here.
287
+ piece = piece.replace(" ", "▁")
288
+ spiece = model_proto.SentencePiece(
289
+ piece=piece,
290
+ # We set the highest score to force priority on user defined tokens.
291
+ score=0.0,
292
+ type=model_proto.SentencePiece().Type.USER_DEFINED,
293
+ )
294
+ existing_idx = model.PieceToId(piece)
295
+ if (existing_idx != unk_idx) ^ (piece == "<unk>"):
296
+ idx_to_updated_piece[existing_idx] = spiece
297
+ logging.info("Updating token at idx %d: %s", existing_idx, spiece.piece)
298
+ else:
299
+ model_proto.pieces.append(spiece)
300
+
301
+ # Replace duplicated pieces with updated ones.
302
+ updated_pieces = [
303
+ idx_to_updated_piece.get(i, piece)
304
+ for i, piece in enumerate(model_proto.pieces)
305
+ ]
306
+ del model_proto.pieces[:]
307
+ model_proto.pieces.extend(updated_pieces)
308
+
309
+ return model_proto.SerializeToString()
310
+
311
+
312
+ def _iterable(x):
313
+ if isinstance(x, tf.RaggedTensor):
314
+ return True
315
+ if getattr(x, "ndim", 0) > 1: # np, jnp
316
+ return True
317
+ if isinstance(x, (list, tuple)) and not isinstance(x[0], (int, float)):
318
+ return True
319
+ return False
320
+
321
+
322
+ @Registry.register("tokenizers.sp")
323
+ class SentencepieceTokenizer(bv_tok.Tokenizer):
324
+ """Wraps a `tftext.SentencepieceTokenizer`.
325
+
326
+ If you plan to use this tokenizer, please familiarize yourself with the test
327
+ cases first. This is likely to save you a lot of troubles down the road, trust
328
+ me!
329
+ """
330
+
331
+ def __init__(self, model, tokensets=()):
332
+ with gfile.GFile(KNOWN_TOKENIZERS.get(model, model), "rb") as f:
333
+ model_bytes = f.read()
334
+ extras = bv_tok.get_extra_tokens(tokensets)
335
+ model_bytes = _add_pieces(model_bytes, extras)
336
+ self._tok_sp = SPProcessor()
337
+ self._tok_sp.LoadFromSerializedProto(model_bytes)
338
+ self.extras = {self._tok_sp.PieceToId(x): x for x in extras}
339
+
340
+ def to_int(self, text, *, bos=False, eos=False):
341
+ def _single(s):
342
+ return (
343
+ ([self.bos_token] if bos else []) +
344
+ self._tok_sp.EncodeAsIds(s) +
345
+ ([self.eos_token] if eos else [])
346
+ )
347
+ if isinstance(text, str):
348
+ return _single(text)
349
+ return type(text)([_single(s) for s in text])
350
+
351
+ def to_str(self, tokens, *, stop_at_eos=True):
352
+ def _single(toks):
353
+ toks = [int(t) for t in toks] # We really need this for DecodeIds.
354
+ if stop_at_eos:
355
+ try: # The SentencePiece strips eos, but does not stop at it, so we do.
356
+ toks = toks[:toks.index(self.eos_token)]
357
+ except ValueError: # No eos token found, nothing to do.
358
+ pass
359
+ return self._tok_sp.DecodeIds(toks)
360
+ if _iterable(tokens):
361
+ return [_single(toks) for toks in tokens]
362
+ return _single(tokens)
363
+
364
+ def _check_known(self, piece):
365
+ if (id_ := self._tok_sp.PieceToId(piece)) == self._tok_sp.unk_id():
366
+ logging.error("Piece '%s' is not known (unk=%s)!", piece, id_)
367
+ return id_
368
+
369
+ def to_piece(self, idx):
370
+ return self._tok_sp.IdToPiece(int(idx))
371
+
372
+ @property
373
+ def pad_token(self):
374
+ return self._tok_sp.pad_id()
375
+
376
+ @property
377
+ def eos_token(self):
378
+ return self._tok_sp.eos_id()
379
+
380
+ @property
381
+ def bos_token(self):
382
+ return self._tok_sp.bos_id()
383
+
384
+ @property
385
+ def vocab_size(self):
386
+ return self._tok_sp.GetPieceSize()
387
+
388
+ # For the _tf_op variants, we need a lot of wrapping boilerplate.
389
+
390
+ def to_int_tf_op(self, text, *, bos=False, eos=False):
391
+ text = tf.convert_to_tensor(text)
392
+ if text.ndim == 0:
393
+ def fn(txt):
394
+ s = txt.numpy().decode()
395
+ return tf.constant(self.to_int(s, bos=bos, eos=eos), tf.int32)
396
+ return tf.py_function(fn, [text], tf.int32)
397
+ else:
398
+ def fn(txt):
399
+ strings = [s.decode() for s in txt.numpy().tolist()]
400
+ toks = self.to_int(strings, bos=bos, eos=eos)
401
+ return tf.ragged.constant(toks)
402
+ out_type = tf.RaggedTensorSpec([tf.shape(text)[0], None], tf.int32)
403
+ return tf.py_function(fn, [text], Tout=out_type)
404
+
405
+ def to_str_tf_op(self, tokens, *, stop_at_eos=True):
406
+ def single(t):
407
+ fn = functools.partial(self.to_str, stop_at_eos=stop_at_eos)
408
+ return tf.numpy_function(fn, [t], tf.string, stateful=False)
409
+ if _iterable(tokens):
410
+ return tf.map_fn(single, tokens, tf.string)
411
+ return single(tokens)
Tipsomaly/model/big_vision/pp/ops_text_test.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for ops_text."""
16
+
17
+ import copy
18
+
19
+ from absl.testing import parameterized
20
+ import big_vision.pp.ops_text as pp
21
+ from big_vision.pp.registry import Registry
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+
26
+ class PyToTfWrapper:
27
+ """Allows to use `to_{int,str}_tf()` via `to_{int,str}()`."""
28
+
29
+ def __init__(self, tok):
30
+ self.tok = tok
31
+ self.bos_token = tok.bos_token
32
+ self.eos_token = tok.eos_token
33
+ self.vocab_size = tok.vocab_size
34
+
35
+ def to_int(self, text, *, bos=False, eos=False):
36
+ ret = self.tok.to_int_tf_op(text, bos=bos, eos=eos)
37
+ if isinstance(ret, tf.RaggedTensor):
38
+ return [t.numpy().tolist() for t in ret]
39
+ return ret.numpy().tolist()
40
+
41
+ def to_str(self, tokens, stop_at_eos=True):
42
+ ret = self.tok.to_str_tf_op(
43
+ tf.ragged.constant(tokens),
44
+ stop_at_eos=stop_at_eos,
45
+ )
46
+ if ret.ndim == 0:
47
+ return ret.numpy().decode()
48
+ return [t.numpy().decode() for t in ret]
49
+
50
+
51
+ class PpOpsTest(tf.test.TestCase, parameterized.TestCase):
52
+
53
+ def tfrun(self, ppfn, data):
54
+ # Run once as standalone, as could happen eg in colab.
55
+ yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()}
56
+
57
+ # And then once again as part of tfdata pipeline.
58
+ # You'd be surprised how much these two differ!
59
+ tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data))
60
+ for npdata in tfdata.map(ppfn).as_numpy_iterator():
61
+ yield npdata
62
+
63
+ def testtok(self):
64
+ # https://github.com/google/sentencepiece/blob/master/python/test/test_model.model
65
+ return "test_model.model" # Should we just commit it? It's 200kB
66
+
67
+ def test_get_pp_clip_i1k_label_names(self):
68
+ op = pp.get_pp_clip_i1k_label_names()
69
+ labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist()
70
+ self.assertAllEqual(labels, ["tench", "goldfish"])
71
+
72
+ def test_get_pp_i21k_label_names(self):
73
+ op = pp.get_pp_i21k_label_names()
74
+ labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist()
75
+ self.assertAllEqual(labels, ["organism", "benthos"])
76
+
77
+ @parameterized.parameters((b"Hello world ScAlAr!", b"hello world scalar!"),
78
+ (["Decoded Array!"], ["decoded array!"]),
79
+ ([b"aA", "bB"], [b"aa", "bb"]))
80
+ def test_get_lower(self, inputs, expected_output):
81
+ op = pp.get_lower()
82
+ out = op({"text": tf.constant(inputs)})
83
+ self.assertAllEqual(out["text"].numpy(), np.array(expected_output))
84
+
85
+ @parameterized.named_parameters(
86
+ ("py", False),
87
+ ("tf", True),
88
+ )
89
+ def test_sentencepiece_tokenizer(self, wrap_tok):
90
+ tok = pp.SentencepieceTokenizer(self.testtok())
91
+ if wrap_tok:
92
+ tok = PyToTfWrapper(tok)
93
+ self.assertEqual(tok.vocab_size, 1000)
94
+ bos, eos = tok.bos_token, tok.eos_token
95
+ self.assertEqual(bos, 1)
96
+ self.assertEqual(eos, 2)
97
+ # Note: test model does NOT have a <pad> token (similar to e.g. "mistral").
98
+ # `.to_int()` wraps `.to_int_tf_ops` which is thus also tested
99
+ self.assertEqual(tok.to_int("blah"), [80, 180, 60])
100
+ self.assertEqual(tok.to_int("blah", bos=True), [bos, 80, 180, 60])
101
+ self.assertEqual(tok.to_int("blah", eos=True), [80, 180, 60, eos])
102
+ self.assertEqual(
103
+ tok.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos]
104
+ )
105
+ self.assertEqual(
106
+ tok.to_int(["blah", "blah blah"]),
107
+ [[80, 180, 60], [80, 180, 60, 80, 180, 60]],
108
+ )
109
+ # inverse of above
110
+ # `.to_str()` wraps `.to_str_tf_ops` which is thus also tested
111
+ self.assertEqual(tok.to_str([80, 180, 60]), "blah")
112
+ self.assertEqual(tok.to_str([1, 80, 180, 60]), "blah")
113
+ self.assertEqual(tok.to_str([80, 180, 60, 2]), "blah")
114
+ self.assertEqual(
115
+ tok.to_str([[80, 180, 60], [80, 180, 60, 80, 180, 60]]),
116
+ ["blah", "blah blah"],
117
+ )
118
+
119
+ def test_sentencepiece_tokenizer_tf_op_ndarray_input(self):
120
+ tok = pp.SentencepieceTokenizer(self.testtok())
121
+ bos, eos = tok.bos_token, tok.eos_token
122
+ arr = np.array([[bos, 80, 180, 60, eos]] * 2, dtype=np.int32)
123
+ self.assertEqual(tok.to_str_tf_op(arr).numpy().tolist(), [b"blah"] * 2)
124
+
125
+ def test_sentencepiece_tokenizer_tokensets(self):
126
+ tok = pp.SentencepieceTokenizer(self.testtok(), tokensets=["loc"])
127
+ self.assertEqual(tok.vocab_size, 2024)
128
+ self.assertEqual(
129
+ tok.to_int("blah<loc0000><loc1023>"), [80, 180, 60, 1000, 2023]
130
+ )
131
+
132
+ def test_sentencepiece_stop_at_eos(self):
133
+ tok = pp.SentencepieceTokenizer(self.testtok())
134
+ self.assertEqual(tok.to_str([80, 180, 60], stop_at_eos=False), "blah")
135
+ eos = tok.eos_token
136
+ self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=False), "blah")
137
+ self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=True), "b")
138
+ self.assertEqual(
139
+ tok.to_str([[80, eos, 180, 60], [80, 180, eos, 60]], stop_at_eos=True),
140
+ ["b", "bla"]
141
+ )
142
+
143
+ def test_sentencepiece_extra_tokens(self):
144
+ tok = pp.SentencepieceTokenizer(self.testtok())
145
+ self.assertEqual(tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "blah")
146
+ tok = pp.SentencepieceTokenizer(
147
+ self.testtok(), tokensets=["sp_extra_tokens"]
148
+ )
149
+ self.assertEqual(tok.vocab_size, 1001) # Also added the <pad> token.
150
+ self.assertEqual(
151
+ tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "<s> blah</s>"
152
+ )
153
+
154
+ def test_strfmt(self):
155
+ data = {
156
+ "int": tf.constant(42, tf.uint8),
157
+ "float": tf.constant(3.14, tf.float32),
158
+ "vec": tf.range(3),
159
+ "empty_str": tf.constant(""),
160
+ "regex_problem1": tf.constant(r"no \replace pattern"),
161
+ "regex_problem2": tf.constant(r"yes \1 pattern"),
162
+ }
163
+ for out in self.tfrun(pp.get_strfmt("Nothing"), data):
164
+ self.assertEqual(out["text"], b"Nothing")
165
+ for out in self.tfrun(pp.get_strfmt("{int}"), data):
166
+ self.assertEqual(out["text"], b"42")
167
+ for out in self.tfrun(pp.get_strfmt("A{int}"), data):
168
+ self.assertEqual(out["text"], b"A42")
169
+ for out in self.tfrun(pp.get_strfmt("{int}A"), data):
170
+ self.assertEqual(out["text"], b"42A")
171
+ for out in self.tfrun(pp.get_strfmt("{int}{int}"), data):
172
+ self.assertEqual(out["text"], b"4242")
173
+ for out in self.tfrun(pp.get_strfmt("A{int}A{int}A"), data):
174
+ self.assertEqual(out["text"], b"A42A42A")
175
+ for out in self.tfrun(pp.get_strfmt("A{float}A"), data):
176
+ self.assertEqual(out["text"], b"A3.14A")
177
+ for out in self.tfrun(pp.get_strfmt("A{float}A{int}"), data):
178
+ self.assertEqual(out["text"], b"A3.14A42")
179
+ for out in self.tfrun(pp.get_strfmt("A{vec}A"), data):
180
+ self.assertEqual(out["text"], b"A[0 1 2]A")
181
+ for out in self.tfrun(pp.get_strfmt("A{empty_str}A"), data):
182
+ self.assertEqual(out["text"], b"AA")
183
+ for out in self.tfrun(pp.get_strfmt("{empty_str}"), data):
184
+ self.assertEqual(out["text"], b"")
185
+ for out in self.tfrun(pp.get_strfmt("A{regex_problem1}A"), data):
186
+ self.assertEqual(out["text"], br"Ano \replace patternA")
187
+ for out in self.tfrun(pp.get_strfmt("A{regex_problem2}A"), data):
188
+ self.assertEqual(out["text"], br"Ayes \1 patternA")
189
+
190
+
191
+ @Registry.register("tokensets.sp_extra_tokens")
192
+ def _get_sp_extra_tokens():
193
+ # For sentencepiece, adding these tokens will make them visible when decoding.
194
+ # If a token is not found (e.g. "<pad>" is not found in "mistral"), then it is
195
+ # added to the vocabulary, increasing the vocab_size accordingly.
196
+ return ["<s>", "</s>", "<pad>"]
197
+
198
+
199
+ if __name__ == "__main__":
200
+ tf.test.main()
Tipsomaly/model/big_vision/pp/registry.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Global Registry for big_vision pp ops.
16
+
17
+ Author: Joan Puigcerver (jpuigcerver@)
18
+ """
19
+
20
+ from __future__ import absolute_import
21
+ from __future__ import division
22
+ from __future__ import print_function
23
+
24
+ import ast
25
+ import contextlib
26
+ import functools
27
+
28
+
29
+ def parse_name(string_to_parse):
30
+ """Parses input to the registry's lookup function.
31
+
32
+ Args:
33
+ string_to_parse: can be either an arbitrary name or function call
34
+ (optionally with positional and keyword arguments).
35
+ e.g. "multiclass", "resnet50_v2(filters_factor=8)".
36
+
37
+ Returns:
38
+ A tuple of input name, argument tuple and a keyword argument dictionary.
39
+ Examples:
40
+ "multiclass" -> ("multiclass", (), {})
41
+ "resnet50_v2(9, filters_factor=4)" ->
42
+ ("resnet50_v2", (9,), {"filters_factor": 4})
43
+
44
+ Author: Joan Puigcerver (jpuigcerver@)
45
+ """
46
+ expr = ast.parse(string_to_parse, mode="eval").body # pytype: disable=attribute-error
47
+ if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)):
48
+ raise ValueError(
49
+ "The given string should be a name or a call, but a {} was parsed from "
50
+ "the string {!r}".format(type(expr), string_to_parse))
51
+
52
+ # Notes:
53
+ # name="some_name" -> type(expr) = ast.Name
54
+ # name="module.some_name" -> type(expr) = ast.Attribute
55
+ # name="some_name()" -> type(expr) = ast.Call
56
+ # name="module.some_name()" -> type(expr) = ast.Call
57
+
58
+ if isinstance(expr, ast.Name):
59
+ return string_to_parse, (), {}
60
+ elif isinstance(expr, ast.Attribute):
61
+ return string_to_parse, (), {}
62
+
63
+ def _get_func_name(expr):
64
+ if isinstance(expr, ast.Attribute):
65
+ return _get_func_name(expr.value) + "." + expr.attr
66
+ elif isinstance(expr, ast.Name):
67
+ return expr.id
68
+ else:
69
+ raise ValueError(
70
+ "Type {!r} is not supported in a function name, the string to parse "
71
+ "was {!r}".format(type(expr), string_to_parse))
72
+
73
+ def _get_func_args_and_kwargs(call):
74
+ args = tuple([ast.literal_eval(arg) for arg in call.args])
75
+ kwargs = {
76
+ kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords
77
+ }
78
+ return args, kwargs
79
+
80
+ func_name = _get_func_name(expr.func)
81
+ func_args, func_kwargs = _get_func_args_and_kwargs(expr)
82
+
83
+ return func_name, func_args, func_kwargs
84
+
85
+
86
+ class Registry(object):
87
+ """Implements global Registry.
88
+
89
+ Authors: Joan Puigcerver (jpuigcerver@), Alexander Kolesnikov (akolesnikov@)
90
+ """
91
+
92
+ _GLOBAL_REGISTRY = {}
93
+
94
+ @staticmethod
95
+ def global_registry():
96
+ return Registry._GLOBAL_REGISTRY
97
+
98
+ @staticmethod
99
+ def register(name, replace=False):
100
+ """Creates a function that registers its input."""
101
+
102
+ def _register(item):
103
+ if name in Registry.global_registry() and not replace:
104
+ raise KeyError("The name {!r} was already registered.".format(name))
105
+
106
+ Registry.global_registry()[name] = item
107
+ return item
108
+
109
+ return _register
110
+
111
+ @staticmethod
112
+ def lookup(lookup_string, kwargs_extra=None):
113
+ """Lookup a name in the registry."""
114
+
115
+ try:
116
+ name, args, kwargs = parse_name(lookup_string)
117
+ except ValueError as e:
118
+ raise ValueError(f"Error parsing:\n{lookup_string}") from e
119
+ if kwargs_extra:
120
+ kwargs.update(kwargs_extra)
121
+ item = Registry.global_registry()[name]
122
+ return functools.partial(item, *args, **kwargs)
123
+
124
+ @staticmethod
125
+ def knows(lookup_string):
126
+ try:
127
+ name, _, _ = parse_name(lookup_string)
128
+ except ValueError as e:
129
+ raise ValueError(f"Error parsing:\n{lookup_string}") from e
130
+ return name in Registry.global_registry()
131
+
132
+
133
+ @contextlib.contextmanager
134
+ def temporary_ops(**kw):
135
+ """Registers specified pp ops for use in a `with` block.
136
+
137
+ Example use:
138
+
139
+ with pp_registry.remporary_ops(
140
+ pow=lambda alpha: lambda d: {k: v**alpha for k, v in d.items()}):
141
+ pp = pp_builder.get_preprocess_fn("pow(alpha=2.0)|pow(alpha=0.5)")
142
+ features = pp(features)
143
+
144
+ Args:
145
+ **kw: Names are preprocess string function names to be used to specify the
146
+ preprocess function. Values are functions that can be called with params
147
+ (e.g. the `alpha` param in above example) and return functions to be used
148
+ to transform features.
149
+
150
+ Yields:
151
+ A context manager to be used in a `with` statement.
152
+ """
153
+ reg = Registry.global_registry()
154
+ kw = {f"preprocess_ops.{k}": v for k, v in kw.items()}
155
+ for k in kw:
156
+ assert k not in reg
157
+ for k, v in kw.items():
158
+ reg[k] = v
159
+ try:
160
+ yield
161
+ finally:
162
+ for k in kw:
163
+ del reg[k]
Tipsomaly/model/big_vision/pp/registry_test.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for registry."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ from unittest import mock
22
+
23
+ from absl.testing import absltest
24
+ from big_vision.pp import registry
25
+
26
+
27
+ class RegistryTest(absltest.TestCase):
28
+
29
+ def setUp(self):
30
+ super(RegistryTest, self).setUp()
31
+ # Mock global registry in each test to keep them isolated and allow for
32
+ # concurrent tests.
33
+ self.addCleanup(mock.patch.stopall)
34
+ self.global_registry = dict()
35
+ self.mocked_method = mock.patch.object(
36
+ registry.Registry, "global_registry",
37
+ return_value=self.global_registry).start()
38
+
39
+ def test_parse_name(self):
40
+ name, args, kwargs = registry.parse_name("f")
41
+ self.assertEqual(name, "f")
42
+ self.assertEqual(args, ())
43
+ self.assertEqual(kwargs, {})
44
+
45
+ name, args, kwargs = registry.parse_name("f()")
46
+ self.assertEqual(name, "f")
47
+ self.assertEqual(args, ())
48
+ self.assertEqual(kwargs, {})
49
+
50
+ name, args, kwargs = registry.parse_name("func(a=0,b=1,c='s')")
51
+ self.assertEqual(name, "func")
52
+ self.assertEqual(args, ())
53
+ self.assertEqual(kwargs, {"a": 0, "b": 1, "c": "s"})
54
+
55
+ name, args, kwargs = registry.parse_name("func(1,'foo',3)")
56
+ self.assertEqual(name, "func")
57
+ self.assertEqual(args, (1, "foo", 3))
58
+ self.assertEqual(kwargs, {})
59
+
60
+ name, args, kwargs = registry.parse_name("func(1,'2',a=3,foo='bar')")
61
+ self.assertEqual(name, "func")
62
+ self.assertEqual(args, (1, "2"))
63
+ self.assertEqual(kwargs, {"a": 3, "foo": "bar"})
64
+
65
+ name, args, kwargs = registry.parse_name("foo.bar.func(a=0,b=(1),c='s')")
66
+ self.assertEqual(name, "foo.bar.func")
67
+ self.assertEqual(kwargs, dict(a=0, b=1, c="s"))
68
+
69
+ with self.assertRaises(SyntaxError):
70
+ registry.parse_name("func(0")
71
+ with self.assertRaises(SyntaxError):
72
+ registry.parse_name("func(a=0,,b=0)")
73
+ with self.assertRaises(SyntaxError):
74
+ registry.parse_name("func(a=0,b==1,c='s')")
75
+ with self.assertRaises(ValueError):
76
+ registry.parse_name("func(a=0,b=undefined_name,c='s')")
77
+
78
+ def test_register(self):
79
+ # pylint: disable=unused-variable
80
+ @registry.Registry.register("func1")
81
+ def func1():
82
+ pass
83
+
84
+ self.assertLen(registry.Registry.global_registry(), 1)
85
+
86
+ def test_lookup_function(self):
87
+
88
+ @registry.Registry.register("func1")
89
+ def func1(arg1, arg2, arg3): # pylint: disable=unused-variable
90
+ return arg1, arg2, arg3
91
+
92
+ self.assertTrue(callable(registry.Registry.lookup("func1")))
93
+ self.assertEqual(registry.Registry.lookup("func1")(1, 2, 3), (1, 2, 3))
94
+ self.assertEqual(
95
+ registry.Registry.lookup("func1(arg3=9)")(1, 2), (1, 2, 9))
96
+ self.assertEqual(
97
+ registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg3=3), (99, 9, 3))
98
+ self.assertEqual(
99
+ registry.Registry.lookup("func1(arg2=9,arg1=99)")(arg1=1, arg3=3),
100
+ (1, 9, 3))
101
+
102
+ self.assertEqual(
103
+ registry.Registry.lookup("func1(1)")(1, 2), (1, 1, 2))
104
+ self.assertEqual(
105
+ registry.Registry.lookup("func1(1)")(arg3=3, arg2=2), (1, 2, 3))
106
+ self.assertEqual(
107
+ registry.Registry.lookup("func1(1, 2)")(3), (1, 2, 3))
108
+ self.assertEqual(
109
+ registry.Registry.lookup("func1(1, 2)")(arg3=3), (1, 2, 3))
110
+ self.assertEqual(
111
+ registry.Registry.lookup("func1(1, arg2=2)")(arg3=3), (1, 2, 3))
112
+ self.assertEqual(
113
+ registry.Registry.lookup("func1(1, arg3=2)")(arg2=3), (1, 3, 2))
114
+ self.assertEqual(
115
+ registry.Registry.lookup("func1(1, arg3=2)")(3), (1, 3, 2))
116
+
117
+ with self.assertRaises(TypeError):
118
+ registry.Registry.lookup("func1(1, arg2=2)")(3)
119
+ with self.assertRaises(TypeError):
120
+ registry.Registry.lookup("func1(1, arg3=3)")(arg3=3)
121
+ with self.assertRaises(TypeError):
122
+ registry.Registry.lookup("func1(1, arg3=3)")(arg1=3)
123
+ with self.assertRaises(SyntaxError):
124
+ registry.Registry.lookup("func1(arg1=1, 3)")(arg2=3)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ absltest.main()
Tipsomaly/model/big_vision/pp/tokenizer.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Big Vision Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The tokenizer API for big_vision, and central registration place."""
16
+ import functools
17
+ import importlib
18
+ from typing import Protocol
19
+
20
+ from absl import logging
21
+ from big_vision.pp import registry
22
+ import big_vision.utils as u
23
+ import numpy as np
24
+
25
+
26
+ class Tokenizer(Protocol):
27
+ """Just to unify on the API as we now have mmany different ones."""
28
+
29
+ def to_int(self, text, *, bos=False, eos=False):
30
+ """Tokenizes `text` into a list of integer tokens.
31
+
32
+ Args:
33
+ text: can be a single string, or a list of strings.
34
+ bos: Whether a beginning-of-sentence token should be prepended.
35
+ eos: Whether an end-of-sentence token should be appended.
36
+
37
+ Returns:
38
+ List or list-of-list of tokens.
39
+ """
40
+
41
+ def to_int_tf_op(self, text, *, bos=False, eos=False):
42
+ """Same as `to_int()`, but as TF ops to be used in pp."""
43
+
44
+ def to_str(self, tokens, *, stop_at_eos=True):
45
+ """Inverse of `to_int()`.
46
+
47
+ Args:
48
+ tokens: list of tokens, or list of lists of tokens.
49
+ stop_at_eos: remove everything that may come after the first EOS.
50
+
51
+ Returns:
52
+ A string (if `tokens` is a list of tokens), or a list of strings.
53
+ Note that most tokenizers strip select few control tokens like
54
+ eos/bos/pad/unk from the output string.
55
+ """
56
+
57
+ def to_str_tf_op(self, tokens, *, stop_at_eos=True):
58
+ """Same as `to_str()`, but as TF ops to be used in pp."""
59
+
60
+ @property
61
+ def pad_token(self):
62
+ """Token id of padding token."""
63
+
64
+ @property
65
+ def eos_token(self):
66
+ """Token id of end-of-sentence token."""
67
+
68
+ @property
69
+ def bos_token(self):
70
+ """Token id of beginning-of-sentence token."""
71
+
72
+ @property
73
+ def vocab_size(self):
74
+ """Returns the size of the vocabulary."""
75
+
76
+
77
+ @functools.cache
78
+ def get_tokenizer(name):
79
+ with u.chrono.log_timing(f"z/secs/tokenizer/{name}"):
80
+ if not registry.Registry.knows(f"tokenizers.{name}"):
81
+ raw_name, *_ = registry.parse_name(name)
82
+ logging.info("Tokenizer %s not registered, "
83
+ "trying import big_vision.pp.%s", name, raw_name)
84
+ importlib.import_module(f"big_vision.pp.{raw_name}")
85
+
86
+ return registry.Registry.lookup(f"tokenizers.{name}")()
87
+
88
+
89
+ def get_extra_tokens(tokensets):
90
+ extra_tokens = []
91
+ for tokenset in tokensets:
92
+ extra_tokens.extend(registry.Registry.lookup(f"tokensets.{tokenset}")())
93
+ return list(np.unique(extra_tokens)) # Preserves order. Dups make no sense.
94
+
95
+
96
+ @registry.Registry.register("tokensets.loc")
97
+ def _get_loc1024(n=1024):
98
+ return [f"<loc{i:04d}>" for i in range(n)]
99
+
100
+
101
+ @registry.Registry.register("tokensets.seg")
102
+ def _get_seg(n=128):
103
+ return [f"<seg{i:03d}>" for i in range(n)]