Add files using upload-large-folder tool
Browse files- Tipsomaly/model/big_vision/configs/proj/paligemma/README.md +282 -0
- Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma2.png +0 -0
- Tipsomaly/model/big_vision/configs/proj/reward_tune/detection_reward.py +232 -0
- Tipsomaly/model/big_vision/configs/proj/scaling_laws/train_vit_g.py +87 -0
- Tipsomaly/model/big_vision/configs/proj/uvim/README.md +84 -0
- Tipsomaly/model/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py +164 -0
- Tipsomaly/model/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py +161 -0
- Tipsomaly/model/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py +170 -0
- Tipsomaly/model/big_vision/configs/proj/uvim/uvim_color_task.ipynb +167 -0
Tipsomaly/model/big_vision/configs/proj/paligemma/README.md
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PaliGemma model README
|
| 2 |
+
|
| 3 |
+
PaliGemma is an open vision-language model (VLM) inspired by PaLI-3, built with
|
| 4 |
+
open components, such as
|
| 5 |
+
the [SigLIP vision model](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb)
|
| 6 |
+
and
|
| 7 |
+
the [Gemma language model](https://ai.google.dev/gemma).
|
| 8 |
+
PaliGemma is designed as a versatile model for transfer to a wide range of
|
| 9 |
+
vision-language tasks such as image and short video caption, visual question
|
| 10 |
+
answering, text reading, object detection and object segmentation. Together with
|
| 11 |
+
the pretrained checkpoints (PaliGemma and PaliGemma 2) we also provide transfer
|
| 12 |
+
checkpoints at multiple resolutions and a checkpoint transferred to a mixture of
|
| 13 |
+
tasks that can be used for off-the-shelf exploration (PaliGemma only).
|
| 14 |
+
|
| 15 |
+
## Quick Reference
|
| 16 |
+
|
| 17 |
+
This is the reference repository of the model, you may also want to check out the resources on
|
| 18 |
+
|
| 19 |
+
- Technical reports on ArXiv: [PaliGemma](https://arxiv.org/abs/2407.07726),
|
| 20 |
+
[PaliGemma 2](https://arxiv.org/abs/2412.03555)
|
| 21 |
+
- Pre-trained / mix checkpoints and model card on Kaggle:
|
| 22 |
+
[PaliGemma](https://www.kaggle.com/models/google/paligemma),
|
| 23 |
+
[PaliGemma transfers](https://www.kaggle.com/models/google/paligemma-ft),
|
| 24 |
+
[PaliGemma 2](https://www.kaggle.com/models/google/paligemma-2)
|
| 25 |
+
- Google Cloud VertexAI Model Garden:
|
| 26 |
+
[PaliGemma](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363)
|
| 27 |
+
- PyTorch and JAX models on Hugging Face:
|
| 28 |
+
[PaliGemma](https://huggingface.co/collections/google/paligemma-release-6643a9ffbf57de2ae0448dda),
|
| 29 |
+
[PaliGemma 2](https://huggingface.co/collections/google/paligemma-2-release-67500e1e1dbfdd4dee27ba48)
|
| 30 |
+
- Light fine-tuning using `big_vision` on a single (free) T4 GPU:
|
| 31 |
+
[Colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb)
|
| 32 |
+
- Demo: [HuggingFace PaliGemma space](https://hf.co/spaces/google/paligemma)
|
| 33 |
+
|
| 34 |
+
### Citation BibTeX
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
@article{beyer2024paligemma,
|
| 38 |
+
title={{PaliGemma: A versatile 3B VLM for transfer}},
|
| 39 |
+
author={Lucas Beyer and Andreas Steiner and André Susano Pinto and Alexander Kolesnikov and Xiao Wang and Daniel Salz and Maxim Neumann and Ibrahim Alabdulmohsin and Michael Tschannen and Emanuele Bugliarello and Thomas Unterthiner and Daniel Keysers and Skanda Koppula and Fangyu Liu and Adam Grycner and Alexey Gritsenko and Neil Houlsby and Manoj Kumar and Keran Rong and Julian Eisenschlos and Rishabh Kabra and Matthias Bauer and Matko Bošnjak and Xi Chen and Matthias Minderer and Paul Voigtlaender and Ioana Bica and Ivana Balazevic and Joan Puigcerver and Pinelopi Papalampidi and Olivier Henaff and Xi Xiong and Radu Soricut and Jeremiah Harmsen and Xiaohua Zhai},
|
| 40 |
+
year={2024},
|
| 41 |
+
journal={arXiv preprint arXiv:2407.07726}
|
| 42 |
+
}
|
| 43 |
+
@article{steiner2024paligemma2,
|
| 44 |
+
title={{PaliGemma 2: A Family of Versatile VLMs for Transfer}},
|
| 45 |
+
author={Andreas Steiner and André Susano Pinto and Michael Tschannen and Daniel Keysers and Xiao Wang and Yonatan Bitton and Alexey Gritsenko and Matthias Minderer and Anthony Sherbondy and Shangbang Long and Siyang Qin and Reeve Ingle and Emanuele Bugliarello and Sahar Kazemzadeh and Thomas Mesnard and Ibrahim Alabdulmohsin and Lucas Beyer and Xiaohua Zhai},
|
| 46 |
+
year={2024},
|
| 47 |
+
journal={arXiv preprint arXiv:2412.03555}
|
| 48 |
+
}
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Model description
|
| 52 |
+
|
| 53 |
+
### Overview
|
| 54 |
+
|
| 55 |
+
PaliGemma is Vision-Language model that was inspired by the PaLI-3 recipe. It is
|
| 56 |
+
built on SigLIP visual encoder (specifically, SigLIP-So400m/14) and the
|
| 57 |
+
Gemma language model. PaliGemma takes as input one or more images, which are
|
| 58 |
+
turned into "soft tokens" by the SigLIP encoder, and input text (codenamed the
|
| 59 |
+
"prefix") that is tokenized by Gemma's tokenizer. The image tokens and prefix
|
| 60 |
+
tokens are concatenated (in this order) and passed to the Gemma decoder with
|
| 61 |
+
full block-attention, which then generates an output text (the "suffix")
|
| 62 |
+
auto-regressively with masked attention.
|
| 63 |
+
|
| 64 |
+

|
| 65 |
+
|
| 66 |
+
Note that PaliGemma uses Gemma 2B model, PaliGemma 2 uses Gemma 2 {2B,9B,27B}
|
| 67 |
+
models.
|
| 68 |
+
|
| 69 |
+
### Training stages
|
| 70 |
+
|
| 71 |
+
Similar to PaLI-3, PaliGemma's training consists of multiple stages:
|
| 72 |
+
|
| 73 |
+
- Stage 0: the unimodal pre-training. We use publicly available off-the-shelf
|
| 74 |
+
SigLIP and Gemma models which have been pre-trained unimodally by their
|
| 75 |
+
respective authors.
|
| 76 |
+
- Stage 1: multimodal pre-training. The combined PaliGemma model is now
|
| 77 |
+
pre-trained on a fully multimodal training dataset, this at a low resolution
|
| 78 |
+
of 224px² and prefix+suffix sequence length of 128 tokens. This results in
|
| 79 |
+
the first base model that we release.
|
| 80 |
+
- Stage 2: high-resolution pre-training. We continue pre-training of the
|
| 81 |
+
Stage 1 model at resolution 448px² with sequence length 512 tokens for a short
|
| 82 |
+
duration on the same multimodal training data, but re-weighted with more
|
| 83 |
+
emphasis on examples that make use of higher resolution or longer sequence
|
| 84 |
+
length. We repeat this once more at resolution 896px². This results in two
|
| 85 |
+
further "high res" base models that we also release.
|
| 86 |
+
- Stage 3: fine-tune. The base models are transferred to
|
| 87 |
+
specific tasks by fine-tuning. To facilitate further research and
|
| 88 |
+
reproducibility, we release checkpoints fine-tuned on most of the benchmarks
|
| 89 |
+
we evaluate on. We also provide a "mix" transfer model, fine-tuned on a wide
|
| 90 |
+
variety of data, for use in interactive demos.
|
| 91 |
+
|
| 92 |
+
Most of the code examples, use-cases, and code release are about Stage 3:
|
| 93 |
+
transferring to a task or dataset of interest to the user.
|
| 94 |
+
|
| 95 |
+
### Tokenizer
|
| 96 |
+
|
| 97 |
+
PaliGemma uses the Gemma tokenizer with 256'000 tokens, but we further extend
|
| 98 |
+
its vocabulary with 1024 entries that represent coordinates in normalized
|
| 99 |
+
image-space (\<loc0000>...\<loc1023>), and another with 128 entries
|
| 100 |
+
(\<seg000>...\<seg127>) that are codewords used by a lightweight
|
| 101 |
+
referring-expression segmentation vector-quantized variational auto-encoder
|
| 102 |
+
(VQ-VAE) with the architecture of [Ning et al. (2023)](https://arxiv.org/abs/2301.02229) and trained on OpenImages
|
| 103 |
+
as in PaLI-3. While the `big_vision` codebase is flexible enough to extend
|
| 104 |
+
tokenizers on-the-fly, we also provide a SentencePiece model file of the Gemma
|
| 105 |
+
tokenizer with these additional tokens baked in, for the convenience of
|
| 106 |
+
other codebases.
|
| 107 |
+
|
| 108 |
+
## Checkpoints
|
| 109 |
+
|
| 110 |
+
The PaliGemma models are released under the same open license as the Gemma
|
| 111 |
+
models, and hence require manual acknowledgement of the license terms. See
|
| 112 |
+
above [Quick Reference](#quick-reference) for download links.
|
| 113 |
+
|
| 114 |
+
### Pretrained checkpoints
|
| 115 |
+
|
| 116 |
+
Use one of these checkpoints as initialization for fine-tuning:
|
| 117 |
+
|
| 118 |
+
- pt-224: Versatile pretrained model for tasks that do not require seeing
|
| 119 |
+
small details in the image.
|
| 120 |
+
Examples: natural image captioning and question-answering, detection and
|
| 121 |
+
segmentation of medium-large objects. This model was trained with
|
| 122 |
+
sequence length 128.
|
| 123 |
+
- pt-448: Versatile base model for mid/higher resolution tasks with access
|
| 124 |
+
to smaller details. Besides higher resolution, it has gotten more weight on
|
| 125 |
+
text reading, detection, and segmentation during its pre-training. Examples:
|
| 126 |
+
as above, plus detection, segmentation, text/diagram reading. This model was
|
| 127 |
+
trained with sequence length 512.
|
| 128 |
+
- pt-896: Further scaled-up version of pt-448, especially good at reading
|
| 129 |
+
very small texts as often found in documents and infographics. This model
|
| 130 |
+
was trained with sequence length 512.
|
| 131 |
+
|
| 132 |
+
Besides the reference float32 checkpoint (11GB), we further provide
|
| 133 |
+
bfloat16 and float16 variants of each, to reduce download and storage time.
|
| 134 |
+
These are good for inference and frozen transfers, but full fine-tuning
|
| 135 |
+
should happen in float32 or mixed precision.
|
| 136 |
+
|
| 137 |
+
### Mixture checkpoint
|
| 138 |
+
|
| 139 |
+
(Currently only available for PaliGemma)
|
| 140 |
+
|
| 141 |
+
This checkpoint is trained on a mixture of all our transfer tasks,
|
| 142 |
+
with a balancing intended to make it "nice to use" out of the box for
|
| 143 |
+
predictions. This model is multilingual and should
|
| 144 |
+
understand prompts in various languages, although English
|
| 145 |
+
is still its "mother tongue".
|
| 146 |
+
Questions can be asked in a natural way (including asking for a caption or
|
| 147 |
+
reading the text), and detection and segmentation should still work with the
|
| 148 |
+
structured `detect {things}` and `segment {things}` prompts as in the base model.
|
| 149 |
+
|
| 150 |
+
- mix-224: Similarly to pt-224, this model is good at many natural image
|
| 151 |
+
tasks that do not require high resolution. Unlike the raw pre-trained model,
|
| 152 |
+
however, it can be interacted with more freely. For example, ask it to
|
| 153 |
+
"describe this image in great detail, please" or "How many coins do you see
|
| 154 |
+
in the picture?". This model was trained with sequence length 256.
|
| 155 |
+
- mix-448: As above, but it is better at tasks that require higher-resolution
|
| 156 |
+
input. For example, one could ask it "what is written in the "sum" field?",
|
| 157 |
+
to "describe this figure", or to "what is the GDP of France?" when shown an
|
| 158 |
+
infographic of countries' GDPs. This model was trained with
|
| 159 |
+
sequence length 512.
|
| 160 |
+
|
| 161 |
+
### Transfers results and checkpoints
|
| 162 |
+
|
| 163 |
+
(DOCCI only available for PaliGemma 2, others only available for PaliGemma)
|
| 164 |
+
|
| 165 |
+
We provide checkpoints transferred to most of the tasks we evaluated
|
| 166 |
+
transfer on, see the [kaggle page](https://www.kaggle.com/models/google/paligemma).
|
| 167 |
+
These are intended for use when a specialised model corresponding
|
| 168 |
+
to one of the tasks is needed, for academic research purposes only.
|
| 169 |
+
Depending on the task, they may require a specialised preprocessing format.
|
| 170 |
+
|
| 171 |
+
The transfer setup is reasonably unified, with the main factors of variation
|
| 172 |
+
being the training duration, learning-rate, and whether or not to use dropout
|
| 173 |
+
and label-smoothing. Details can be found in the corresponding config files or
|
| 174 |
+
in an upcoming tech report.
|
| 175 |
+
|
| 176 |
+
Importantly, none of these tasks or datasets are part of the pre-training data
|
| 177 |
+
mixture, and their images are explicitly removed from the web-scale
|
| 178 |
+
pretraining data.
|
| 179 |
+
|
| 180 |
+
#### Captioning
|
| 181 |
+
|
| 182 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
| 183 |
+
-----------------------|----------------|--------|--------|--------
|
| 184 |
+
[COCO captions](https://cocodataset.org/#home) (train+restval) | CIDEr (val) | 141.92 | 144.60 |
|
| 185 |
+
[NoCaps](https://nocaps.org/) (Eval of COCO captions transfer) | CIDEr (val) | 121.72 | 123.58 |
|
| 186 |
+
[COCO-35L](https://arxiv.org/abs/2205.12522) (train) | CIDEr dev (en / avg-34 / avg) | 139.2 / 115.8 / 116.4 | 141.2 / 118.0 / 118.6 |
|
| 187 |
+
[XM3600](https://arxiv.org/abs/2205.12522) (Eval of COCO-35L transfer) | CIDEr test (en / avg-35 / avg) | 78.1 / 41.3 / 42.4 | 80.0 / 41.9 / 42.9 |
|
| 188 |
+
[TextCaps](https://textvqa.org/textcaps/) (train) | CIDEr (val) | 127.48 | 153.94 |
|
| 189 |
+
[SciCap](https://arxiv.org/abs/2110.11624) (first sentence, no subfigure) (train+val) | CIDEr / BLEU-4 (test) | 162.25 / 0.192 | 181.49 / 0.211 |
|
| 190 |
+
[Screen2words](https://arxiv.org/abs/2108.03353) (train+dev) | CIDEr (test) | 117.57 | 119.59 |
|
| 191 |
+
[Widget Captioning](https://arxiv.org/abs/2010.04295) (train+dev) | CIDEr (test) | 136.07 | 148.36 |
|
| 192 |
+
|
| 193 |
+
#### Question Answering
|
| 194 |
+
|
| 195 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
| 196 |
+
-----------------------|----------------|--------|--------|--------
|
| 197 |
+
[VQAv2](https://visualqa.org/index.html) (train+validation) | Accuracy (Test server - std) | 83.19 | 85.64 |
|
| 198 |
+
[MMVP](https://arxiv.org/abs/2401.06209) (Eval of VQAv2 transfer) | Paired Accuracy | 47.33 | 45.33 |
|
| 199 |
+
[POPE](https://arxiv.org/abs/2305.10355) (Eval of VQAv2 transfer) | Accuracy (random / popular / adversarial) | 87.80 / 85.87 / 84.27 | 88.23 / 86.77 / 85.90 |
|
| 200 |
+
[Objaverse Multiview](https://arxiv.org/abs/2311.17851) (Eval of VQAv2 transfer) | Cosine Similarity (USEv4) | 62.7 | 62.8 |
|
| 201 |
+
[OKVQA](https://okvqa.allenai.org/) (train) | Accuracy (val) | 63.54 | 63.15 |
|
| 202 |
+
[A-OKVQA](https://allenai.org/project/a-okvqa/home) (MC) (train+val) | Accuracy (Test server) | 76.37 | 76.90 |
|
| 203 |
+
[A-OKVQA](https://allenai.org/project/a-okvqa/home) (DA) (train+val) | Accuracy (Test server) | 61.85 | 63.22 |
|
| 204 |
+
[GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html) (train_balanced+val_balanced) | Accuracy (testdev balanced) | 65.61 | 67.03 |
|
| 205 |
+
[xGQA](https://aclanthology.org/2022.findings-acl.196/) (Eval of GQA transfer) | Mean Accuracy (bn,de,en,id,ko,pt,ru,zh) | 58.37 | 59.07 |
|
| 206 |
+
[NLVR2](https://lil.nlp.cornell.edu/nlvr/) (train+dev) | Accuracy (test) | 90.02 | 88.93 |
|
| 207 |
+
[MaRVL](https://marvl-challenge.github.io/) (Eval of NLVR2 transfer) | Mean Accuracy (test) (id,sw,ta,tr,zh) | 80.57 | 76.78 |
|
| 208 |
+
[AI2D](https://allenai.org/data/diagrams) (train) | Accuracy (test) | 72.12 | 73.28 |
|
| 209 |
+
[ScienceQA](https://scienceqa.github.io/) (Img subset, no CoT) (train+val) | Accuracy (test) | 95.39 | 95.93 |
|
| 210 |
+
[RSVQA-LR](https://zenodo.org/records/6344334) (Non numeric) (train+val) | Mean Accuracy (test) | 92.65 | 93.11 |
|
| 211 |
+
[RSVQA-HR](https://zenodo.org/records/6344367) (Non numeric) (train+val) | Mean Accuracy (test/test2) | 92.61 / 90.58 | 92.79 / 90.54 |
|
| 212 |
+
[ChartQA](https://arxiv.org/abs/2203.10244) (human+aug)x(train+val) | Mean Relaxed Accuracy (test_human, test_aug) | 57.08 | 71.36 |
|
| 213 |
+
[VizWiz](https://vizwiz.org/tasks-and-datasets/vqa/) VQA (train+val) | Accuracy (Test server - std) | 73.7 | 75.52 |
|
| 214 |
+
[TallyQA](https://arxiv.org/abs/1810.12440) (train) | Accuracy (test_simple/test_complex) | 81.72 / 69.56 | 84.86 / 72.27 |
|
| 215 |
+
[OCR-VQA](https://ocr-vqa.github.io/) (train+val) | Accuracy (test) | 73.24 | 75.60 | 75.90
|
| 216 |
+
[TextVQA](https://textvqa.org/) (train+val) | Accuracy (Test server - std) | 55.47 | 73.15 | 76.48
|
| 217 |
+
[DocVQA](https://www.docvqa.org/) (train+val) | ANLS (Test server) | 43.74 | 78.02 | 84.77
|
| 218 |
+
[Infographic VQA](https://openaccess.thecvf.com/content/WACV2022/papers/Mathew_InfographicVQA_WACV_2022_paper.pdf) (train+val) | ANLS (Test server) | 28.46 | 40.47 | 47.75
|
| 219 |
+
[SceneText VQA](https://arxiv.org/abs/1905.13648) (train+val) | ANLS (Test server) | 63.29 | 81.82 | 84.40
|
| 220 |
+
|
| 221 |
+
#### Segmentation
|
| 222 |
+
|
| 223 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
| 224 |
+
-----------------------|----------------|--------|--------|--------
|
| 225 |
+
[RefCOCO](https://arxiv.org/abs/1608.00272) (combined refcoco, refcoco+, refcocog excluding val and test images) | MIoU (validation) refcoco / refcoco+ / refcocog | 73.40 / 68.32 / 67.65 | 75.57 / 69.76 / 70.17 | 76.94 / 72.18 / 72.22
|
| 226 |
+
|
| 227 |
+
#### Video tasks (Caption/QA)
|
| 228 |
+
|
| 229 |
+
Benchmark (train split) | Metric (split) | pt-224 | pt-448 | pt-896
|
| 230 |
+
-----------------------|----------------|--------|--------|--------
|
| 231 |
+
[MSR-VTT](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) (Captioning) | CIDEr (test) | 70.54 |
|
| 232 |
+
[MSR-VTT](https://www.microsoft.com/en-us/research/publication/msr-vtt-a-large-video-description-dataset-for-bridging-video-and-language/) (QA) | Accuracy (test) | 50.09 |
|
| 233 |
+
[ActivityNet](http://activity-net.org/) (Captioning)] | CIDEr (test) | 34.62 |
|
| 234 |
+
[ActivityNet](http://activity-net.org/) (QA) | Accuracy (test) | 50.78 |
|
| 235 |
+
[VATEX](https://eric-xw.github.io/vatex-website/about.html) (Captioning) | CIDEr (test) | 79.73 |
|
| 236 |
+
[MSVD](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) (QA) | Accuracy (test) | 60.22 |
|
| 237 |
+
|
| 238 |
+
#### Mix model (finetune on mixture of transfer tasks)
|
| 239 |
+
|
| 240 |
+
Benchmark | Metric (split) | mix-224 | mix-448
|
| 241 |
+
----------|----------------|---------|---------
|
| 242 |
+
[MMVP](https://arxiv.org/abs/2401.06209) | Paired Accuracy | 46.00 | 45.33
|
| 243 |
+
[POPE](https://arxiv.org/abs/2305.10355) | Accuracy (random / popular / adversarial) | 88.00 / 86.63 / 85.67 | 89.37 / 88.40 / 87.47
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
## How to run PaliGemma fine-tuning
|
| 247 |
+
|
| 248 |
+
To run PaliGemma fine-tuning, set up the `big_vision` repository by following the
|
| 249 |
+
main README file. Here we provide PaliGemma-specific instructions.
|
| 250 |
+
|
| 251 |
+
Checkpoints can be downloaded from Kaggle. You need to create an account and acknowledge checkpoint usage policy. You can then download any checkpoint:
|
| 252 |
+
|
| 253 |
+
```
|
| 254 |
+
export KAGGLE_USERNAME=
|
| 255 |
+
export KAGGLE_KEY=
|
| 256 |
+
|
| 257 |
+
# See https://www.kaggle.com/models/google/paligemma-2 for a full list of models.
|
| 258 |
+
export MODEL_NAME=paligemma2-3b-pt-224
|
| 259 |
+
|
| 260 |
+
mkdir ckpts/
|
| 261 |
+
cd ckpts/
|
| 262 |
+
|
| 263 |
+
# Store as a "vanity name" from models/proj/paligemma/paligemma.py
|
| 264 |
+
curl -L -u $KAGGLE_USERNAME:$KAGGLE_KEY\
|
| 265 |
+
-o pt_3b_224.bf16.npz \
|
| 266 |
+
https://www.kaggle.com/api/v1/models/google/paligemma-2/jax/$MODEL_NAME/1/download/$MODEL_NAME.b16.npz
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
As an example, we provide the `forkme.py` config that is based on the easily-adjustable jsonl data source:
|
| 270 |
+
|
| 271 |
+
```
|
| 272 |
+
BV_GEMMA_DIR=ckpts/ python -m big_vision.trainers.proj.paligemma.train --config big_vision/configs/proj/paligemma/transfers/forkme.py --workdir workdirs/`date '+%m-%d_%H%M'`
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
If you want to use TFDS-based data, check out other transfer configs. Remember to set `TFDS_DATA_DIR` to point to the folder with data (can be GCP data bucket).
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
## Model Development Contributions
|
| 279 |
+
|
| 280 |
+
See the Appendices of technical reports:
|
| 281 |
+
[PaliGemma](https://arxiv.org/abs/2407.07726),
|
| 282 |
+
[PaliGemma 2](https://arxiv.org/abs/2412.03555).
|
Tipsomaly/model/big_vision/configs/proj/paligemma/paligemma2.png
ADDED
|
Tipsomaly/model/big_vision/configs/proj/reward_tune/detection_reward.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Object detection reward from "Tuning computer vision models with task rewards" (https://arxiv.org/abs/2302.08242).
|
| 16 |
+
|
| 17 |
+
The `reward_fn` computes the reward for a batch of predictions and ground truth
|
| 18 |
+
annotations. When using it to optimize a model that outputs a prediction as a
|
| 19 |
+
sequence of tokens like [y0, x0, Y0, X0, class0, confidence0, y1, x1, Y1, ...]
|
| 20 |
+
the training loop may look like:
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
# Settings used in the paper.
|
| 24 |
+
config.max_level = 1000 # Coordinates are discretized into 1000 buckets.
|
| 25 |
+
config.max_conf = 2 # Two tokens are reserved to represent confidence.
|
| 26 |
+
config.num_cls = 80 # Number of classes in COCO.
|
| 27 |
+
config.nms_w = 0.3 # Weight for duplicate instances.
|
| 28 |
+
config.cls_smooth = 0.05 # Adjust the classes weights based on their frequency.
|
| 29 |
+
config.reward_thr = (0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95)
|
| 30 |
+
config.correct_thr = 0.5 # Learn the IoU when matching with threshold=0.5.
|
| 31 |
+
config.conf_w = 0.3 # Weight for the confidence loss.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 1) Sample N outputs for each input and compute rewards, use one sample to
|
| 35 |
+
# optimize and others to compute a reward baseline.
|
| 36 |
+
sample_seqs = sample_fn(params, images, num_samples)
|
| 37 |
+
sample_rewards, aux = reward_fn(sample_seqs, labels, config)
|
| 38 |
+
labels = sample_seqs[:, 0, ...]
|
| 39 |
+
rewards = sample_rewards[:, 0]
|
| 40 |
+
match_iou = aux["match_iou"][:, 0]
|
| 41 |
+
baselines = (jnp.sum(sample_rewards, axis=-1) - rewards) / (num_samples - 1)
|
| 42 |
+
|
| 43 |
+
# 2) Optimizize the model. By using REINFORCE to adjust the likelihood of the
|
| 44 |
+
# sequence based on the reward and with supervision to teach the model to
|
| 45 |
+
# predict the expected IoU of each box in its own samples.
|
| 46 |
+
def loss_fn(params):
|
| 47 |
+
logits = model.apply(params, images, labels, train=True, rngs=rngs)
|
| 48 |
+
logits_softmax = jax.nn.log_softmax(logits)
|
| 49 |
+
|
| 50 |
+
# Use reinforce to optimize the expected reward for the whole sequence.
|
| 51 |
+
seq_rewards = (rewards - baselines)
|
| 52 |
+
# Note: consider improve this code to skip this loss for confidence tokens.
|
| 53 |
+
# The paper did not do it due to a bug (and also does not seem to matter).
|
| 54 |
+
target = jax.nn.one_hot(labels, logits.shape[-1]) * seq_rewards[:, None, None]
|
| 55 |
+
loss_reward = -jnp.sum(target * logits_softmax, axis=-1)
|
| 56 |
+
|
| 57 |
+
# Use supervision loss to tune the confidence tokens to predict IoU:
|
| 58 |
+
# - (1.0, 0.0, 0.0, ...) -> for padded boxes.
|
| 59 |
+
# - (0.0, 1-iou, iou, ...) -> for sampled boxes.
|
| 60 |
+
conf0 = (labels[:, 5::6] == 0)
|
| 61 |
+
conf1 = (labels[:, 5::6] > 0) * (1.0 - match_iou)
|
| 62 |
+
conf2 = (labels[:, 5::6] > 0) * match_iou
|
| 63 |
+
target_conf = jnp.stack([conf0, conf1, conf2], axis=-1)
|
| 64 |
+
logits_conf = logits_softmax[:, 5::6, :3]
|
| 65 |
+
loss_conf = -jnp.sum(target_conf * logits_conf, axis=-1)
|
| 66 |
+
|
| 67 |
+
loss = jnp.mean(loss_reward) + config.conf_w * jnp.mean(loss_conf)
|
| 68 |
+
return loss
|
| 69 |
+
```
|
| 70 |
+
"""
|
| 71 |
+
import functools
|
| 72 |
+
|
| 73 |
+
import einops
|
| 74 |
+
import jax
|
| 75 |
+
import jax.numpy as jnp
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Frequency of COCO object detection classes as observed in the training set.
|
| 79 |
+
# pylint: disable=bad-whitespace,bad-continuation
|
| 80 |
+
CLS_COUNTS = [
|
| 81 |
+
262465, 7113, 43867, 8725, 5135, 6069, 4571, 9973, 10759,
|
| 82 |
+
12884, 1865, 1983, 1285, 9838, 10806, 4768, 5508, 6587,
|
| 83 |
+
9509, 8147, 5513, 1294, 5303, 5131, 8720, 11431, 12354,
|
| 84 |
+
6496, 6192, 2682, 6646, 2685, 6347, 9076, 3276, 3747,
|
| 85 |
+
5543, 6126, 4812, 24342, 7913, 20650, 5479, 7770, 6165,
|
| 86 |
+
14358, 9458, 5851, 4373, 6399, 7308, 7852, 2918, 5821,
|
| 87 |
+
7179, 6353, 38491, 5779, 8652, 4192, 15714, 4157, 5805,
|
| 88 |
+
4970, 2262, 5703, 2855, 6434, 1673, 3334, 225, 5610,
|
| 89 |
+
2637, 24715, 6334, 6613, 1481, 4793, 198, 1954
|
| 90 |
+
]
|
| 91 |
+
# pylint: enable=bad-whitespace,bad-continuation
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def seq2box(seq, max_level, max_conf, num_cls):
|
| 95 |
+
"""Extract boxes encoded as sequences."""
|
| 96 |
+
# Reshape to instances of boxes
|
| 97 |
+
dim_per_box = 6
|
| 98 |
+
seq_len = seq.shape[-1]
|
| 99 |
+
seq = seq[..., :(seq_len - seq_len % dim_per_box)]
|
| 100 |
+
seq = einops.rearrange(seq, "... (n d) -> ... n d", d=dim_per_box)
|
| 101 |
+
|
| 102 |
+
# Unpack box fields
|
| 103 |
+
boxes, labels, confs = seq[..., 0:4], seq[..., 4], seq[..., 5]
|
| 104 |
+
boxes = boxes - max_conf - 1
|
| 105 |
+
labels = labels - max_conf - 1 - max_level - 1
|
| 106 |
+
boxes = jnp.clip(boxes, 0, max_level) / max_level
|
| 107 |
+
labels = jnp.clip(labels, 0, num_cls - 1)
|
| 108 |
+
confs = jnp.clip(confs, 0, max_conf)
|
| 109 |
+
|
| 110 |
+
return boxes, labels, confs
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def iou_fn(box1, box2):
|
| 114 |
+
"""Compute IoU of two boxes."""
|
| 115 |
+
ymin1, xmin1, ymax1, xmax1 = box1
|
| 116 |
+
ymin2, xmin2, ymax2, xmax2 = box2
|
| 117 |
+
|
| 118 |
+
a1 = jnp.abs((ymax1 - ymin1) * (xmax1 - xmin1))
|
| 119 |
+
a2 = jnp.abs((ymax2 - ymin2) * (xmax2 - xmin2))
|
| 120 |
+
|
| 121 |
+
yl = jnp.maximum(ymin1, ymin2)
|
| 122 |
+
yr = jnp.minimum(ymax1, ymax2)
|
| 123 |
+
yi = jnp.maximum(0, yr - yl)
|
| 124 |
+
|
| 125 |
+
xl = jnp.maximum(xmin1, xmin2)
|
| 126 |
+
xr = jnp.minimum(xmax1, xmax2)
|
| 127 |
+
xi = jnp.maximum(0, xr - xl)
|
| 128 |
+
|
| 129 |
+
inter = xi * yi
|
| 130 |
+
return inter / (a1 + a2 - inter + 1e-9)
|
| 131 |
+
|
| 132 |
+
iou_fn_batched = jax.vmap(
|
| 133 |
+
jax.vmap(iou_fn, in_axes=(None, 0)), in_axes=(0, None)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _reward_fn_thr(seq_pred, seq_gt,
|
| 138 |
+
thr, nms_w, max_level, max_conf, num_cls, cls_smooth):
|
| 139 |
+
"""Compute detection reward function for a given IoU threshold."""
|
| 140 |
+
# Weight matches of each label inversely proportional to the percentage of
|
| 141 |
+
# GT instances with such label in the whole train dataset. Additionally
|
| 142 |
+
# smooth out the observed distribution.
|
| 143 |
+
cls_counts = jnp.array(CLS_COUNTS)
|
| 144 |
+
weights = 1.0 / (cls_counts + cls_smooth*jnp.sum(cls_counts))
|
| 145 |
+
weights = num_cls * weights / jnp.sum(weights)
|
| 146 |
+
|
| 147 |
+
boxes_pred, labels_pred, confs_pred = seq2box(
|
| 148 |
+
seq_pred, max_level, max_conf, num_cls)
|
| 149 |
+
boxes_gt, labels_gt, confs_gt = seq2box(
|
| 150 |
+
seq_gt, max_level, max_conf, num_cls)
|
| 151 |
+
|
| 152 |
+
# Compute IoU matrix: Predictions X GT
|
| 153 |
+
iou = iou_fn_batched(boxes_pred, boxes_gt)
|
| 154 |
+
|
| 155 |
+
# IoU thr
|
| 156 |
+
iou = jnp.where(iou > thr, iou, 0.0)
|
| 157 |
+
|
| 158 |
+
# EOS mask
|
| 159 |
+
confs_mask = (confs_pred[:, None] > 0) * (confs_gt[None, :] > 0)
|
| 160 |
+
iou = confs_mask * iou
|
| 161 |
+
|
| 162 |
+
# Label mask
|
| 163 |
+
label_mask = labels_pred[:, None] == labels_gt[None, :]
|
| 164 |
+
iou = label_mask * iou
|
| 165 |
+
|
| 166 |
+
# Each prediction is matched to a single box
|
| 167 |
+
single_match_mask = jax.nn.one_hot(jnp.argmax(iou, axis=1), iou.shape[1])
|
| 168 |
+
iou = iou * single_match_mask
|
| 169 |
+
|
| 170 |
+
# Pred. boxes indicators
|
| 171 |
+
correct = jnp.any(iou > 0.0, axis=1).astype("int32") + 1
|
| 172 |
+
correct = jnp.where(confs_pred > 0, correct, 0)
|
| 173 |
+
|
| 174 |
+
# For each GT box find best match
|
| 175 |
+
matches_idx = jnp.argmax(iou, axis=0)
|
| 176 |
+
matches_iou = jnp.take_along_axis(iou, matches_idx[None], axis=0)[0]
|
| 177 |
+
matches_idx = jnp.where(matches_iou > 0.0, matches_idx, -1)
|
| 178 |
+
|
| 179 |
+
match_reward = jnp.sum((matches_idx >= 0) * weights[labels_gt][None, :])
|
| 180 |
+
|
| 181 |
+
# Compute duplicate penalty (aka NMS).
|
| 182 |
+
matches_mask = jax.nn.one_hot(matches_idx, iou.shape[0], axis=0)
|
| 183 |
+
nms_penalty = jnp.sum(
|
| 184 |
+
(iou > 0.0) * (1 - matches_mask) * weights[labels_pred][:, None])
|
| 185 |
+
|
| 186 |
+
match_iou = jnp.sum(iou, axis=1)
|
| 187 |
+
|
| 188 |
+
return {
|
| 189 |
+
"reward": (match_reward - nms_w * nms_penalty),
|
| 190 |
+
"num_matches": jnp.sum(matches_idx >= 0),
|
| 191 |
+
"nms_penalty": nms_penalty,
|
| 192 |
+
"correct": correct,
|
| 193 |
+
"match_iou": match_iou,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def reward_fn(seqs_pred, seqs_gt, config):
|
| 198 |
+
"""Total reward."""
|
| 199 |
+
result = {}
|
| 200 |
+
thrs = config.reward_thr
|
| 201 |
+
correct_thr = config.correct_thr
|
| 202 |
+
r_keys = ["reward", "num_matches", "nms_penalty"]
|
| 203 |
+
for thr in thrs:
|
| 204 |
+
fn = functools.partial(
|
| 205 |
+
_reward_fn_thr,
|
| 206 |
+
thr=thr,
|
| 207 |
+
nms_w=config.nms_w,
|
| 208 |
+
max_level=config.max_level,
|
| 209 |
+
max_conf=config.max_conf,
|
| 210 |
+
num_cls=config.num_cls,
|
| 211 |
+
cls_smooth=config.cls_smooth,
|
| 212 |
+
)
|
| 213 |
+
rewards = jax.vmap(jax.vmap(fn, in_axes=(0, None)))(seqs_pred, seqs_gt)
|
| 214 |
+
|
| 215 |
+
result = {**result, **{f"{k}-{thr:0.1f}": rewards[k]
|
| 216 |
+
for k in r_keys}}
|
| 217 |
+
if thr == correct_thr:
|
| 218 |
+
correct = rewards["correct"]
|
| 219 |
+
match_iou = rewards["match_iou"]
|
| 220 |
+
|
| 221 |
+
result = {
|
| 222 |
+
**result,
|
| 223 |
+
**{k: jnp.mean(
|
| 224 |
+
jnp.array([result[f"{k}-{thr:0.1f}"] for thr in thrs]), axis=0)
|
| 225 |
+
for k in r_keys}
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
return result["reward"], {
|
| 229 |
+
"result": result,
|
| 230 |
+
"correct": correct,
|
| 231 |
+
"match_iou": match_iou,
|
| 232 |
+
}
|
Tipsomaly/model/big_vision/configs/proj/scaling_laws/train_vit_g.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""Pre-train ViT-g (1B params) on JFT-3B as in https://arxiv.org/abs/2106.04560
|
| 17 |
+
|
| 18 |
+
To train ViT-G (2B params), simply update the following single line:
|
| 19 |
+
`config.model.variant = 'G/14'`
|
| 20 |
+
|
| 21 |
+
The code is released for reference purposes.
|
| 22 |
+
One can test the code using public ImageNet-1k or ImageNet-21k dataset.
|
| 23 |
+
|
| 24 |
+
big_vision.train \
|
| 25 |
+
--config big_vision/configs/proj/scaling_laws/train_vit_g.py \
|
| 26 |
+
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
| 30 |
+
import ml_collections as mlc
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_config():
|
| 34 |
+
"""Rocket config."""
|
| 35 |
+
config = mlc.ConfigDict()
|
| 36 |
+
|
| 37 |
+
config.dataset = 'jft_3b'
|
| 38 |
+
config.val_split = 'val'
|
| 39 |
+
config.train_split = 'train'
|
| 40 |
+
config.num_classes = 29_593
|
| 41 |
+
config.init_head_bias = -10.0
|
| 42 |
+
|
| 43 |
+
# Fits 32 images per TPUv3 core with ViT-g/14.
|
| 44 |
+
config.batch_size = 4096*4
|
| 45 |
+
|
| 46 |
+
pp_common = '|value_range(-1, 1)'
|
| 47 |
+
pp_common += f'|onehot({config.num_classes})'
|
| 48 |
+
pp_common += '|keep("image", "labels")'
|
| 49 |
+
config.pp_train = 'inception_crop(224)|flip_lr' + pp_common
|
| 50 |
+
config.pp_eval = 'resize_small(256)|central_crop(224)' + pp_common
|
| 51 |
+
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
|
| 52 |
+
|
| 53 |
+
config.log_training_steps = 50
|
| 54 |
+
config.log_eval_steps = 1000
|
| 55 |
+
# NOTE: eval is very fast O(seconds) so it's fine to run it often.
|
| 56 |
+
|
| 57 |
+
config.ckpt_steps = 1000
|
| 58 |
+
config.keep_ckpt_steps = 10_000
|
| 59 |
+
|
| 60 |
+
config.prefetch_to_device = 1
|
| 61 |
+
config.trial = 0
|
| 62 |
+
|
| 63 |
+
# Model section
|
| 64 |
+
config.model_name = 'vit'
|
| 65 |
+
config.model = mlc.ConfigDict()
|
| 66 |
+
config.model.variant = 'g/14'
|
| 67 |
+
config.model.pool_type = 'map'
|
| 68 |
+
|
| 69 |
+
# Optimizer section
|
| 70 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
| 71 |
+
config.grad_clip_norm = 1.0
|
| 72 |
+
config.lr = 8e-4
|
| 73 |
+
config.wd = 0.03 * 8e-4
|
| 74 |
+
config.wd_mults = [
|
| 75 |
+
('.*head/kernel', 100.0),
|
| 76 |
+
('.*/kernel', 1.0),
|
| 77 |
+
]
|
| 78 |
+
config.schedule = dict(
|
| 79 |
+
decay_type='rsqrt', timescale=10_000, warmup_steps=10_000,
|
| 80 |
+
cooldown_steps=50_000)
|
| 81 |
+
config.total_steps = 1_000_000
|
| 82 |
+
|
| 83 |
+
# Few-shot eval section
|
| 84 |
+
config.evals = {}
|
| 85 |
+
config.evals.fewshot = dict(log_steps=10_000, **get_fewshot_lsr())
|
| 86 |
+
|
| 87 |
+
return config
|
Tipsomaly/model/big_vision/configs/proj/uvim/README.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes
|
| 2 |
+
|
| 3 |
+
*by Alexander Kolesnikov, André Susano Pinto, Lucas Beyer, Xiaohua Zhai, Jeremiah Harmsen, Neil Houlsby*
|
| 4 |
+
|
| 5 |
+
We provide pretrained UViM models from the [original paper](https://arxiv.org/abs/2205.10337),
|
| 6 |
+
as well as the instructions on how to reproduce core paper experiments.
|
| 7 |
+
|
| 8 |
+
## Pretrained models
|
| 9 |
+
|
| 10 |
+
The table below contains UViM models (stage I and II) trained for three
|
| 11 |
+
different tasks: panoptic segmentation, colorization and depth prediction.
|
| 12 |
+
|
| 13 |
+
| task | model | dataset | accuracy | download link |
|
| 14 |
+
| --------------------- | ------------------- | ------------------------------------------------------------------------ | ------------ | ----------------------------------------------------------------------------------------- |
|
| 15 |
+
| Panoptic segmentation | UViM Stage I model | [COCO(2017)](https://cocodataset.org/#home) | 75.8 PQ | [link](https://storage.googleapis.com/big_vision/uvim/panoptic_stageI_params.npz) |
|
| 16 |
+
| Panoptic segmentation | UViM Stage II model | [COCO(2017)](https://cocodataset.org/#home) | 43.1 PQ | [link](https://storage.googleapis.com/big_vision/uvim/panoptic_stageII_params.npz) |
|
| 17 |
+
| Colorization | UViM Stage I model | [ILSVRC-2012](https://www.image-net.org/) | 15.59 FID | [link](https://storage.googleapis.com/big_vision/uvim/color_stageI_params.npz) |
|
| 18 |
+
| Colorization | UViM Stage II model | [ILSVRC-2012](https://www.image-net.org/) | 16.99 FID | [link](https://storage.googleapis.com/big_vision/uvim/color_stageII_params.npz) |
|
| 19 |
+
| Depth | UViM Stage I model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.155 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageI_params.npz) |
|
| 20 |
+
| Depth | UViM Stage II model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.463 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageII_params.npz) |
|
| 21 |
+
|
| 22 |
+
All of this models can be interactively explored in our [colabs](configs/proj/uvim).
|
| 23 |
+
|
| 24 |
+
## Running on a single-host TPU machine
|
| 25 |
+
|
| 26 |
+
Below we provide instructions on how to run UViM training (stage I and
|
| 27 |
+
stage II) using a single TPU host with 8 TPU accelerators. These instructions
|
| 28 |
+
can be easily adapted to a GPU host and multi-host TPU setup, see the main
|
| 29 |
+
`big_vision` [README file](README.md).
|
| 30 |
+
|
| 31 |
+
We assume that the user has already created and `ssh`-ed to the TPU host
|
| 32 |
+
machine. The next step is to clone `big_vision` repository:
|
| 33 |
+
`git clone https://github.com/google-research/big_vision.git`.
|
| 34 |
+
|
| 35 |
+
The next steps are to create a python virtual environment and install python
|
| 36 |
+
dependencies:
|
| 37 |
+
```
|
| 38 |
+
virtualenv bv
|
| 39 |
+
source bv/bin/activate
|
| 40 |
+
cd big_vision/
|
| 41 |
+
pip3 install --upgrade pip
|
| 42 |
+
pip3 install -r big_vision/requirements.txt
|
| 43 |
+
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
After this invoke the helper tool to download and prepare data:
|
| 47 |
+
`python3 -m big_vision.tools.download_tfds_datasets coco/2017_panoptic nyu_depth_v2`.
|
| 48 |
+
For preparing the ImageNet dataset consult the main codebase README.
|
| 49 |
+
|
| 50 |
+
> :warning: TPU machines have 100 GB of the disk space. It may not be enough to
|
| 51 |
+
> store all training data (though only panoptic or only depth data may fit).
|
| 52 |
+
> Consider preparing the data on a seperate machine and then copying it to
|
| 53 |
+
> to TPU machine's extra persistent disk or to a Google Cloud Bucket. See
|
| 54 |
+
> instructions for [creating an extra persistent disk](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
|
| 55 |
+
> Remember to set the correct data home directory, e.g.`export DISK=/mnt/disk/persist; export TFDS_DATA_DIR=$DISK/tensorflow_datasets`.
|
| 56 |
+
|
| 57 |
+
Our panoptic evaluator uses raw variant of the COCO data, so we move it into a
|
| 58 |
+
separate folder. Note, `tfds` has already pre-downloaded the panoptic data,
|
| 59 |
+
except for one small json file that we fetch manually:
|
| 60 |
+
```
|
| 61 |
+
mkdir $DISK/coco_data
|
| 62 |
+
cd $DISK/coco_data
|
| 63 |
+
mv $TFDS_DATA_DIR/downloads/extracted/ZIP.image.cocod.org_annot_panop_annot_train<REPLACE_ME_WITH_THE_HASH_CODE>.zip/annotations/* .
|
| 64 |
+
wget https://raw.githubusercontent.com/cocodataset/panopticapi/master/panoptic_coco_categories.json
|
| 65 |
+
export COCO_DATA_DIR=$DISK/coco_data
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
For FID evaluator, which is used for the colorization model, set the path to the
|
| 69 |
+
directory with image id files, e.g.
|
| 70 |
+
`export FID_DATA_DIR=<ROOT>/big_vision/evaluators/proj/uvim/coltran_fid_data`.
|
| 71 |
+
|
| 72 |
+
As an example, stage I panoptic training can be invoked as (note the `:singlehost` config parameter which will use lightweight configuration suitable for a single host):
|
| 73 |
+
```
|
| 74 |
+
python3 -m big_vision.trainers.proj.uvim.vqvae --config big_vision/configs/proj/uvim/vqvae_coco_panoptic.py:singlehost --workdir workdirs/`date '+%m-%d_%H%M'`
|
| 75 |
+
```
|
| 76 |
+
or stage II training
|
| 77 |
+
```
|
| 78 |
+
python3 -m big_vision.trainers.proj.uvim.train --config big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py:singlehost --workdir workdirs/`date '+%m-%d_%H%M'`
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Acknowledgments
|
| 82 |
+
The sampling code in `models/proj/uvim/decode.py` module is based on contributions
|
| 83 |
+
from Anselm Levskaya, Ilya Tolstikhin and Maxim Neumann.
|
| 84 |
+
|
Tipsomaly/model/big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""A config for training a UViM stage II model for the panoptic task.
|
| 17 |
+
|
| 18 |
+
This config is expected to reproduce the paper's result and achieve
|
| 19 |
+
approximately 43.7 PQ points on the COCO holdout data.
|
| 20 |
+
|
| 21 |
+
We also provide a low-resource variant of this config, which can be enabled
|
| 22 |
+
by adding `:singlehost` postfix to the config name. This one is expected to
|
| 23 |
+
achieve 39.4 PQ points on the COCO holdout data.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import big_vision.configs.common as bvcc
|
| 27 |
+
from ml_collections import ConfigDict
|
| 28 |
+
|
| 29 |
+
VTT_MODELS = {
|
| 30 |
+
'base': dict(num_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768),
|
| 31 |
+
'large': dict(num_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024),
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
VQVAE_MODELS = {
|
| 35 |
+
'base': dict(enc_depth=6, dec_depth=12, num_heads=12, mlp_dim=3072, width=768),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
RES = 512
|
| 39 |
+
PATCH_SIZE = 16
|
| 40 |
+
LABEL_RES = 512
|
| 41 |
+
LABEL_PATCH_SIZE = 16
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_config(arg=''):
|
| 45 |
+
"""Config for training."""
|
| 46 |
+
arg = bvcc.parse_arg(arg, runlocal=False, singlehost=False)
|
| 47 |
+
config = ConfigDict()
|
| 48 |
+
|
| 49 |
+
config.input = {}
|
| 50 |
+
config.input.pp = (
|
| 51 |
+
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
|
| 52 |
+
f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
|
| 53 |
+
f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
|
| 54 |
+
f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|'
|
| 55 |
+
f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|'
|
| 56 |
+
f'value_range(-1, 1, key="image_ctx")|'
|
| 57 |
+
f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")'
|
| 58 |
+
)
|
| 59 |
+
pp_eval = (
|
| 60 |
+
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
|
| 61 |
+
f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|'
|
| 62 |
+
f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|'
|
| 63 |
+
f'value_range(-1, 1, key="image_ctx")|'
|
| 64 |
+
f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")'
|
| 65 |
+
)
|
| 66 |
+
pp_predict = (
|
| 67 |
+
f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|resize({RES})|'
|
| 68 |
+
f'value_range(-1, 1, key="image_ctx")|value_range(-1, 1)|'
|
| 69 |
+
f'keep("image","image_ctx","image/id")' # image/id used for rng seeds.
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
config.input.data = dict(name='coco/2017_panoptic', split='train[4096:]')
|
| 73 |
+
config.input.batch_size = 512
|
| 74 |
+
config.input.shuffle_buffer_size = 50_000
|
| 75 |
+
|
| 76 |
+
config.total_epochs = 200
|
| 77 |
+
|
| 78 |
+
config.log_training_steps = 50
|
| 79 |
+
config.ckpt_steps = 1000
|
| 80 |
+
config.keep_ckpt_steps = 5000
|
| 81 |
+
config.prefetch_to_device = 2
|
| 82 |
+
config.seed = 0
|
| 83 |
+
|
| 84 |
+
# Optimizer section
|
| 85 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
| 86 |
+
config.optax = dict(beta2_cap=0.95)
|
| 87 |
+
|
| 88 |
+
config.lr = 0.001
|
| 89 |
+
config.wd = 0.000001
|
| 90 |
+
config.lr_mults = [
|
| 91 |
+
('pos_embedding_encoder.*', 0.1),
|
| 92 |
+
('EmbedPatches.*', 0.1),
|
| 93 |
+
('encoder.*', 0.1),
|
| 94 |
+
('decoder.*', 1.0)
|
| 95 |
+
]
|
| 96 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=4_000)
|
| 97 |
+
|
| 98 |
+
# Oracle section
|
| 99 |
+
config.oracle = ConfigDict()
|
| 100 |
+
config.oracle.task = 'proj.uvim.panoptic_task'
|
| 101 |
+
config.oracle.model_init = 'gs://big_vision/uvim/panoptic_stageI_params.npz'
|
| 102 |
+
config.oracle.model_name = 'proj.uvim.vit'
|
| 103 |
+
config.oracle.model = ConfigDict(VQVAE_MODELS['base'])
|
| 104 |
+
config.oracle.model.input_size = (LABEL_RES, LABEL_RES)
|
| 105 |
+
config.oracle.model.patch_size = (LABEL_PATCH_SIZE, LABEL_PATCH_SIZE)
|
| 106 |
+
config.oracle.model.code_len = 256
|
| 107 |
+
config.oracle.model.dict_size = 4096
|
| 108 |
+
config.oracle.model.codeword_dim = 768
|
| 109 |
+
config.oracle.model.with_encoder_ctx = True
|
| 110 |
+
config.oracle.model.with_decoder_ctx = True
|
| 111 |
+
config.oracle.model.code_dropout = 'random'
|
| 112 |
+
config.oracle.model.bottleneck_resize = True
|
| 113 |
+
config.oracle.model.inputs = {
|
| 114 |
+
'semantics': (133 + 1, LABEL_PATCH_SIZE**2), # +1 for void label
|
| 115 |
+
'instances': (100, LABEL_PATCH_SIZE**2), # COCO: actually 98 train/78 validation.
|
| 116 |
+
}
|
| 117 |
+
config.oracle.model.outputs = config.oracle.model.inputs
|
| 118 |
+
|
| 119 |
+
# Model section
|
| 120 |
+
config.model_name = 'proj.uvim.vtt'
|
| 121 |
+
# config.model_init = {'encoder': 'howto-i21k-B/8'}
|
| 122 |
+
config.model_init = {'encoder': 'howto-i21k-L/16'}
|
| 123 |
+
config.model = ConfigDict(VTT_MODELS['large'])
|
| 124 |
+
config.model.patches = ConfigDict({'size': (PATCH_SIZE, PATCH_SIZE)})
|
| 125 |
+
config.model.vocab_size = config.oracle.model.get_ref('dict_size') + 1
|
| 126 |
+
config.model.posemb_type = 'learn'
|
| 127 |
+
config.model.input_size = (RES, RES)
|
| 128 |
+
config.model.seq_len = config.oracle.model.get_ref('code_len')
|
| 129 |
+
|
| 130 |
+
# Evaluation section
|
| 131 |
+
config.evals = {}
|
| 132 |
+
config.evals.val = ConfigDict()
|
| 133 |
+
config.evals.val.type = 'proj.uvim.compute_mean'
|
| 134 |
+
config.evals.val.pred = 'validation'
|
| 135 |
+
config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]')
|
| 136 |
+
config.evals.val.pp_fn = pp_eval
|
| 137 |
+
config.evals.val.log_steps = 1000
|
| 138 |
+
|
| 139 |
+
base = {
|
| 140 |
+
'type': 'proj.uvim.coco_panoptic',
|
| 141 |
+
'pp_fn': pp_predict,
|
| 142 |
+
'log_steps': 10_000,
|
| 143 |
+
# Filters objects that occupy less than 0.03^2 fraction of all pixels.
|
| 144 |
+
# 'predict_kwargs': {'min_fraction': 0.03 ** 2},
|
| 145 |
+
}
|
| 146 |
+
config.evals.coco_panoptic_train = dict(**base, split='train[4096:8192]')
|
| 147 |
+
config.evals.coco_panoptic_holdout = dict(**base, split='train[:4096]')
|
| 148 |
+
config.evals.coco_panoptic = dict(**base, split='validation')
|
| 149 |
+
|
| 150 |
+
# config.evals.save_pred = dict(type='proj.uvim.save_predictions')
|
| 151 |
+
# config.evals.save_pred.pp = pp_eval.replace('decode|', '')
|
| 152 |
+
# config.evals.save_pred.log_steps = 100_000
|
| 153 |
+
# config.evals.save_pred.dataset = config.dataset
|
| 154 |
+
# config.evals.save_pred.split = 'validation[:1024]'
|
| 155 |
+
# config.evals.save_pred.outfile = 'inference.npz'
|
| 156 |
+
|
| 157 |
+
if arg.singlehost:
|
| 158 |
+
config.input.batch_size = 32
|
| 159 |
+
config.num_epochs = 50
|
| 160 |
+
elif arg.runlocal:
|
| 161 |
+
config.input.batch_size = 4
|
| 162 |
+
config.input.shuffle_buffer_size = 10
|
| 163 |
+
config.evals.val.data.split = 'train[:16]'
|
| 164 |
+
return config
|
Tipsomaly/model/big_vision/configs/proj/uvim/train_imagenet2012_colorization_pretrained.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""A config for training a UViM stage II model for the colorization task.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import big_vision.configs.common as bvcc
|
| 20 |
+
from ml_collections import ConfigDict
|
| 21 |
+
|
| 22 |
+
VTT_MODELS = {
|
| 23 |
+
'base': dict(num_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768),
|
| 24 |
+
'large': dict(num_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
VQVAE_MODELS = {
|
| 28 |
+
'base': dict(enc_depth=6, dec_depth=12, num_heads=12, mlp_dim=3072, width=768),
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
RES = 512
|
| 32 |
+
PATCH_SIZE = 16
|
| 33 |
+
LABEL_RES = 512
|
| 34 |
+
LABEL_PATCH_SIZE = 16
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_config(arg=''):
|
| 38 |
+
"""Config for training."""
|
| 39 |
+
arg = bvcc.parse_arg(arg, runlocal=False, singlehost=False)
|
| 40 |
+
config = ConfigDict()
|
| 41 |
+
|
| 42 |
+
config.input = {}
|
| 43 |
+
config.input.pp = (
|
| 44 |
+
f'decode_jpeg_and_inception_crop({RES})'
|
| 45 |
+
f'|flip_lr'
|
| 46 |
+
f'|copy(inkey="image", outkey="labels")'
|
| 47 |
+
f'|resize({LABEL_RES},inkey="labels",outkey="labels",method="nearest")'
|
| 48 |
+
f'|value_range(-1,1,key="labels")'
|
| 49 |
+
f'|rgb_to_grayscale_to_rgb(inkey="image",outkey="image")'
|
| 50 |
+
f'|value_range(-1,1,key="image")'
|
| 51 |
+
f'|copy(inkey="image", outkey="image_ctx")'
|
| 52 |
+
f'|resize({LABEL_RES},inkey="image_ctx",outkey="image_ctx")'
|
| 53 |
+
f'|keep("image","image_ctx","labels")')
|
| 54 |
+
pp_eval = (
|
| 55 |
+
f'decode'
|
| 56 |
+
f'|resize({RES})'
|
| 57 |
+
f'|copy(inkey="image", outkey="labels")'
|
| 58 |
+
f'|resize({LABEL_RES},inkey="labels",outkey="labels",method="nearest")'
|
| 59 |
+
f'|value_range(-1,1,key="labels")'
|
| 60 |
+
f'|rgb_to_grayscale_to_rgb(inkey="image",outkey="image")'
|
| 61 |
+
f'|value_range(-1,1,key="image")'
|
| 62 |
+
f'|copy(inkey="image", outkey="image_ctx")'
|
| 63 |
+
f'|resize({LABEL_RES},inkey="image_ctx",outkey="image_ctx")'
|
| 64 |
+
f'|strong_hash(inkey="tfds_id", outkey="image/id")'
|
| 65 |
+
f'|keep("image","image_ctx","labels","image/id")')
|
| 66 |
+
|
| 67 |
+
config.input.data = dict(name='imagenet2012', split='train[4096:]')
|
| 68 |
+
config.input.batch_size = 512
|
| 69 |
+
config.input.shuffle_buffer_size = 50_000
|
| 70 |
+
|
| 71 |
+
config.total_epochs = 50
|
| 72 |
+
|
| 73 |
+
config.log_training_steps = 50
|
| 74 |
+
config.ckpt_steps = 1000
|
| 75 |
+
config.keep_ckpt_steps = 5000
|
| 76 |
+
config.prefetch_to_device = 2
|
| 77 |
+
config.seed = 0
|
| 78 |
+
|
| 79 |
+
# Optimizer section
|
| 80 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
| 81 |
+
config.optax = dict(beta2_cap=0.95)
|
| 82 |
+
|
| 83 |
+
config.lr = 0.001
|
| 84 |
+
config.wd = 0.000001
|
| 85 |
+
config.lr_mults = [
|
| 86 |
+
('pos_embedding_encoder.*', 0.1),
|
| 87 |
+
('EmbedPatches.*', 0.1),
|
| 88 |
+
('encoder.*', 0.1),
|
| 89 |
+
('decoder.*', 1.0)
|
| 90 |
+
]
|
| 91 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=4_000)
|
| 92 |
+
|
| 93 |
+
# Oracle section
|
| 94 |
+
config.oracle = ConfigDict()
|
| 95 |
+
config.oracle.task = 'proj.uvim.colorization_task'
|
| 96 |
+
config.oracle.model_init = 'gs://big_vision/uvim/color_stageI_params.npz'
|
| 97 |
+
config.oracle.model_name = 'proj.uvim.vit'
|
| 98 |
+
config.oracle.model = ConfigDict(VQVAE_MODELS['base'])
|
| 99 |
+
config.oracle.model.input_size = (LABEL_RES, LABEL_RES)
|
| 100 |
+
config.oracle.model.patch_size = (LABEL_PATCH_SIZE, LABEL_PATCH_SIZE)
|
| 101 |
+
config.oracle.model.code_len = 256
|
| 102 |
+
config.oracle.model.dict_size = 4096
|
| 103 |
+
config.oracle.model.codeword_dim = 768
|
| 104 |
+
config.oracle.model.with_encoder_ctx = True
|
| 105 |
+
config.oracle.model.with_decoder_ctx = True
|
| 106 |
+
config.oracle.model.code_dropout = 'random'
|
| 107 |
+
config.oracle.model.bottleneck_resize = True
|
| 108 |
+
config.oracle.model.inputs = {
|
| 109 |
+
'color': (3, LABEL_PATCH_SIZE**2),
|
| 110 |
+
}
|
| 111 |
+
config.oracle.model.outputs = config.oracle.model.inputs
|
| 112 |
+
|
| 113 |
+
# Model section
|
| 114 |
+
config.model_name = 'proj.uvim.vtt'
|
| 115 |
+
# config.model_init = {'encoder': 'howto-i21k-B/8'}
|
| 116 |
+
config.model_init = {'encoder': 'howto-i21k-L/16'}
|
| 117 |
+
config.model = ConfigDict(VTT_MODELS['large'])
|
| 118 |
+
config.model.patches = ConfigDict({'size': (PATCH_SIZE, PATCH_SIZE)})
|
| 119 |
+
config.model.vocab_size = config.oracle.model.get_ref('dict_size') + 1
|
| 120 |
+
config.model.posemb_type = 'learn'
|
| 121 |
+
config.model.input_size = (RES, RES)
|
| 122 |
+
config.model.seq_len = config.oracle.model.get_ref('code_len')
|
| 123 |
+
|
| 124 |
+
# Evaluation section
|
| 125 |
+
config.evals = {}
|
| 126 |
+
config.evals.val = ConfigDict()
|
| 127 |
+
config.evals.val.type = 'proj.uvim.compute_mean'
|
| 128 |
+
config.evals.val.pred = 'validation'
|
| 129 |
+
config.evals.val.data = dict(name=config.input.data.name, split='train[:4096]')
|
| 130 |
+
config.evals.val.pp_fn = pp_eval
|
| 131 |
+
config.evals.val.log_steps = 1000
|
| 132 |
+
|
| 133 |
+
base = {
|
| 134 |
+
'type': 'proj.uvim.psnr',
|
| 135 |
+
'pp_fn': pp_eval.replace('decode|', ''),
|
| 136 |
+
'log_steps': 10_000,
|
| 137 |
+
}
|
| 138 |
+
config.evals.psnr_train = dict(**base, split='train[4096:8192]')
|
| 139 |
+
config.evals.psnr_holdout = dict(**base, split='train[:4096]')
|
| 140 |
+
config.evals.psnr_val = dict(**base, split='validation')
|
| 141 |
+
|
| 142 |
+
config.evals.colorization_val_coltran_fid = {
|
| 143 |
+
'type': 'proj.uvim.coltran_fid',
|
| 144 |
+
'log_steps': 100_000,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
# config.evals.save_pred = dict(type='proj.uvim.save_predictions')
|
| 148 |
+
# config.evals.save_pred.pp_fn = pp_eval.replace('decode|', '')
|
| 149 |
+
# config.evals.save_pred.log_steps = 100_000
|
| 150 |
+
# config.evals.save_pred.dataset = config.dataset
|
| 151 |
+
# config.evals.save_pred.split = 'validation[:1024]'
|
| 152 |
+
# config.evals.save_pred.outfile = 'inference.npz'
|
| 153 |
+
|
| 154 |
+
if arg.singlehost:
|
| 155 |
+
config.input.batch_size = 32
|
| 156 |
+
config.total_epochs = 20
|
| 157 |
+
elif arg.runlocal:
|
| 158 |
+
config.input.batch_size = 8
|
| 159 |
+
config.input.shuffle_buffer_size = 10
|
| 160 |
+
config.evals.val.data.split = 'validation[:256]'
|
| 161 |
+
return config
|
Tipsomaly/model/big_vision/configs/proj/uvim/train_nyu_depth_pretrained.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Big Vision Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# pylint: disable=line-too-long
|
| 16 |
+
r"""A config for training a UViM stage II model for the depth task.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import big_vision.configs.common as bvcc
|
| 20 |
+
from ml_collections import ConfigDict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
VTT_MODELS = {
|
| 24 |
+
'base': dict(num_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768),
|
| 25 |
+
'large': dict(num_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
VQVAE_MODELS = {
|
| 29 |
+
'base': dict(enc_depth=6, dec_depth=12, num_heads=12, mlp_dim=3072, width=768),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
RES = 512
|
| 34 |
+
PATCH_SIZE = 16
|
| 35 |
+
LABEL_RES = 512
|
| 36 |
+
LABEL_PATCH_SIZE = 16
|
| 37 |
+
QUANTIZATION_BINS = 256
|
| 38 |
+
# Same as values used in eval, see evaluators/nyu_depth.py.
|
| 39 |
+
MIN_DEPTH = 1e-3
|
| 40 |
+
MAX_DEPTH = 10
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_config(arg='split=final'):
|
| 44 |
+
"""Config for training."""
|
| 45 |
+
arg = bvcc.parse_arg(arg, split='final', runlocal=False, singlehost=False)
|
| 46 |
+
config = ConfigDict()
|
| 47 |
+
|
| 48 |
+
config.input = {}
|
| 49 |
+
config.input.pp = (
|
| 50 |
+
f'decode|nyu_depth|'
|
| 51 |
+
f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
|
| 52 |
+
f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
|
| 53 |
+
f'resize({RES})|'
|
| 54 |
+
f'resize({LABEL_RES},inkey="image",outkey="image_ctx")|'
|
| 55 |
+
f'resize({LABEL_RES},key="labels",method="nearest")|'
|
| 56 |
+
f'value_range(-1,1)|'
|
| 57 |
+
f'value_range(-1,1,inkey="image_ctx",outkey="image_ctx")|'
|
| 58 |
+
f'keep("image","image_ctx","labels")'
|
| 59 |
+
)
|
| 60 |
+
pp_eval = (
|
| 61 |
+
f'decode|nyu_depth|'
|
| 62 |
+
f'nyu_eval_crop|'
|
| 63 |
+
f'resize({RES})|'
|
| 64 |
+
f'resize({LABEL_RES},inkey="image",outkey="image_ctx")|'
|
| 65 |
+
f'resize({LABEL_RES},key="labels",method="nearest")|'
|
| 66 |
+
f'value_range(-1,1)|'
|
| 67 |
+
f'value_range(-1,1,inkey="image_ctx",outkey="image_ctx")|'
|
| 68 |
+
f'keep("image","image_ctx","labels")'
|
| 69 |
+
)
|
| 70 |
+
pp_predict = (
|
| 71 |
+
f'nyu_depth|'
|
| 72 |
+
f'nyu_eval_crop|copy("labels","ground_truth")|'
|
| 73 |
+
f'resize({RES})|'
|
| 74 |
+
f'resize({LABEL_RES},inkey="image",outkey="image_ctx")|'
|
| 75 |
+
f'value_range(-1,1)|'
|
| 76 |
+
f'value_range(-1,1,inkey="image_ctx",outkey="image_ctx")|'
|
| 77 |
+
f'keep("image","image_ctx","ground_truth")'
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
config.input.data = dict(name='nyu_depth_v2', split='train')
|
| 81 |
+
config.input.batch_size = 512
|
| 82 |
+
config.input.shuffle_buffer_size = 50_000
|
| 83 |
+
|
| 84 |
+
config.total_epochs = 50
|
| 85 |
+
|
| 86 |
+
config.log_training_steps = 50
|
| 87 |
+
config.ckpt_steps = 1000
|
| 88 |
+
config.keep_ckpt_steps = 5000
|
| 89 |
+
config.prefetch_to_device = 2
|
| 90 |
+
config.seed = 0
|
| 91 |
+
|
| 92 |
+
# Optimizer section
|
| 93 |
+
config.optax_name = 'big_vision.scale_by_adafactor'
|
| 94 |
+
config.optax = dict(beta2_cap=0.95)
|
| 95 |
+
config.optax.clipping_threshold = None
|
| 96 |
+
|
| 97 |
+
config.lr = 0.001
|
| 98 |
+
config.wd = 0.000001
|
| 99 |
+
config.lr_mults = (
|
| 100 |
+
('pos_embedding_encoder.*', 0.1),
|
| 101 |
+
('EmbedPatches.*', 0.1),
|
| 102 |
+
('encoder.*', 0.1),
|
| 103 |
+
('decoder.*', 1.0)
|
| 104 |
+
)
|
| 105 |
+
config.schedule = dict(decay_type='cosine', warmup_steps=4_000)
|
| 106 |
+
|
| 107 |
+
# Oracle section
|
| 108 |
+
config.oracle = ConfigDict()
|
| 109 |
+
config.oracle.min_depth = MIN_DEPTH
|
| 110 |
+
config.oracle.max_depth = MAX_DEPTH
|
| 111 |
+
config.oracle.task = 'proj.uvim.depth_task'
|
| 112 |
+
config.oracle.model_init = 'gs://big_vision/uvim/depth_stageI_params.npz'
|
| 113 |
+
config.oracle.model_name = 'proj.uvim.vit'
|
| 114 |
+
config.oracle.model = ConfigDict(VQVAE_MODELS['base'])
|
| 115 |
+
config.oracle.model.input_size = (LABEL_RES, LABEL_RES)
|
| 116 |
+
config.oracle.model.patch_size = (LABEL_PATCH_SIZE, LABEL_PATCH_SIZE)
|
| 117 |
+
config.oracle.model.code_len = 256
|
| 118 |
+
config.oracle.model.dict_size = 4096
|
| 119 |
+
config.oracle.model.codeword_dim = 768
|
| 120 |
+
config.oracle.model.with_encoder_ctx = True
|
| 121 |
+
config.oracle.model.with_decoder_ctx = True
|
| 122 |
+
config.oracle.model.code_dropout = 'random'
|
| 123 |
+
config.oracle.model.bottleneck_resize = True
|
| 124 |
+
config.oracle.model.inputs = {
|
| 125 |
+
'depth': (QUANTIZATION_BINS, LABEL_PATCH_SIZE**2,),
|
| 126 |
+
}
|
| 127 |
+
config.oracle.model.outputs = config.oracle.model.inputs
|
| 128 |
+
|
| 129 |
+
# Model section
|
| 130 |
+
config.model_name = 'proj.uvim.vtt'
|
| 131 |
+
# config.model_init = {'encoder': 'howto-i21k-B/8''} # B/8 I21K
|
| 132 |
+
config.model_init = {'encoder': 'howto-i21k-L/16'} # L/16 I21K
|
| 133 |
+
config.model = ConfigDict(VTT_MODELS['large'])
|
| 134 |
+
config.model.patches = ConfigDict({'size': (PATCH_SIZE, PATCH_SIZE)})
|
| 135 |
+
config.model.vocab_size = config.oracle.model.dict_size + 1
|
| 136 |
+
config.model.posemb_type = 'learn'
|
| 137 |
+
config.model.input_size = (RES, RES)
|
| 138 |
+
config.model.seq_len = config.oracle.model.get_ref('code_len')
|
| 139 |
+
config.model.zero_decoder_seq = False
|
| 140 |
+
|
| 141 |
+
# Evaluation section
|
| 142 |
+
config.evals = {}
|
| 143 |
+
config.evals.val = ConfigDict()
|
| 144 |
+
config.evals.val.type = 'proj.uvim.compute_mean'
|
| 145 |
+
config.evals.val.pred = 'validation'
|
| 146 |
+
config.evals.val.data = {**config.input.data}
|
| 147 |
+
config.evals.val.data.split = 'validation'
|
| 148 |
+
config.evals.val.pp_fn = pp_eval
|
| 149 |
+
config.evals.val.log_steps = 1000
|
| 150 |
+
|
| 151 |
+
base = {
|
| 152 |
+
'type': 'proj.uvim.nyu_depth',
|
| 153 |
+
'dataset': config.input.data.name,
|
| 154 |
+
'pp_fn': pp_predict,
|
| 155 |
+
'log_steps': 2000,
|
| 156 |
+
'min_depth': MIN_DEPTH,
|
| 157 |
+
'max_depth': MAX_DEPTH,
|
| 158 |
+
}
|
| 159 |
+
config.evals.nyu_depth_val = dict(**base, split='validation')
|
| 160 |
+
|
| 161 |
+
if arg.singlehost:
|
| 162 |
+
config.input.batch_size = 32
|
| 163 |
+
config.total_epochs = 20
|
| 164 |
+
elif arg.runlocal:
|
| 165 |
+
config.oracle.model_init = '/tmp/checkpoint.npz'
|
| 166 |
+
config.model_init = {'encoder': '/tmp/enc_checkpoint.npz'}
|
| 167 |
+
config.evals = {}
|
| 168 |
+
config.input.batch_size = 1
|
| 169 |
+
config.input.shuffle_buffer_size = 10
|
| 170 |
+
return config
|
Tipsomaly/model/big_vision/configs/proj/uvim/uvim_color_task.ipynb
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"name": "UViM color task",
|
| 7 |
+
"provenance": [],
|
| 8 |
+
"collapsed_sections": [],
|
| 9 |
+
"private_outputs": true
|
| 10 |
+
},
|
| 11 |
+
"kernelspec": {
|
| 12 |
+
"name": "python3",
|
| 13 |
+
"display_name": "Python 3"
|
| 14 |
+
},
|
| 15 |
+
"language_info": {
|
| 16 |
+
"name": "python"
|
| 17 |
+
},
|
| 18 |
+
"accelerator": "GPU",
|
| 19 |
+
"gpuClass": "standard"
|
| 20 |
+
},
|
| 21 |
+
"cells": [
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"source": [
|
| 25 |
+
"# Fetch big_vision repository and move it into the current workdir (import path).\n",
|
| 26 |
+
"!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo\n",
|
| 27 |
+
"!cp -R big_vision_repo/big_vision big_vision\n",
|
| 28 |
+
"!pip install -qr big_vision/requirements.txt"
|
| 29 |
+
],
|
| 30 |
+
"metadata": {
|
| 31 |
+
"id": "sKZK6_QpVI_O"
|
| 32 |
+
},
|
| 33 |
+
"execution_count": null,
|
| 34 |
+
"outputs": []
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"source": [
|
| 39 |
+
"import jax\n",
|
| 40 |
+
"import jax.numpy as jnp\n",
|
| 41 |
+
"import numpy as np\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"from big_vision.models.proj.uvim import vtt # stage-II model\n",
|
| 44 |
+
"from big_vision.models.proj.uvim import vit # stage-I model\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"from big_vision.models.proj.uvim import decode\n",
|
| 47 |
+
"from big_vision.trainers.proj.uvim import colorization_task as task\n",
|
| 48 |
+
"from big_vision.configs.proj.uvim import train_imagenet2012_colorization_pretrained as config_module\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"import big_vision.pp.ops_image\n",
|
| 51 |
+
"import big_vision.pp.ops_general\n",
|
| 52 |
+
"import big_vision.pp.proj.uvim.pp_ops\n",
|
| 53 |
+
"from big_vision.pp import builder as pp_builder\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"config = config_module.get_config()\n",
|
| 56 |
+
"res = 512\n",
|
| 57 |
+
"seq_len = config.model.seq_len\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"lm_model = vtt.Model(**config.model)\n",
|
| 60 |
+
"oracle_model = vit.Model(**config.oracle.model)\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"preprocess_fn = pp_builder.get_preprocess_fn(\n",
|
| 63 |
+
" 'decode|resize(512)|'\n",
|
| 64 |
+
" 'rgb_to_grayscale_to_rgb|value_range(-1,1)|'\n",
|
| 65 |
+
" 'copy(inkey=\"image\",outkey=\"image_ctx\")')\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"@jax.jit\n",
|
| 68 |
+
"def predict_code(params, x, rng, temperature):\n",
|
| 69 |
+
" prompts = jnp.zeros((x[\"image\"].shape[0], seq_len), dtype=jnp.int32)\n",
|
| 70 |
+
" seqs, _, _ = decode.temperature_sampling(\n",
|
| 71 |
+
" params=params, model=lm_model, seed=rng,\n",
|
| 72 |
+
" inputs=x[\"image\"],\n",
|
| 73 |
+
" prompts=prompts,\n",
|
| 74 |
+
" temperature=temperature,\n",
|
| 75 |
+
" num_samples=1, eos_token=-1, prefill=False)\n",
|
| 76 |
+
" seqs = jnp.squeeze(seqs, axis=1) # drop num_samples axis \n",
|
| 77 |
+
" return seqs - 1\n",
|
| 78 |
+
" \n",
|
| 79 |
+
"@jax.jit\n",
|
| 80 |
+
"def labels2code(params, x, ctx):\n",
|
| 81 |
+
" y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)\n",
|
| 82 |
+
" return aux[\"code\"]\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"@jax.jit\n",
|
| 85 |
+
"def code2labels(params, code, ctx):\n",
|
| 86 |
+
" logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)\n",
|
| 87 |
+
" return task.predict_outputs(logits, config.oracle)"
|
| 88 |
+
],
|
| 89 |
+
"metadata": {
|
| 90 |
+
"id": "QzThueWDzc7I"
|
| 91 |
+
},
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"outputs": []
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"source": [
|
| 98 |
+
"# Load checkpoints\n",
|
| 99 |
+
"!gsutil cp -n gs://big_vision/uvim/color_stageI_params.npz gs://big_vision/uvim/color_stageII_params.npz .\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"oracle_params, oracle_state = vit.load(None, \"color_stageI_params.npz\")\n",
|
| 102 |
+
"oracle_params = jax.device_put({\"params\": oracle_params, \"state\": oracle_state})\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"lm_params = vtt.load(None, \"color_stageII_params.npz\")\n",
|
| 105 |
+
"lm_params = jax.device_put({\"params\": lm_params})"
|
| 106 |
+
],
|
| 107 |
+
"metadata": {
|
| 108 |
+
"id": "AEjRgshLa6Fp"
|
| 109 |
+
},
|
| 110 |
+
"execution_count": null,
|
| 111 |
+
"outputs": []
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"source": [
|
| 116 |
+
"# Prepare set of images from coco/val2017:\n",
|
| 117 |
+
"# - https://cocodataset.org/\n",
|
| 118 |
+
"import os\n",
|
| 119 |
+
"import tensorflow as tf\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"if not os.path.exists(\"val2017/\"):\n",
|
| 122 |
+
" !wget --no-clobber http://images.cocodataset.org/zips/val2017.zip\n",
|
| 123 |
+
" !unzip -uq val2017.zip\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"dataset = tf.data.Dataset.list_files(\"val2017/*.jpg\", shuffle=True)\n",
|
| 126 |
+
"dataset = dataset.map(lambda filename: {\"image\": tf.io.read_file(filename)})\n",
|
| 127 |
+
"dataset = dataset.map(preprocess_fn)"
|
| 128 |
+
],
|
| 129 |
+
"metadata": {
|
| 130 |
+
"id": "BKifDDRnH_Ll"
|
| 131 |
+
},
|
| 132 |
+
"execution_count": null,
|
| 133 |
+
"outputs": []
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"source": [
|
| 138 |
+
"# Run the model in a few examples:\n",
|
| 139 |
+
"from matplotlib import pyplot as plt\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"num_examples = 4\n",
|
| 142 |
+
"data = dataset.batch(1).take(num_examples).as_numpy_iterator()\n",
|
| 143 |
+
"key = jax.random.PRNGKey(0)\n",
|
| 144 |
+
"temperature = jnp.array(1.0)\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"def render_example(image, prediction):\n",
|
| 147 |
+
" f, ax = plt.subplots(1, 2, figsize=(10, 10))\n",
|
| 148 |
+
" ax[0].imshow(image*0.5 + 0.5)\n",
|
| 149 |
+
" ax[0].axis(\"off\")\n",
|
| 150 |
+
" ax[1].imshow(prediction*0.5 + 0.5)\n",
|
| 151 |
+
" ax[1].axis(\"off\")\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"for idx, batch in enumerate(data):\n",
|
| 154 |
+
" subkey = jax.random.fold_in(key, idx)\n",
|
| 155 |
+
" code = predict_code(lm_params, batch, key, temperature)\n",
|
| 156 |
+
" aux_inputs = task.input_pp(batch, config.oracle)\n",
|
| 157 |
+
" prediction = code2labels(oracle_params, code, aux_inputs[\"ctx\"])\n",
|
| 158 |
+
" render_example(batch[\"image\"][0], prediction[\"color\"][0])"
|
| 159 |
+
],
|
| 160 |
+
"metadata": {
|
| 161 |
+
"id": "TuevCy33nuv3"
|
| 162 |
+
},
|
| 163 |
+
"execution_count": null,
|
| 164 |
+
"outputs": []
|
| 165 |
+
}
|
| 166 |
+
]
|
| 167 |
+
}
|