Upload 46 files
Browse files- LICENSE +203 -0
- README.md +113 -0
- __init__.py +0 -0
- class_orders/cifar100.yaml +1 -0
- class_orders/tinyimagenet.yaml +17 -0
- clip/README.md +1 -0
- clip/__init__.py +1 -0
- clip/adapter.py +75 -0
- clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- clip/clip.py +310 -0
- clip/model.py +486 -0
- clip/tokenizer.py +140 -0
- configs/class/cifar100_10-10.yaml +33 -0
- configs/class/cifar100_2-2.yaml +33 -0
- configs/class/cifar100_5-5.yaml +37 -0
- configs/class/tinyimagenet_100-10.yaml +36 -0
- configs/class/tinyimagenet_100-20.yaml +36 -0
- configs/class/tinyimagenet_100-5.yaml +36 -0
- continual_clip/__init__.py +0 -0
- continual_clip/cc.py +53 -0
- continual_clip/clip_original/README.md +1 -0
- continual_clip/clip_original/__init__.py +1 -0
- continual_clip/clip_original/adapter.py +75 -0
- continual_clip/clip_original/bpe_simple_vocab_16e6.txt.gz +3 -0
- continual_clip/clip_original/clip.py +208 -0
- continual_clip/clip_original/model.py +568 -0
- continual_clip/clip_original/tokenizer.py +140 -0
- continual_clip/datasets.py +124 -0
- continual_clip/dynamic_dataset.py +108 -0
- continual_clip/models.py +228 -0
- continual_clip/utils.py +210 -0
- dataset_reqs/imagenet1000_classes.txt +1000 -0
- dataset_reqs/imagenet100_classes.txt +100 -0
- dataset_reqs/imagenet100_splits/train_100.txt +0 -0
- dataset_reqs/imagenet100_splits/val_100.txt +0 -0
- dataset_reqs/tinyimagenet_classes.txt +200 -0
- main.py +104 -0
- requirements.txt +19 -0
- run_cifar100-10-10.sh +9 -0
- templates/__init__.py +0 -0
- templates/fmow_template.py +20 -0
- templates/iwildcam_template.py +4 -0
- templates/openai_imagenet_template.py +82 -0
- templates/simple_template.py +3 -0
- templates/template_utils.py +28 -0
- templates/testing_template.py +83 -0
LICENSE
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2020 - present, Facebook, Inc
|
| 190 |
+
Copyright 2022 - present, Arthur Douillard
|
| 191 |
+
Copyright 2023 - present, Zangwei Zheng, Mingyuan Ma
|
| 192 |
+
|
| 193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 194 |
+
you may not use this file except in compliance with the License.
|
| 195 |
+
You may obtain a copy of the License at
|
| 196 |
+
|
| 197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 198 |
+
|
| 199 |
+
Unless required by applicable law or agreed to in writing, software
|
| 200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 202 |
+
See the License for the specific language governing permissions and
|
| 203 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DMNSP: Dynamic Multi-Layer Null Space Projection for Vision-Language Continual Learning
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/release/python-380/)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](LICENSE)
|
| 6 |
+
|
| 7 |
+
Official implementation of the paper "Dynamic Multi-Layer Null Space Projection for Vision-Language Continual Learning" (ICCV 2025) in PyTorch.
|
| 8 |
+
|
| 9 |
+
## 🎯 Abstract
|
| 10 |
+
|
| 11 |
+
Vision-Language Models (VLM) have emerged as a highly promising approach for Continual Learning (CL) due to their powerful generalized features. While adapter-based VLM can exploit both task-specific and task-agnostic features, current CL methods have largely overlooked the distinct and evolving parameter distributions in visual and language modalities, which are found crucial for effectively mitigating catastrophic forgetting.In this study, we find that the **visual modality experiences a broader parameter distribution and greater variance** during class increments than the textual modality, leading to higher vulnerability to forgetting. Consequently, we handle the branches of the two modalities asymmetrically.
|
| 12 |
+
|
| 13 |
+
### Key Contributions
|
| 14 |
+
|
| 15 |
+
- 🔍 **Asymmetric Modality Handling**: We propose handling visual and language modalities differently based on their distinct parameter distribution characteristics
|
| 16 |
+
- 🚀 **Multi-layer Null Space Projection**: A novel strategy applied only to the visual modality branch to restrict parameter updates within specific subspaces
|
| 17 |
+
- ⚖️ **Dynamic Projection Coefficient**: Precise control of gradient projection magnitude for optimal stability-plasticity balance
|
| 18 |
+
|
| 19 |
+
## 🛠️ Installation
|
| 20 |
+
|
| 21 |
+
### Setup Environment
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
# Install dependencies
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## 📊 Datasets
|
| 29 |
+
|
| 30 |
+
The framework supports the following datasets for class incremental learning:
|
| 31 |
+
|
| 32 |
+
- **CIFAR100**: 100 classes, various incremental settings (2-2, 5-5, 10-10)
|
| 33 |
+
- **TinyImageNet**: 200 classes, incremental settings (200-100-5, 200-100-10, 200-100-20)
|
| 34 |
+
|
| 35 |
+
### Data Preparation
|
| 36 |
+
|
| 37 |
+
1. The datasets will be automatically downloaded when running experiments
|
| 38 |
+
2. Update the `dataset_root` path in your configuration files or command line
|
| 39 |
+
3. Ensure sufficient disk space for dataset storage
|
| 40 |
+
|
| 41 |
+
## 🚀 Quick Start
|
| 42 |
+
|
| 43 |
+
### Basic Usage
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# Run CIFAR100 with 10 initial classes and 10 incremental classes
|
| 47 |
+
sh run_cifar100-10-10.sh
|
| 48 |
+
|
| 49 |
+
# Or run with custom parameters
|
| 50 |
+
python main.py \
|
| 51 |
+
--config-path ./configs/class \
|
| 52 |
+
--config-name cifar100_10-10.yaml \
|
| 53 |
+
dataset_root="/path/to/your/data" \
|
| 54 |
+
class_order="./class_orders/cifar100.yaml"
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Configuration Options
|
| 58 |
+
|
| 59 |
+
The project uses Hydra for configuration management. Key parameters include:
|
| 60 |
+
|
| 61 |
+
```yaml
|
| 62 |
+
# Model settings
|
| 63 |
+
model_name: "ViT-B/16" # CLIP model variant
|
| 64 |
+
prompt_template: "a bad photo of a {}." # Text prompt template
|
| 65 |
+
|
| 66 |
+
# Training settings
|
| 67 |
+
batch_size: 128 # Training batch size
|
| 68 |
+
lr: 1e-3 # Learning rate
|
| 69 |
+
weight_decay: 0.0 # Weight decay
|
| 70 |
+
ls: 0.0 # Label smoothing
|
| 71 |
+
|
| 72 |
+
# Incremental learning settings
|
| 73 |
+
initial_increment: 10 # Initial number of classes
|
| 74 |
+
increment: 10 # Classes per incremental step
|
| 75 |
+
method: "DMNSP" # Method name
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## 🔧 Advanced Usage
|
| 79 |
+
|
| 80 |
+
### Custom Datasets
|
| 81 |
+
|
| 82 |
+
To add support for new datasets:
|
| 83 |
+
|
| 84 |
+
1. Add dataset configuration in `continual_clip/datasets.py`
|
| 85 |
+
2. Create corresponding class order file in `class_orders/`
|
| 86 |
+
3. Add configuration YAML in `configs/class/`
|
| 87 |
+
|
| 88 |
+
## 📄 License
|
| 89 |
+
|
| 90 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 91 |
+
|
| 92 |
+
## 📚 Citation
|
| 93 |
+
|
| 94 |
+
If you find this work useful in your research, please consider citing:
|
| 95 |
+
|
| 96 |
+
```bibtex
|
| 97 |
+
@inproceedings{Kang2025DMNSP,
|
| 98 |
+
title={Dynamic Multi-Layer Null Space Projection for Vision-Language Continual Learning},
|
| 99 |
+
author={Borui Kang, Lei Wang, Zhiping Wu, Tao Feng, Yawen Li, Yang Gao, Wenbin Li},
|
| 100 |
+
journal={ICCV},
|
| 101 |
+
year={2025}
|
| 102 |
+
}
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## 📞 Contact
|
| 106 |
+
|
| 107 |
+
For questions or issues, please:
|
| 108 |
+
- Open an issue on GitHub
|
| 109 |
+
- Contact the authors at [kangborui.cn@gmail.com]
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
|
| 113 |
+
**Note**: This implementation is for research purposes. Please ensure you comply with the respective licenses of the datasets and models used.
|
__init__.py
ADDED
|
File without changes
|
class_orders/cifar100.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
class_order: [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39]
|
class_orders/tinyimagenet.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class_order: [
|
| 2 |
+
131, 181, 22, 172, 144, 92, 97, 187, 58, 93, 6, 70, 106, 68,
|
| 3 |
+
153, 168, 179, 199, 29, 46, 9, 142, 134, 88, 193, 110, 26,
|
| 4 |
+
32, 117, 112, 17, 39, 166, 13, 94, 138, 109, 147, 51, 101,
|
| 5 |
+
59, 188, 116, 5, 170, 99, 100, 167, 180, 146, 65, 1, 104,
|
| 6 |
+
43, 38, 184, 123, 171, 137, 162, 71, 44, 95, 174, 12, 7,
|
| 7 |
+
54, 152, 21, 47, 28, 176, 34, 2, 132, 118, 42, 189, 150,
|
| 8 |
+
14, 165, 41, 192, 45, 82, 128, 63, 57, 197, 160, 53, 75,
|
| 9 |
+
108, 135, 121, 159, 183, 67, 169, 50, 87, 69, 89, 196,
|
| 10 |
+
115, 19, 148, 96, 86, 11, 8, 60, 33, 173, 78, 4, 119, 105,
|
| 11 |
+
182, 127, 177, 30, 186, 40, 49, 178, 76, 157, 161, 73, 164,
|
| 12 |
+
151, 31, 74, 191, 27, 125, 198, 81, 20, 155, 114, 139, 36,
|
| 13 |
+
61, 56, 145, 48, 16, 83, 62, 85, 126, 0, 102, 23, 3, 140,
|
| 14 |
+
15, 195, 133, 113, 190, 141, 52, 163, 156, 80, 111, 90, 175,
|
| 15 |
+
143, 120, 84, 18, 25, 79, 37, 154, 136, 64, 158, 24, 185,
|
| 16 |
+
72, 35, 129, 55, 149, 91, 122, 77, 103, 124, 130, 66, 10, 107, 194, 98
|
| 17 |
+
]
|
clip/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This folder is a lightly modified version of https://github.com/openai/CLIP.
|
clip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clip import *
|
clip/adapter.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# References:
|
| 3 |
+
# https://github.com/jxhe/unify-parameter-efficient-tuning
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Adapter(nn.Module):
|
| 12 |
+
def __init__(self,
|
| 13 |
+
d_model=None,
|
| 14 |
+
bottleneck=None,
|
| 15 |
+
dropout=0.0,
|
| 16 |
+
init_option="lora",
|
| 17 |
+
adapter_scalar="1.0",
|
| 18 |
+
adapter_layernorm_option="in"):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.n_embd = d_model if d_model is None else d_model
|
| 21 |
+
self.down_size = bottleneck
|
| 22 |
+
|
| 23 |
+
#_before
|
| 24 |
+
self.adapter_layernorm_option = adapter_layernorm_option
|
| 25 |
+
|
| 26 |
+
self.adapter_layer_norm_before = None
|
| 27 |
+
if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
|
| 28 |
+
self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
|
| 29 |
+
|
| 30 |
+
if adapter_scalar == "learnable_scalar":
|
| 31 |
+
self.scale = nn.Parameter(torch.ones(1))
|
| 32 |
+
else:
|
| 33 |
+
self.scale = float(adapter_scalar)
|
| 34 |
+
|
| 35 |
+
# self.linear = nn.Linear(self.n_embd, self.n_embd)
|
| 36 |
+
|
| 37 |
+
self.down_proj = nn.Linear(self.n_embd, 64)
|
| 38 |
+
self.non_linear_func = nn.ReLU()
|
| 39 |
+
self.up_proj = nn.Linear(self.down_size, self.n_embd)
|
| 40 |
+
|
| 41 |
+
self.dropout = dropout
|
| 42 |
+
if init_option == "bert":
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
elif init_option == "lora":
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
|
| 47 |
+
nn.init.zeros_(self.up_proj.weight)
|
| 48 |
+
nn.init.zeros_(self.down_proj.bias)
|
| 49 |
+
nn.init.zeros_(self.up_proj.bias)
|
| 50 |
+
elif init_option == "linear":
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
nn.init.zeros_(self.linear.weight)
|
| 53 |
+
|
| 54 |
+
def forward(self, x, add_residual=True, residual=None):
|
| 55 |
+
|
| 56 |
+
residual = x if residual is None else residual
|
| 57 |
+
if self.adapter_layernorm_option == 'in': # none
|
| 58 |
+
x = self.adapter_layer_norm_before(x)
|
| 59 |
+
|
| 60 |
+
down = self.down_proj(x)
|
| 61 |
+
down = self.non_linear_func(down)
|
| 62 |
+
down = nn.functional.dropout(down, p=self.dropout, training=self.training)
|
| 63 |
+
up = self.up_proj(down)
|
| 64 |
+
|
| 65 |
+
up = up * self.scale
|
| 66 |
+
|
| 67 |
+
if self.adapter_layernorm_option == 'out': # none
|
| 68 |
+
up = self.adapter_layer_norm_before(up)
|
| 69 |
+
|
| 70 |
+
if add_residual:
|
| 71 |
+
output = up + residual
|
| 72 |
+
else:
|
| 73 |
+
output = up
|
| 74 |
+
return down, output, \
|
| 75 |
+
self.up_proj.weight, self.down_proj.weight, self.up_proj.bias, self.down_proj.bias
|
clip/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
clip/clip.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code ported from https://github.com/openai/CLIP
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import os
|
| 5 |
+
import urllib
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import Union, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, InterpolationMode
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from clip.model import build_model
|
| 14 |
+
from clip.tokenizer import SimpleTokenizer as _Tokenizer
|
| 15 |
+
|
| 16 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 17 |
+
_tokenizer = _Tokenizer()
|
| 18 |
+
|
| 19 |
+
_MODELS = {
|
| 20 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 21 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 22 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 23 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 24 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 25 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
| 30 |
+
os.makedirs(root, exist_ok=True)
|
| 31 |
+
filename = os.path.basename(url)
|
| 32 |
+
|
| 33 |
+
expected_sha256 = url.split("/")[-2]
|
| 34 |
+
download_target = os.path.join(root, filename)
|
| 35 |
+
|
| 36 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 37 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 38 |
+
|
| 39 |
+
if os.path.isfile(download_target):
|
| 40 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 41 |
+
return download_target
|
| 42 |
+
else:
|
| 43 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 44 |
+
|
| 45 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 46 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
| 47 |
+
while True:
|
| 48 |
+
buffer = source.read(8192)
|
| 49 |
+
if not buffer:
|
| 50 |
+
break
|
| 51 |
+
|
| 52 |
+
output.write(buffer)
|
| 53 |
+
loop.update(len(buffer))
|
| 54 |
+
|
| 55 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 56 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
| 57 |
+
|
| 58 |
+
return download_target
|
| 59 |
+
|
| 60 |
+
def _convert_to_rgb(image):
|
| 61 |
+
return image.convert('RGB')
|
| 62 |
+
|
| 63 |
+
def _transform(n_px: int, is_train: bool):
|
| 64 |
+
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
| 65 |
+
if is_train:
|
| 66 |
+
return Compose([
|
| 67 |
+
RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
|
| 68 |
+
_convert_to_rgb,
|
| 69 |
+
ToTensor(),
|
| 70 |
+
normalize,
|
| 71 |
+
])
|
| 72 |
+
else:
|
| 73 |
+
return Compose([
|
| 74 |
+
Resize(n_px, interpolation=InterpolationMode.BICUBIC),
|
| 75 |
+
CenterCrop(n_px),
|
| 76 |
+
_convert_to_rgb,
|
| 77 |
+
ToTensor(),
|
| 78 |
+
normalize,
|
| 79 |
+
])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def available_models() -> List[str]:
|
| 84 |
+
"""Returns the names of available CLIP models"""
|
| 85 |
+
return list(_MODELS.keys())
|
| 86 |
+
|
| 87 |
+
# def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
| 88 |
+
# """Load a CLIP model
|
| 89 |
+
|
| 90 |
+
# Parameters
|
| 91 |
+
# ----------
|
| 92 |
+
# name : str
|
| 93 |
+
# A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 94 |
+
|
| 95 |
+
# device : Union[str, torch.device]
|
| 96 |
+
# The device to put the loaded model
|
| 97 |
+
|
| 98 |
+
# jit : bool
|
| 99 |
+
# Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 100 |
+
|
| 101 |
+
# download_root: str
|
| 102 |
+
# path to download the model files; by default, it uses "~/.cache/clip"
|
| 103 |
+
|
| 104 |
+
# Returns
|
| 105 |
+
# -------
|
| 106 |
+
# model : torch.nn.Module
|
| 107 |
+
# The CLIP model
|
| 108 |
+
|
| 109 |
+
# preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 110 |
+
# A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 111 |
+
# """
|
| 112 |
+
# if name in _MODELS:
|
| 113 |
+
# model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
| 114 |
+
# elif os.path.isfile(name):
|
| 115 |
+
# model_path = name
|
| 116 |
+
# else:
|
| 117 |
+
# raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 118 |
+
|
| 119 |
+
# try:
|
| 120 |
+
# # loading JIT archive
|
| 121 |
+
# model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
| 122 |
+
# state_dict = None
|
| 123 |
+
# except RuntimeError:
|
| 124 |
+
# # loading saved state dict
|
| 125 |
+
# if jit:
|
| 126 |
+
# warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 127 |
+
# jit = False
|
| 128 |
+
# state_dict = torch.load(model_path, map_location="cpu")
|
| 129 |
+
|
| 130 |
+
# if not jit:
|
| 131 |
+
# model = build_model(state_dict or model.state_dict()).to(device)
|
| 132 |
+
# if str(device) == "cpu":
|
| 133 |
+
# model.float()
|
| 134 |
+
# return model, _transform(model.visual.input_resolution)
|
| 135 |
+
|
| 136 |
+
# # patch the device names
|
| 137 |
+
# device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 138 |
+
# device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 139 |
+
|
| 140 |
+
# def patch_device(module):
|
| 141 |
+
# try:
|
| 142 |
+
# graphs = [module.graph] if hasattr(module, "graph") else []
|
| 143 |
+
# except RuntimeError:
|
| 144 |
+
# graphs = []
|
| 145 |
+
|
| 146 |
+
# if hasattr(module, "forward1"):
|
| 147 |
+
# graphs.append(module.forward1.graph)
|
| 148 |
+
|
| 149 |
+
# for graph in graphs:
|
| 150 |
+
# for node in graph.findAllNodes("prim::Constant"):
|
| 151 |
+
# if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 152 |
+
# node.copyAttributes(device_node)
|
| 153 |
+
|
| 154 |
+
# model.apply(patch_device)
|
| 155 |
+
# patch_device(model.encode_image)
|
| 156 |
+
# patch_device(model.encode_text)
|
| 157 |
+
|
| 158 |
+
# # patch dtype to float32 on CPU
|
| 159 |
+
# if str(device) == "cpu":
|
| 160 |
+
# float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 161 |
+
# float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 162 |
+
# float_node = float_input.node()
|
| 163 |
+
|
| 164 |
+
# def patch_float(module):
|
| 165 |
+
# try:
|
| 166 |
+
# graphs = [module.graph] if hasattr(module, "graph") else []
|
| 167 |
+
# except RuntimeError:
|
| 168 |
+
# graphs = []
|
| 169 |
+
|
| 170 |
+
# if hasattr(module, "forward1"):
|
| 171 |
+
# graphs.append(module.forward1.graph)
|
| 172 |
+
|
| 173 |
+
# for graph in graphs:
|
| 174 |
+
# for node in graph.findAllNodes("aten::to"):
|
| 175 |
+
# inputs = list(node.inputs())
|
| 176 |
+
# for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 177 |
+
# if inputs[i].node()["value"] == 5:
|
| 178 |
+
# inputs[i].node().copyAttributes(float_node)
|
| 179 |
+
|
| 180 |
+
# model.apply(patch_float)
|
| 181 |
+
# patch_float(model.encode_image)
|
| 182 |
+
# patch_float(model.encode_text)
|
| 183 |
+
|
| 184 |
+
# model.float()
|
| 185 |
+
|
| 186 |
+
# return model, _transform(model.input_resolution.item())
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True):
|
| 190 |
+
"""Load a CLIP model
|
| 191 |
+
Parameters
|
| 192 |
+
----------
|
| 193 |
+
name : str
|
| 194 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 195 |
+
device : Union[str, torch.device]
|
| 196 |
+
The device to put the loaded model
|
| 197 |
+
jit : bool
|
| 198 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
| 199 |
+
Returns
|
| 200 |
+
-------
|
| 201 |
+
model : torch.nn.Module
|
| 202 |
+
The CLIP model
|
| 203 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 204 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 205 |
+
"""
|
| 206 |
+
if name in _MODELS:
|
| 207 |
+
model_path = _download(_MODELS[name])
|
| 208 |
+
elif os.path.isfile(name):
|
| 209 |
+
model_path = name
|
| 210 |
+
else:
|
| 211 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
# loading JIT archive
|
| 215 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
| 216 |
+
state_dict = None
|
| 217 |
+
except RuntimeError:
|
| 218 |
+
# loading saved state dict
|
| 219 |
+
if jit:
|
| 220 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 221 |
+
jit = False
|
| 222 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 223 |
+
|
| 224 |
+
if not jit:
|
| 225 |
+
try:
|
| 226 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
| 227 |
+
except KeyError:
|
| 228 |
+
sd = {k[7:]: v for k,v in state_dict["state_dict"].items()}
|
| 229 |
+
model = build_model(sd).to(device)
|
| 230 |
+
|
| 231 |
+
if str(device) == "cpu":
|
| 232 |
+
model.float()
|
| 233 |
+
return model, \
|
| 234 |
+
_transform(model.visual.input_resolution, is_train=True), \
|
| 235 |
+
_transform(model.visual.input_resolution, is_train=False)
|
| 236 |
+
|
| 237 |
+
# patch the device names
|
| 238 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 239 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 240 |
+
|
| 241 |
+
def patch_device(module):
|
| 242 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 243 |
+
if hasattr(module, "forward1"):
|
| 244 |
+
graphs.append(module.forward1.graph)
|
| 245 |
+
|
| 246 |
+
for graph in graphs:
|
| 247 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 248 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 249 |
+
node.copyAttributes(device_node)
|
| 250 |
+
|
| 251 |
+
model.apply(patch_device)
|
| 252 |
+
patch_device(model.encode_image)
|
| 253 |
+
patch_device(model.encode_text)
|
| 254 |
+
|
| 255 |
+
# patch dtype to float32 on CPU
|
| 256 |
+
if str(device) == "cpu":
|
| 257 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 258 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 259 |
+
float_node = float_input.node()
|
| 260 |
+
|
| 261 |
+
def patch_float(module):
|
| 262 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 263 |
+
if hasattr(module, "forward1"):
|
| 264 |
+
graphs.append(module.forward1.graph)
|
| 265 |
+
|
| 266 |
+
for graph in graphs:
|
| 267 |
+
for node in graph.findAllNodes("aten::to"):
|
| 268 |
+
inputs = list(node.inputs())
|
| 269 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 270 |
+
if inputs[i].node()["value"] == 5:
|
| 271 |
+
inputs[i].node().copyAttributes(float_node)
|
| 272 |
+
|
| 273 |
+
model.apply(patch_float)
|
| 274 |
+
patch_float(model.encode_image)
|
| 275 |
+
patch_float(model.encode_text)
|
| 276 |
+
|
| 277 |
+
model.float()
|
| 278 |
+
|
| 279 |
+
return model, \
|
| 280 |
+
_transform(model.input_resolution.item(), is_train=True), \
|
| 281 |
+
_transform(model.input_resolution.item(), is_train=False)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
| 285 |
+
"""
|
| 286 |
+
Returns the tokenized representation of given input string(s)
|
| 287 |
+
Parameters
|
| 288 |
+
----------
|
| 289 |
+
texts : Union[str, List[str]]
|
| 290 |
+
An input string or a list of input strings to tokenize
|
| 291 |
+
context_length : int
|
| 292 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 293 |
+
Returns
|
| 294 |
+
-------
|
| 295 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 296 |
+
"""
|
| 297 |
+
if isinstance(texts, str):
|
| 298 |
+
texts = [texts]
|
| 299 |
+
|
| 300 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
| 301 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
| 302 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 303 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 304 |
+
|
| 305 |
+
for i, tokens in enumerate(all_tokens):
|
| 306 |
+
if len(tokens) > context_length: # Truncate
|
| 307 |
+
tokens = tokens[:context_length]
|
| 308 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 309 |
+
|
| 310 |
+
return result
|
clip/model.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
from .adapter import Adapter
|
| 11 |
+
from torch.distributions.normal import Normal
|
| 12 |
+
from collections import Counter
|
| 13 |
+
|
| 14 |
+
global_taskid = 0
|
| 15 |
+
global_is_train=True
|
| 16 |
+
|
| 17 |
+
class Bottleneck(nn.Module):
|
| 18 |
+
expansion = 4
|
| 19 |
+
|
| 20 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 24 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 25 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 26 |
+
|
| 27 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 28 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 29 |
+
|
| 30 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 31 |
+
|
| 32 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 33 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 34 |
+
|
| 35 |
+
self.relu = nn.ReLU(inplace=True)
|
| 36 |
+
self.downsample = None
|
| 37 |
+
self.stride = stride
|
| 38 |
+
|
| 39 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 40 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 41 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 42 |
+
("-1", nn.AvgPool2d(stride)),
|
| 43 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 44 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 45 |
+
]))
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor):
|
| 48 |
+
identity = x
|
| 49 |
+
|
| 50 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 51 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
| 52 |
+
out = self.avgpool(out)
|
| 53 |
+
out = self.bn3(self.conv3(out))
|
| 54 |
+
|
| 55 |
+
if self.downsample is not None:
|
| 56 |
+
identity = self.downsample(x)
|
| 57 |
+
|
| 58 |
+
out += identity
|
| 59 |
+
out = self.relu(out)
|
| 60 |
+
return out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AttentionPool2d(nn.Module):
|
| 64 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 67 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 68 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 69 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 70 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 71 |
+
self.num_heads = num_heads
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 75 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 76 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 77 |
+
x, _ = F.multi_head_attention_forward(
|
| 78 |
+
query=x, key=x, value=x,
|
| 79 |
+
embed_dim_to_check=x.shape[-1],
|
| 80 |
+
num_heads=self.num_heads,
|
| 81 |
+
q_proj_weight=self.q_proj.weight,
|
| 82 |
+
k_proj_weight=self.k_proj.weight,
|
| 83 |
+
v_proj_weight=self.v_proj.weight,
|
| 84 |
+
in_proj_weight=None,
|
| 85 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 86 |
+
bias_k=None,
|
| 87 |
+
bias_v=None,
|
| 88 |
+
add_zero_attn=False,
|
| 89 |
+
dropout_p=0,
|
| 90 |
+
out_proj_weight=self.c_proj.weight,
|
| 91 |
+
out_proj_bias=self.c_proj.bias,
|
| 92 |
+
use_separate_proj_weight=True,
|
| 93 |
+
training=self.training,
|
| 94 |
+
need_weights=False
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return x[0]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class ModifiedResNet(nn.Module):
|
| 101 |
+
"""
|
| 102 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 103 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 104 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 105 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.output_dim = output_dim
|
| 111 |
+
self.input_resolution = input_resolution
|
| 112 |
+
|
| 113 |
+
# the 3-layer stem
|
| 114 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 115 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 116 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 117 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 118 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 119 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 120 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 121 |
+
self.relu = nn.ReLU(inplace=True)
|
| 122 |
+
|
| 123 |
+
# residual layers
|
| 124 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 125 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 126 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 127 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 128 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 129 |
+
|
| 130 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 131 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 132 |
+
|
| 133 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 134 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 135 |
+
|
| 136 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 137 |
+
for _ in range(1, blocks):
|
| 138 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 139 |
+
|
| 140 |
+
return nn.Sequential(*layers)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
def stem(x):
|
| 144 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
| 145 |
+
x = self.relu(bn(conv(x)))
|
| 146 |
+
x = self.avgpool(x)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
x = x.type(self.conv1.weight.dtype)
|
| 150 |
+
x = stem(x)
|
| 151 |
+
x = self.layer1(x)
|
| 152 |
+
x = self.layer2(x)
|
| 153 |
+
x = self.layer3(x)
|
| 154 |
+
x = self.layer4(x)
|
| 155 |
+
x = self.attnpool(x)
|
| 156 |
+
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class LayerNorm(nn.LayerNorm):
|
| 161 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 162 |
+
|
| 163 |
+
def forward(self, x: torch.Tensor):
|
| 164 |
+
orig_type = x.dtype
|
| 165 |
+
ret = super().forward(x.type(torch.float32))
|
| 166 |
+
return ret.type(orig_type)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class QuickGELU(nn.Module):
|
| 170 |
+
def forward(self, x: torch.Tensor):
|
| 171 |
+
return x * torch.sigmoid(1.702 * x)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ResidualAttentionBlock(nn.Module):
|
| 175 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, text_or_image=None, flag=False):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.register_buffer("mean", torch.tensor([0.0]))
|
| 178 |
+
self.register_buffer("std", torch.tensor([1.0]))
|
| 179 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 180 |
+
self.ln_1 = LayerNorm(d_model)
|
| 181 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 182 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 183 |
+
("gelu", QuickGELU()),
|
| 184 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 185 |
+
]))
|
| 186 |
+
self.ln_2 = LayerNorm(d_model)
|
| 187 |
+
self.attn_mask = attn_mask
|
| 188 |
+
self.is_train = global_is_train
|
| 189 |
+
self.ffn_num = 64
|
| 190 |
+
self.softmax = nn.Softmax(1)
|
| 191 |
+
self.softplus = nn.Softplus()
|
| 192 |
+
self.noisy_gating = True
|
| 193 |
+
self.adaptmlp_list = nn.ModuleList()
|
| 194 |
+
self.text_or_image = text_or_image
|
| 195 |
+
self.flag = flag
|
| 196 |
+
self.adaptmlp = Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num,
|
| 197 |
+
init_option='lora',
|
| 198 |
+
adapter_scalar=0.1,
|
| 199 |
+
adapter_layernorm_option='none',
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def attention(self, x: torch.Tensor):
|
| 203 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 204 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 205 |
+
|
| 206 |
+
def forward(self, x: torch.Tensor):
|
| 207 |
+
x = x + self.attention(self.ln_1(x))
|
| 208 |
+
x_re = x.permute(1, 0, 2)
|
| 209 |
+
down, adapt_x, up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = self.adaptmlp(x_re, add_residual=False)
|
| 210 |
+
adapt_x = adapt_x.permute(1, 0, 2)
|
| 211 |
+
down = down.permute(1, 0, 2)
|
| 212 |
+
x = x + self.mlp(self.ln_2(x)) + adapt_x
|
| 213 |
+
|
| 214 |
+
return x, down, adapt_x, up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class Transformer(nn.Module):
|
| 218 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, text_or_image=None,
|
| 219 |
+
flag =True,):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.width = width
|
| 222 |
+
self.layers = layers
|
| 223 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, text_or_image, flag) for _ in range(layers)])
|
| 224 |
+
self.lora_feature = {}
|
| 225 |
+
|
| 226 |
+
def forward(self, x: torch.Tensor):
|
| 227 |
+
for i in range(len(self.resblocks)):
|
| 228 |
+
x, down, adapt_x, up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = self.resblocks[i](x)
|
| 229 |
+
self.lora_feature[i] = adapt_x
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class VisualTransformer(nn.Module):
|
| 234 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, text_or_image=None):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.input_resolution = input_resolution
|
| 237 |
+
self.output_dim = output_dim
|
| 238 |
+
# Added so this info is available. should not change anything.
|
| 239 |
+
self.patch_size = patch_size
|
| 240 |
+
self.width = width
|
| 241 |
+
self.layers = layers
|
| 242 |
+
self.heads = heads
|
| 243 |
+
|
| 244 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 245 |
+
|
| 246 |
+
scale = width ** -0.5
|
| 247 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 248 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 249 |
+
self.ln_pre = LayerNorm(width)
|
| 250 |
+
|
| 251 |
+
self.transformer = Transformer(width, layers, heads, text_or_image=text_or_image, flag=True)
|
| 252 |
+
|
| 253 |
+
self.ln_post = LayerNorm(width)
|
| 254 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 255 |
+
|
| 256 |
+
def forward(self, x: torch.Tensor):
|
| 257 |
+
x = self.conv1(x)
|
| 258 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 259 |
+
x = x.permute(0, 2, 1)
|
| 260 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 261 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 262 |
+
x = self.ln_pre(x)
|
| 263 |
+
|
| 264 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 265 |
+
x = self.transformer(x)
|
| 266 |
+
x_before_fusion = x
|
| 267 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 268 |
+
|
| 269 |
+
x = self.ln_post(x[:, 0, :])
|
| 270 |
+
|
| 271 |
+
if self.proj is not None:
|
| 272 |
+
x = x @ self.proj
|
| 273 |
+
|
| 274 |
+
return x, x_before_fusion
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class CLIP(nn.Module):
|
| 278 |
+
def __init__(self,
|
| 279 |
+
embed_dim: int,
|
| 280 |
+
# vision
|
| 281 |
+
image_resolution: int,
|
| 282 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 283 |
+
vision_width: int,
|
| 284 |
+
vision_patch_size: int,
|
| 285 |
+
# text
|
| 286 |
+
context_length: int,
|
| 287 |
+
vocab_size: int,
|
| 288 |
+
transformer_width: int,
|
| 289 |
+
transformer_heads: int,
|
| 290 |
+
transformer_layers: int,
|
| 291 |
+
baseline = False
|
| 292 |
+
):
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.baseline = baseline
|
| 295 |
+
|
| 296 |
+
self.context_length = context_length
|
| 297 |
+
|
| 298 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 299 |
+
vision_heads = vision_width * 32 // 64
|
| 300 |
+
self.visual = ModifiedResNet(
|
| 301 |
+
layers=vision_layers,
|
| 302 |
+
output_dim=embed_dim,
|
| 303 |
+
heads=vision_heads,
|
| 304 |
+
input_resolution=image_resolution,
|
| 305 |
+
width=vision_width
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
vision_heads = vision_width // 64
|
| 309 |
+
self.visual = VisualTransformer(
|
| 310 |
+
input_resolution=image_resolution,
|
| 311 |
+
patch_size=vision_patch_size,
|
| 312 |
+
width=vision_width,
|
| 313 |
+
layers=vision_layers,
|
| 314 |
+
heads=vision_heads,
|
| 315 |
+
output_dim=embed_dim,
|
| 316 |
+
text_or_image='image',
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
self.transformer = Transformer(
|
| 320 |
+
width=transformer_width,
|
| 321 |
+
layers=transformer_layers,
|
| 322 |
+
heads=transformer_heads,
|
| 323 |
+
attn_mask=self.build_attention_mask(),
|
| 324 |
+
text_or_image='text',
|
| 325 |
+
flag = True,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
self.vocab_size = vocab_size
|
| 330 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 331 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 332 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 333 |
+
|
| 334 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 335 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 336 |
+
# self.adapt_lamda = [torch.nn.Parameter(30 * torch.rand(1)) for _ in range(12)]
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
self.initialize_parameters()
|
| 340 |
+
|
| 341 |
+
def initialize_parameters(self):
|
| 342 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 343 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 344 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 345 |
+
|
| 346 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 347 |
+
if self.visual.attnpool is not None:
|
| 348 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 349 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 350 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 351 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 352 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 353 |
+
|
| 354 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 355 |
+
for name, param in resnet_block.named_parameters():
|
| 356 |
+
if name.endswith("bn3.weight"):
|
| 357 |
+
nn.init.zeros_(param)
|
| 358 |
+
|
| 359 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 360 |
+
attn_std = self.transformer.width ** -0.5
|
| 361 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 362 |
+
for block in self.transformer.resblocks:
|
| 363 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 364 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 365 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 366 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 367 |
+
|
| 368 |
+
if self.text_projection is not None:
|
| 369 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 370 |
+
|
| 371 |
+
def build_attention_mask(self):
|
| 372 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 373 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 374 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 375 |
+
mask.fill_(float("-inf"))
|
| 376 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 377 |
+
return mask
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def dtype(self):
|
| 381 |
+
return self.visual.conv1.weight.dtype
|
| 382 |
+
|
| 383 |
+
def encode_image(self, image):
|
| 384 |
+
return self.visual(image.type(self.dtype))
|
| 385 |
+
|
| 386 |
+
def encode_text(self, text):
|
| 387 |
+
|
| 388 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 389 |
+
|
| 390 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 391 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 392 |
+
x = self.transformer(x)
|
| 393 |
+
x_before_fusion = x
|
| 394 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 395 |
+
x = self.ln_final(x).type(self.dtype)
|
| 396 |
+
|
| 397 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 398 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 399 |
+
|
| 400 |
+
return x, x_before_fusion
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def forward(self, image, text, taskid, is_train):
|
| 404 |
+
global global_taskid, global_is_train
|
| 405 |
+
global_taskid = taskid
|
| 406 |
+
global_is_train = is_train
|
| 407 |
+
if image is None:
|
| 408 |
+
return self.encode_text(text)
|
| 409 |
+
elif text is None:
|
| 410 |
+
return self.encode_image(image)
|
| 411 |
+
image_features, x_img_before_fusion = self.encode_image(image)
|
| 412 |
+
text_features, x_txt_before_fusion = self.encode_text(text)
|
| 413 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 414 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 415 |
+
|
| 416 |
+
# if self.baseline:
|
| 417 |
+
logit_scale = self.logit_scale.exp()
|
| 418 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 419 |
+
logits_per_text = logits_per_image.t()
|
| 420 |
+
return logits_per_image, logits_per_text
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def convert_weights(model: nn.Module):
|
| 425 |
+
"""Convert applicable model parameters to fp16"""
|
| 426 |
+
|
| 427 |
+
def _convert_weights_to_fp16(l):
|
| 428 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 429 |
+
l.weight.data = l.weight.data.half()
|
| 430 |
+
if l.bias is not None:
|
| 431 |
+
l.bias.data = l.bias.data.half()
|
| 432 |
+
|
| 433 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 434 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 435 |
+
tensor = getattr(l, attr)
|
| 436 |
+
if tensor is not None:
|
| 437 |
+
tensor.data = tensor.data.half()
|
| 438 |
+
|
| 439 |
+
for name in ["text_projection", "proj"]:
|
| 440 |
+
if hasattr(l, name):
|
| 441 |
+
attr = getattr(l, name)
|
| 442 |
+
if attr is not None:
|
| 443 |
+
attr.data = attr.data.half()
|
| 444 |
+
|
| 445 |
+
model.apply(_convert_weights_to_fp16)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def build_model(state_dict: dict):
|
| 449 |
+
vit = "visual.proj" in state_dict
|
| 450 |
+
|
| 451 |
+
if vit:
|
| 452 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 453 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 454 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 455 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 456 |
+
image_resolution = vision_patch_size * grid_size
|
| 457 |
+
else:
|
| 458 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 459 |
+
vision_layers = tuple(counts)
|
| 460 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 461 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 462 |
+
vision_patch_size = None
|
| 463 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 464 |
+
image_resolution = output_width * 32
|
| 465 |
+
|
| 466 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 467 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 468 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 469 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 470 |
+
transformer_heads = transformer_width // 64
|
| 471 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
| 472 |
+
|
| 473 |
+
model = CLIP(
|
| 474 |
+
embed_dim,
|
| 475 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 476 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 480 |
+
if key in state_dict:
|
| 481 |
+
del state_dict[key]
|
| 482 |
+
|
| 483 |
+
model.load_state_dict(state_dict, strict=False)
|
| 484 |
+
for p in model.parameters():
|
| 485 |
+
p.data = p.data.float()
|
| 486 |
+
return model.eval()
|
clip/tokenizer.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import ftfy
|
| 7 |
+
import regex as re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache()
|
| 11 |
+
def default_bpe():
|
| 12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def bytes_to_unicode():
|
| 17 |
+
"""
|
| 18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
+
The reversible bpe codes work on unicode strings.
|
| 20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
+
"""
|
| 26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 27 |
+
cs = bs[:]
|
| 28 |
+
n = 0
|
| 29 |
+
for b in range(2**8):
|
| 30 |
+
if b not in bs:
|
| 31 |
+
bs.append(b)
|
| 32 |
+
cs.append(2**8+n)
|
| 33 |
+
n += 1
|
| 34 |
+
cs = [chr(n) for n in cs]
|
| 35 |
+
return dict(zip(bs, cs))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""Return set of symbol pairs in a word.
|
| 40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
+
"""
|
| 42 |
+
pairs = set()
|
| 43 |
+
prev_char = word[0]
|
| 44 |
+
for char in word[1:]:
|
| 45 |
+
pairs.add((prev_char, char))
|
| 46 |
+
prev_char = char
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def basic_clean(text):
|
| 51 |
+
text = ftfy.fix_text(text)
|
| 52 |
+
text = html.unescape(html.unescape(text))
|
| 53 |
+
return text.strip()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def whitespace_clean(text):
|
| 57 |
+
text = re.sub(r'\s+', ' ', text)
|
| 58 |
+
text = text.strip()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SimpleTokenizer(object):
|
| 63 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
| 64 |
+
self.byte_encoder = bytes_to_unicode()
|
| 65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
+
merges = merges[1:49152-256-2+1]
|
| 68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
+
vocab = list(bytes_to_unicode().values())
|
| 70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
+
for merge in merges:
|
| 72 |
+
vocab.append(''.join(merge))
|
| 73 |
+
if not special_tokens:
|
| 74 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
| 75 |
+
else:
|
| 76 |
+
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
| 77 |
+
vocab.extend(special_tokens)
|
| 78 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 79 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 80 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 81 |
+
self.cache = {t:t for t in special_tokens}
|
| 82 |
+
special = "|".join(special_tokens)
|
| 83 |
+
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = len(self.encoder)
|
| 86 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
| 87 |
+
|
| 88 |
+
def bpe(self, token):
|
| 89 |
+
if token in self.cache:
|
| 90 |
+
return self.cache[token]
|
| 91 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 92 |
+
pairs = get_pairs(word)
|
| 93 |
+
|
| 94 |
+
if not pairs:
|
| 95 |
+
return token+'</w>'
|
| 96 |
+
|
| 97 |
+
while True:
|
| 98 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 99 |
+
if bigram not in self.bpe_ranks:
|
| 100 |
+
break
|
| 101 |
+
first, second = bigram
|
| 102 |
+
new_word = []
|
| 103 |
+
i = 0
|
| 104 |
+
while i < len(word):
|
| 105 |
+
try:
|
| 106 |
+
j = word.index(first, i)
|
| 107 |
+
new_word.extend(word[i:j])
|
| 108 |
+
i = j
|
| 109 |
+
except:
|
| 110 |
+
new_word.extend(word[i:])
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 114 |
+
new_word.append(first+second)
|
| 115 |
+
i += 2
|
| 116 |
+
else:
|
| 117 |
+
new_word.append(word[i])
|
| 118 |
+
i += 1
|
| 119 |
+
new_word = tuple(new_word)
|
| 120 |
+
word = new_word
|
| 121 |
+
if len(word) == 1:
|
| 122 |
+
break
|
| 123 |
+
else:
|
| 124 |
+
pairs = get_pairs(word)
|
| 125 |
+
word = ' '.join(word)
|
| 126 |
+
self.cache[token] = word
|
| 127 |
+
return word
|
| 128 |
+
|
| 129 |
+
def encode(self, text):
|
| 130 |
+
bpe_tokens = []
|
| 131 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 132 |
+
for token in re.findall(self.pat, text):
|
| 133 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 134 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 135 |
+
return bpe_tokens
|
| 136 |
+
|
| 137 |
+
def decode(self, tokens):
|
| 138 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 139 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 140 |
+
return text
|
configs/class/cifar100_10-10.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
|
| 4 |
+
job:
|
| 5 |
+
chdir: true
|
| 6 |
+
|
| 7 |
+
job_logging:
|
| 8 |
+
version: 1
|
| 9 |
+
formatters:
|
| 10 |
+
simple:
|
| 11 |
+
format: '%(message)s'
|
| 12 |
+
|
| 13 |
+
class_order: ""
|
| 14 |
+
dataset_root: ""
|
| 15 |
+
workdir: ""
|
| 16 |
+
log_path: "metrics.json"
|
| 17 |
+
model_name: "ViT-B/16"
|
| 18 |
+
prompt_template: "a bad photo of a {}."
|
| 19 |
+
|
| 20 |
+
batch_size: 128
|
| 21 |
+
increment: ${initial_increment}
|
| 22 |
+
initial_increment: 10
|
| 23 |
+
scenario: "class"
|
| 24 |
+
dataset: "cifar100"
|
| 25 |
+
|
| 26 |
+
weight_decay: 0.0
|
| 27 |
+
l2: 0
|
| 28 |
+
ce_method: 0
|
| 29 |
+
|
| 30 |
+
method: "DMNSP"
|
| 31 |
+
|
| 32 |
+
lr: 1e-3
|
| 33 |
+
ls: 0.0
|
configs/class/cifar100_2-2.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}-${ls}
|
| 4 |
+
job:
|
| 5 |
+
chdir: true
|
| 6 |
+
|
| 7 |
+
job_logging:
|
| 8 |
+
version: 1
|
| 9 |
+
formatters:
|
| 10 |
+
simple:
|
| 11 |
+
format: '%(message)s'
|
| 12 |
+
|
| 13 |
+
class_order: ""
|
| 14 |
+
dataset_root: ""
|
| 15 |
+
workdir: ""
|
| 16 |
+
log_path: "metrics.json"
|
| 17 |
+
model_name: "ViT-B/16"
|
| 18 |
+
prompt_template: "a bad photo of a {}."
|
| 19 |
+
|
| 20 |
+
batch_size: 128
|
| 21 |
+
increment: ${initial_increment}
|
| 22 |
+
initial_increment: 2
|
| 23 |
+
scenario: "class"
|
| 24 |
+
dataset: "cifar100"
|
| 25 |
+
|
| 26 |
+
weight_decay: 0.0
|
| 27 |
+
l2: 0
|
| 28 |
+
ce_method: 0
|
| 29 |
+
|
| 30 |
+
method: "DMNSP"
|
| 31 |
+
lr: 1e-3
|
| 32 |
+
ls: 0.0
|
| 33 |
+
|
configs/class/cifar100_5-5.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
|
| 4 |
+
job:
|
| 5 |
+
chdir: true
|
| 6 |
+
|
| 7 |
+
job_logging:
|
| 8 |
+
version: 1
|
| 9 |
+
formatters:
|
| 10 |
+
simple:
|
| 11 |
+
format: '%(message)s'
|
| 12 |
+
|
| 13 |
+
class_order: ""
|
| 14 |
+
dataset_root: ""
|
| 15 |
+
workdir: ""
|
| 16 |
+
log_path: "metrics.json"
|
| 17 |
+
model_name: "ViT-B/16"
|
| 18 |
+
prompt_template: "a bad photo of a {}."
|
| 19 |
+
|
| 20 |
+
batch_size: 128
|
| 21 |
+
increment: ${initial_increment}
|
| 22 |
+
initial_increment: 5
|
| 23 |
+
scenario: "class"
|
| 24 |
+
dataset: "cifar100"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
weight_decay: 0.0
|
| 29 |
+
l2: 0
|
| 30 |
+
ce_method: 0
|
| 31 |
+
|
| 32 |
+
method: "DMNSP"
|
| 33 |
+
lr: 1e-3
|
| 34 |
+
ls: 0.0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
configs/class/tinyimagenet_100-10.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
|
| 4 |
+
job:
|
| 5 |
+
chdir: true
|
| 6 |
+
|
| 7 |
+
job_logging:
|
| 8 |
+
version: 1
|
| 9 |
+
formatters:
|
| 10 |
+
simple:
|
| 11 |
+
format: '%(message)s'
|
| 12 |
+
|
| 13 |
+
class_order: ""
|
| 14 |
+
dataset_root: ""
|
| 15 |
+
workdir: ""
|
| 16 |
+
log_path: "metrics.json"
|
| 17 |
+
model_name: "ViT-B/16"
|
| 18 |
+
prompt_template: "a bad photo of a {}."
|
| 19 |
+
|
| 20 |
+
batch_size: 128
|
| 21 |
+
initial_increment: 100
|
| 22 |
+
increment: 10
|
| 23 |
+
scenario: "class"
|
| 24 |
+
dataset: "tinyimagenet"
|
| 25 |
+
|
| 26 |
+
weight_decay: 0.0
|
| 27 |
+
l2: 0
|
| 28 |
+
ce_method: 0
|
| 29 |
+
|
| 30 |
+
method: "DMNSP"
|
| 31 |
+
lr: 1e-3
|
| 32 |
+
ls: 0.0
|
| 33 |
+
we:
|
| 34 |
+
avg_freq:
|
| 35 |
+
ref_dataset:
|
| 36 |
+
ref_sentences: random
|
configs/class/tinyimagenet_100-20.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
|
| 4 |
+
job:
|
| 5 |
+
chdir: true
|
| 6 |
+
|
| 7 |
+
job_logging:
|
| 8 |
+
version: 1
|
| 9 |
+
formatters:
|
| 10 |
+
simple:
|
| 11 |
+
format: '%(message)s'
|
| 12 |
+
|
| 13 |
+
class_order: ""
|
| 14 |
+
dataset_root: ""
|
| 15 |
+
workdir: ""
|
| 16 |
+
log_path: "metrics.json"
|
| 17 |
+
model_name: "ViT-B/16"
|
| 18 |
+
prompt_template: "a bad photo of a {}."
|
| 19 |
+
|
| 20 |
+
batch_size: 128
|
| 21 |
+
initial_increment: 100
|
| 22 |
+
increment: 20
|
| 23 |
+
scenario: "class"
|
| 24 |
+
dataset: "tinyimagenet"
|
| 25 |
+
|
| 26 |
+
weight_decay: 0.0
|
| 27 |
+
l2: 0
|
| 28 |
+
ce_method: 0
|
| 29 |
+
|
| 30 |
+
method: "DMNSP"
|
| 31 |
+
lr: 1e-3
|
| 32 |
+
ls: 0.0
|
| 33 |
+
we:
|
| 34 |
+
avg_freq:
|
| 35 |
+
ref_dataset:
|
| 36 |
+
ref_sentences: random
|
configs/class/tinyimagenet_100-5.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
|
| 4 |
+
job:
|
| 5 |
+
chdir: true
|
| 6 |
+
|
| 7 |
+
job_logging:
|
| 8 |
+
version: 1
|
| 9 |
+
formatters:
|
| 10 |
+
simple:
|
| 11 |
+
format: '%(message)s'
|
| 12 |
+
|
| 13 |
+
class_order: ""
|
| 14 |
+
dataset_root: ""
|
| 15 |
+
workdir: ""
|
| 16 |
+
log_path: "metrics.json"
|
| 17 |
+
model_name: "ViT-B/16"
|
| 18 |
+
prompt_template: "a bad photo of a {}."
|
| 19 |
+
|
| 20 |
+
batch_size: 128
|
| 21 |
+
initial_increment: 100
|
| 22 |
+
increment: 5
|
| 23 |
+
scenario: "class"
|
| 24 |
+
dataset: "tinyimagenet"
|
| 25 |
+
|
| 26 |
+
weight_decay: 0.0
|
| 27 |
+
l2: 0
|
| 28 |
+
ce_method: 0
|
| 29 |
+
|
| 30 |
+
method: "DMNSP"
|
| 31 |
+
lr: 1e-3
|
| 32 |
+
ls: 0.0
|
| 33 |
+
we:
|
| 34 |
+
avg_freq:
|
| 35 |
+
ref_dataset:
|
| 36 |
+
ref_sentences: random
|
continual_clip/__init__.py
ADDED
|
File without changes
|
continual_clip/cc.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import (
|
| 6 |
+
DataLoader,
|
| 7 |
+
Dataset,
|
| 8 |
+
IterableDataset,
|
| 9 |
+
SubsetRandomSampler,
|
| 10 |
+
get_worker_info,
|
| 11 |
+
)
|
| 12 |
+
import clip.clip as clip
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CsvDataset(Dataset):
|
| 16 |
+
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
|
| 17 |
+
df = pd.read_csv(input_filename, sep=sep)
|
| 18 |
+
|
| 19 |
+
self.location = os.path.dirname(input_filename)
|
| 20 |
+
self.images = df[img_key].tolist()
|
| 21 |
+
self.captions = df[caption_key].tolist()
|
| 22 |
+
self.transforms = transforms
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return len(self.captions)
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, idx):
|
| 28 |
+
image_path = os.path.join(self.location, str(self.images[idx]))
|
| 29 |
+
images = self.transforms(Image.open(image_path))
|
| 30 |
+
texts = clip.tokenize([str(self.captions[idx])])[0]
|
| 31 |
+
return images, texts
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class conceptual_captions(Dataset):
|
| 35 |
+
def __init__(
|
| 36 |
+
self, transforms, location, batch_size, *args, num_workers=16, **kwargs
|
| 37 |
+
):
|
| 38 |
+
file_name = "Validation_GCC-1.1.0-Validation_output.csv"
|
| 39 |
+
file_path = os.path.join(location, file_name)
|
| 40 |
+
self.template = lambda c: f"a photo of a {c}."
|
| 41 |
+
self.train_dataset = CsvDataset(
|
| 42 |
+
input_filename=file_path,
|
| 43 |
+
transforms=transforms,
|
| 44 |
+
img_key="filepath",
|
| 45 |
+
caption_key="title",
|
| 46 |
+
)
|
| 47 |
+
# breakpoint()
|
| 48 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 49 |
+
self.train_dataset,
|
| 50 |
+
batch_size=batch_size,
|
| 51 |
+
shuffle=True,
|
| 52 |
+
num_workers=num_workers,
|
| 53 |
+
)
|
continual_clip/clip_original/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This folder is a lightly modified version of https://github.com/openai/CLIP.
|
continual_clip/clip_original/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clip import *
|
continual_clip/clip_original/adapter.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# References:
|
| 3 |
+
# https://github.com/jxhe/unify-parameter-efficient-tuning
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Adapter(nn.Module):
|
| 12 |
+
def __init__(self,
|
| 13 |
+
d_model=None,
|
| 14 |
+
bottleneck=None,
|
| 15 |
+
dropout=0.0,
|
| 16 |
+
init_option="lora",
|
| 17 |
+
adapter_scalar="1.0",
|
| 18 |
+
adapter_layernorm_option="in"):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.n_embd = d_model if d_model is None else d_model
|
| 21 |
+
self.down_size = bottleneck
|
| 22 |
+
|
| 23 |
+
#_before
|
| 24 |
+
self.adapter_layernorm_option = adapter_layernorm_option
|
| 25 |
+
|
| 26 |
+
self.adapter_layer_norm_before = None
|
| 27 |
+
if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
|
| 28 |
+
self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
|
| 29 |
+
|
| 30 |
+
if adapter_scalar == "learnable_scalar":
|
| 31 |
+
self.scale = nn.Parameter(torch.ones(1))
|
| 32 |
+
else:
|
| 33 |
+
self.scale = float(adapter_scalar)
|
| 34 |
+
|
| 35 |
+
self.linear = nn.Linear(self.n_embd, self.n_embd)
|
| 36 |
+
|
| 37 |
+
self.down_proj = nn.Linear(self.n_embd, 64)
|
| 38 |
+
self.non_linear_func = nn.ReLU()
|
| 39 |
+
self.up_proj = nn.Linear(self.down_size, self.n_embd)
|
| 40 |
+
|
| 41 |
+
self.dropout = dropout
|
| 42 |
+
if init_option == "bert":
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
elif init_option == "lora":
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
|
| 47 |
+
nn.init.zeros_(self.up_proj.weight)
|
| 48 |
+
nn.init.zeros_(self.down_proj.bias)
|
| 49 |
+
nn.init.zeros_(self.up_proj.bias)
|
| 50 |
+
elif init_option == "linear":
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
nn.init.zeros_(self.linear.weight)
|
| 53 |
+
|
| 54 |
+
def forward(self, x, add_residual=True, residual=None):
|
| 55 |
+
|
| 56 |
+
residual = x if residual is None else residual
|
| 57 |
+
if self.adapter_layernorm_option == 'in': # none
|
| 58 |
+
x = self.adapter_layer_norm_before(x)
|
| 59 |
+
|
| 60 |
+
down = self.down_proj(x)
|
| 61 |
+
down = self.non_linear_func(down)
|
| 62 |
+
down = nn.functional.dropout(down, p=self.dropout, training=self.training)
|
| 63 |
+
up = self.up_proj(down)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
up = up * self.scale
|
| 67 |
+
|
| 68 |
+
if self.adapter_layernorm_option == 'out': # none
|
| 69 |
+
up = self.adapter_layer_norm_before(up)
|
| 70 |
+
|
| 71 |
+
if add_residual:
|
| 72 |
+
output = up + residual
|
| 73 |
+
else:
|
| 74 |
+
output = up
|
| 75 |
+
return output
|
continual_clip/clip_original/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
continual_clip/clip_original/clip.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code ported from https://github.com/openai/CLIP
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import os
|
| 5 |
+
import urllib
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import Union, List
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, InterpolationMode
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from .model import build_model
|
| 13 |
+
from clip.tokenizer import SimpleTokenizer as _Tokenizer
|
| 14 |
+
|
| 15 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 16 |
+
_tokenizer = _Tokenizer()
|
| 17 |
+
|
| 18 |
+
_MODELS = {
|
| 19 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 20 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 21 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 22 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 23 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 24 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
| 29 |
+
os.makedirs(root, exist_ok=True)
|
| 30 |
+
filename = os.path.basename(url)
|
| 31 |
+
|
| 32 |
+
expected_sha256 = url.split("/")[-2]
|
| 33 |
+
download_target = os.path.join(root, filename)
|
| 34 |
+
|
| 35 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 36 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 37 |
+
|
| 38 |
+
if os.path.isfile(download_target):
|
| 39 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 40 |
+
return download_target
|
| 41 |
+
else:
|
| 42 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 43 |
+
|
| 44 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 45 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
| 46 |
+
while True:
|
| 47 |
+
buffer = source.read(8192)
|
| 48 |
+
if not buffer:
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
output.write(buffer)
|
| 52 |
+
loop.update(len(buffer))
|
| 53 |
+
|
| 54 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 55 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
| 56 |
+
|
| 57 |
+
return download_target
|
| 58 |
+
|
| 59 |
+
def _convert_to_rgb(image):
|
| 60 |
+
return image.convert('RGB')
|
| 61 |
+
|
| 62 |
+
def _transform(n_px: int, is_train: bool):
|
| 63 |
+
normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
| 64 |
+
if is_train:
|
| 65 |
+
return Compose([
|
| 66 |
+
RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
|
| 67 |
+
_convert_to_rgb,
|
| 68 |
+
ToTensor(),
|
| 69 |
+
normalize,
|
| 70 |
+
])
|
| 71 |
+
else:
|
| 72 |
+
return Compose([
|
| 73 |
+
Resize(n_px, interpolation=InterpolationMode.BICUBIC),
|
| 74 |
+
CenterCrop(n_px),
|
| 75 |
+
_convert_to_rgb,
|
| 76 |
+
ToTensor(),
|
| 77 |
+
normalize,
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def available_models() -> List[str]:
|
| 83 |
+
"""Returns the names of available CLIP models"""
|
| 84 |
+
return list(_MODELS.keys())
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True):
|
| 88 |
+
"""Load a CLIP model
|
| 89 |
+
Parameters
|
| 90 |
+
----------
|
| 91 |
+
name : str
|
| 92 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 93 |
+
device : Union[str, torch.device]
|
| 94 |
+
The device to put the loaded model
|
| 95 |
+
jit : bool
|
| 96 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
| 97 |
+
Returns
|
| 98 |
+
-------
|
| 99 |
+
model : torch.nn.Module
|
| 100 |
+
The CLIP model
|
| 101 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 102 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 103 |
+
"""
|
| 104 |
+
if name in _MODELS:
|
| 105 |
+
model_path = _download(_MODELS[name])
|
| 106 |
+
elif os.path.isfile(name):
|
| 107 |
+
model_path = name
|
| 108 |
+
else:
|
| 109 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
# loading JIT archive
|
| 113 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
| 114 |
+
state_dict = None
|
| 115 |
+
except RuntimeError:
|
| 116 |
+
# loading saved state dict
|
| 117 |
+
if jit:
|
| 118 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 119 |
+
jit = False
|
| 120 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 121 |
+
|
| 122 |
+
if not jit:
|
| 123 |
+
try:
|
| 124 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
| 125 |
+
except KeyError:
|
| 126 |
+
sd = {k[7:]: v for k,v in state_dict["state_dict"].items()}
|
| 127 |
+
model = build_model(sd).to(device)
|
| 128 |
+
|
| 129 |
+
if str(device) == "cpu":
|
| 130 |
+
model.float()
|
| 131 |
+
return model, \
|
| 132 |
+
_transform(model.visual.input_resolution, is_train=True), \
|
| 133 |
+
_transform(model.visual.input_resolution, is_train=False)
|
| 134 |
+
|
| 135 |
+
# patch the device names
|
| 136 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 137 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 138 |
+
|
| 139 |
+
def patch_device(module):
|
| 140 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 141 |
+
if hasattr(module, "forward1"):
|
| 142 |
+
graphs.append(module.forward1.graph)
|
| 143 |
+
|
| 144 |
+
for graph in graphs:
|
| 145 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 146 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 147 |
+
node.copyAttributes(device_node)
|
| 148 |
+
|
| 149 |
+
model.apply(patch_device)
|
| 150 |
+
patch_device(model.encode_image)
|
| 151 |
+
patch_device(model.encode_text)
|
| 152 |
+
|
| 153 |
+
# patch dtype to float32 on CPU
|
| 154 |
+
if str(device) == "cpu":
|
| 155 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 156 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 157 |
+
float_node = float_input.node()
|
| 158 |
+
|
| 159 |
+
def patch_float(module):
|
| 160 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 161 |
+
if hasattr(module, "forward1"):
|
| 162 |
+
graphs.append(module.forward1.graph)
|
| 163 |
+
|
| 164 |
+
for graph in graphs:
|
| 165 |
+
for node in graph.findAllNodes("aten::to"):
|
| 166 |
+
inputs = list(node.inputs())
|
| 167 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 168 |
+
if inputs[i].node()["value"] == 5:
|
| 169 |
+
inputs[i].node().copyAttributes(float_node)
|
| 170 |
+
|
| 171 |
+
model.apply(patch_float)
|
| 172 |
+
patch_float(model.encode_image)
|
| 173 |
+
patch_float(model.encode_text)
|
| 174 |
+
|
| 175 |
+
model.float()
|
| 176 |
+
|
| 177 |
+
return model, \
|
| 178 |
+
_transform(model.input_resolution.item(), is_train=True), \
|
| 179 |
+
_transform(model.input_resolution.item(), is_train=False)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
| 183 |
+
"""
|
| 184 |
+
Returns the tokenized representation of given input string(s)
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
texts : Union[str, List[str]]
|
| 188 |
+
An input string or a list of input strings to tokenize
|
| 189 |
+
context_length : int
|
| 190 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 191 |
+
Returns
|
| 192 |
+
-------
|
| 193 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 194 |
+
"""
|
| 195 |
+
if isinstance(texts, str):
|
| 196 |
+
texts = [texts]
|
| 197 |
+
|
| 198 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
| 199 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
| 200 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 201 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 202 |
+
|
| 203 |
+
for i, tokens in enumerate(all_tokens):
|
| 204 |
+
if len(tokens) > context_length: # Truncate
|
| 205 |
+
tokens = tokens[:context_length]
|
| 206 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 207 |
+
|
| 208 |
+
return result
|
continual_clip/clip_original/model.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
from .adapter import Adapter
|
| 11 |
+
from torch.distributions.normal import Normal
|
| 12 |
+
from collections import Counter
|
| 13 |
+
|
| 14 |
+
global_taskid = 0
|
| 15 |
+
global_is_train=True
|
| 16 |
+
|
| 17 |
+
class Bottleneck(nn.Module):
|
| 18 |
+
expansion = 4
|
| 19 |
+
|
| 20 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 24 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 25 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 26 |
+
|
| 27 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 28 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 29 |
+
|
| 30 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 31 |
+
|
| 32 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 33 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 34 |
+
|
| 35 |
+
self.relu = nn.ReLU(inplace=True)
|
| 36 |
+
self.downsample = None
|
| 37 |
+
self.stride = stride
|
| 38 |
+
|
| 39 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 40 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 41 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 42 |
+
("-1", nn.AvgPool2d(stride)),
|
| 43 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 44 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 45 |
+
]))
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor):
|
| 48 |
+
identity = x
|
| 49 |
+
|
| 50 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 51 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
| 52 |
+
out = self.avgpool(out)
|
| 53 |
+
out = self.bn3(self.conv3(out))
|
| 54 |
+
|
| 55 |
+
if self.downsample is not None:
|
| 56 |
+
identity = self.downsample(x)
|
| 57 |
+
|
| 58 |
+
out += identity
|
| 59 |
+
out = self.relu(out)
|
| 60 |
+
return out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AttentionPool2d(nn.Module):
|
| 64 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 67 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 68 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 69 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 70 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 71 |
+
self.num_heads = num_heads
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 75 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 76 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 77 |
+
x, _ = F.multi_head_attention_forward(
|
| 78 |
+
query=x, key=x, value=x,
|
| 79 |
+
embed_dim_to_check=x.shape[-1],
|
| 80 |
+
num_heads=self.num_heads,
|
| 81 |
+
q_proj_weight=self.q_proj.weight,
|
| 82 |
+
k_proj_weight=self.k_proj.weight,
|
| 83 |
+
v_proj_weight=self.v_proj.weight,
|
| 84 |
+
in_proj_weight=None,
|
| 85 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 86 |
+
bias_k=None,
|
| 87 |
+
bias_v=None,
|
| 88 |
+
add_zero_attn=False,
|
| 89 |
+
dropout_p=0,
|
| 90 |
+
out_proj_weight=self.c_proj.weight,
|
| 91 |
+
out_proj_bias=self.c_proj.bias,
|
| 92 |
+
use_separate_proj_weight=True,
|
| 93 |
+
training=self.training,
|
| 94 |
+
need_weights=False
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return x[0]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class ModifiedResNet(nn.Module):
|
| 101 |
+
"""
|
| 102 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 103 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 104 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 105 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.output_dim = output_dim
|
| 111 |
+
self.input_resolution = input_resolution
|
| 112 |
+
|
| 113 |
+
# the 3-layer stem
|
| 114 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 115 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 116 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 117 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 118 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 119 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 120 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 121 |
+
self.relu = nn.ReLU(inplace=True)
|
| 122 |
+
|
| 123 |
+
# residual layers
|
| 124 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 125 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 126 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 127 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 128 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 129 |
+
|
| 130 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 131 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 132 |
+
|
| 133 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 134 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 135 |
+
|
| 136 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 137 |
+
for _ in range(1, blocks):
|
| 138 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 139 |
+
|
| 140 |
+
return nn.Sequential(*layers)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
def stem(x):
|
| 144 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
| 145 |
+
x = self.relu(bn(conv(x)))
|
| 146 |
+
x = self.avgpool(x)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
x = x.type(self.conv1.weight.dtype)
|
| 150 |
+
x = stem(x)
|
| 151 |
+
x = self.layer1(x)
|
| 152 |
+
x = self.layer2(x)
|
| 153 |
+
x = self.layer3(x)
|
| 154 |
+
x = self.layer4(x)
|
| 155 |
+
x = self.attnpool(x)
|
| 156 |
+
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class LayerNorm(nn.LayerNorm):
|
| 161 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 162 |
+
|
| 163 |
+
def forward(self, x: torch.Tensor):
|
| 164 |
+
orig_type = x.dtype
|
| 165 |
+
ret = super().forward(x.type(torch.float32))
|
| 166 |
+
return ret.type(orig_type)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class QuickGELU(nn.Module):
|
| 170 |
+
def forward(self, x: torch.Tensor):
|
| 171 |
+
return x * torch.sigmoid(1.702 * x)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ResidualAttentionBlock(nn.Module):
|
| 175 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, text_or_image=None, flag=False):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.register_buffer("mean", torch.tensor([0.0]))
|
| 178 |
+
self.register_buffer("std", torch.tensor([1.0]))
|
| 179 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 180 |
+
self.ln_1 = LayerNorm(d_model)
|
| 181 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 182 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 183 |
+
("gelu", QuickGELU()),
|
| 184 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 185 |
+
]))
|
| 186 |
+
self.ln_2 = LayerNorm(d_model)
|
| 187 |
+
self.attn_mask = attn_mask
|
| 188 |
+
self.is_train = global_is_train
|
| 189 |
+
self.step = 1
|
| 190 |
+
self.top_k = 2
|
| 191 |
+
self.ffn_num = 64
|
| 192 |
+
self.experts_num = 1
|
| 193 |
+
self.softmax = nn.Softmax(1)
|
| 194 |
+
self.softplus = nn.Softplus()
|
| 195 |
+
self.noisy_gating = True
|
| 196 |
+
self.adaptmlp_list = nn.ModuleList()
|
| 197 |
+
self.text_or_image = text_or_image
|
| 198 |
+
self.flag = flag
|
| 199 |
+
if text_or_image == 'text':
|
| 200 |
+
print('vanilla text transformer')
|
| 201 |
+
self.choose_map_text = torch.zeros([ self.experts_num])
|
| 202 |
+
else:
|
| 203 |
+
print('vanilla image transformer')
|
| 204 |
+
self.choose_map_image = torch.zeros([ self.experts_num])
|
| 205 |
+
|
| 206 |
+
# self.taskid = None
|
| 207 |
+
def attention(self, x: torch.Tensor):
|
| 208 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 209 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 210 |
+
|
| 211 |
+
def cv_squared(self, x):
|
| 212 |
+
"""The squared coefficient of variation of a sample.
|
| 213 |
+
Useful as a loss to encourage a positive distribution to be more uniform.
|
| 214 |
+
Epsilons added for numerical stability.
|
| 215 |
+
Returns 0 for an empty Tensor.
|
| 216 |
+
Args:
|
| 217 |
+
x: a `Tensor`.
|
| 218 |
+
Returns:
|
| 219 |
+
a `Scalar`.
|
| 220 |
+
"""
|
| 221 |
+
eps = 1e-10
|
| 222 |
+
# if only num_experts = 1
|
| 223 |
+
|
| 224 |
+
if x.shape[0] == 1:
|
| 225 |
+
return torch.tensor([0], device=x.device, dtype=x.dtype)
|
| 226 |
+
return x.float().var() / (x.float().mean()**2 + eps)
|
| 227 |
+
|
| 228 |
+
def _gates_to_load(self, gates):
|
| 229 |
+
"""Compute the true load per expert, given the gates.
|
| 230 |
+
The load is the number of examples for which the corresponding gate is >0.
|
| 231 |
+
Args:
|
| 232 |
+
gates: a `Tensor` of shape [batch_size, n]
|
| 233 |
+
Returns:
|
| 234 |
+
a float32 `Tensor` of shape [n]
|
| 235 |
+
"""
|
| 236 |
+
return (gates > 0).sum(0)
|
| 237 |
+
|
| 238 |
+
def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
|
| 239 |
+
"""Helper function to NoisyTopKGating.
|
| 240 |
+
Computes the probability that value is in top k, given different random noise.
|
| 241 |
+
This gives us a way of backpropagating from a loss that balances the number
|
| 242 |
+
of times each expert is in the top k experts per example.
|
| 243 |
+
In the case of no noise, pass in None for noise_stddev, and the result will
|
| 244 |
+
not be differentiable.
|
| 245 |
+
Args:
|
| 246 |
+
clean_values: a `Tensor` of shape [batch, n].
|
| 247 |
+
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
|
| 248 |
+
normally distributed noise with standard deviation noise_stddev.
|
| 249 |
+
noise_stddev: a `Tensor` of shape [batch, n], or None
|
| 250 |
+
noisy_top_values: a `Tensor` of shape [batch, m].
|
| 251 |
+
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
|
| 252 |
+
Returns:
|
| 253 |
+
a `Tensor` of shape [batch, n].
|
| 254 |
+
"""
|
| 255 |
+
# print('1231',clean_values) # 全nan
|
| 256 |
+
batch = clean_values.size(0)
|
| 257 |
+
m = noisy_top_values.size(1)
|
| 258 |
+
top_values_flat = noisy_top_values.flatten()
|
| 259 |
+
|
| 260 |
+
threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k
|
| 261 |
+
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
|
| 262 |
+
is_in = torch.gt(noisy_values, threshold_if_in)
|
| 263 |
+
threshold_positions_if_out = threshold_positions_if_in - 1
|
| 264 |
+
threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
|
| 265 |
+
# is each value currently in the top k.
|
| 266 |
+
normal = Normal(self.mean, self.std)
|
| 267 |
+
#
|
| 268 |
+
|
| 269 |
+
prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
|
| 270 |
+
prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
|
| 271 |
+
prob = torch.where(is_in, prob_if_in, prob_if_out)
|
| 272 |
+
return prob
|
| 273 |
+
|
| 274 |
+
def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2):
|
| 275 |
+
"""Noisy top-k gating.
|
| 276 |
+
See paper: https://arxiv.org/abs/1701.06538.
|
| 277 |
+
Args:
|
| 278 |
+
x: input Tensor with shape [batch_size, input_size]
|
| 279 |
+
train: a boolean - we only add noise at training time.
|
| 280 |
+
noise_epsilon: a float
|
| 281 |
+
Returns:
|
| 282 |
+
gates: a Tensor with shape [batch_size, num_experts]
|
| 283 |
+
load: a Tensor with shape [num_experts]
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
clean_logits = x @ w_gate.to(x)
|
| 287 |
+
if self.noisy_gating and train:
|
| 288 |
+
raw_noise_stddev = x @ w_noise.to(x)
|
| 289 |
+
noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
|
| 290 |
+
noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
|
| 291 |
+
logits = noisy_logits
|
| 292 |
+
else:
|
| 293 |
+
logits = clean_logits
|
| 294 |
+
# calculate topk + 1 that will be needed for the noisy gates
|
| 295 |
+
top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1)
|
| 296 |
+
top_k_logits = top_logits[:, :self.top_k]
|
| 297 |
+
top_k_indices = top_indices[:, :self.top_k]
|
| 298 |
+
top_k_gates = self.softmax(top_k_logits)
|
| 299 |
+
zeros = torch.zeros_like(logits)
|
| 300 |
+
gates = zeros.scatter(1, top_k_indices, top_k_gates)
|
| 301 |
+
if self.noisy_gating and self.top_k < self.experts_num and train: # 目前未用上
|
| 302 |
+
load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
|
| 303 |
+
else:
|
| 304 |
+
load = self._gates_to_load(gates)
|
| 305 |
+
return gates, load
|
| 306 |
+
|
| 307 |
+
def forward(self, x: torch.Tensor):
|
| 308 |
+
x = x + self.attention(self.ln_1(x))
|
| 309 |
+
x = x + self.mlp(self.ln_2(x))
|
| 310 |
+
return x
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class Transformer(nn.Module):
|
| 314 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, text_or_image=None,
|
| 315 |
+
flag =True,):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.width = width
|
| 318 |
+
self.layers = layers
|
| 319 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, text_or_image, flag) for _ in range(layers)])
|
| 320 |
+
|
| 321 |
+
def forward(self, x: torch.Tensor):
|
| 322 |
+
return self.resblocks(x)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class VisualTransformer(nn.Module):
|
| 326 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, text_or_image=None):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.input_resolution = input_resolution
|
| 329 |
+
self.output_dim = output_dim
|
| 330 |
+
# Added so this info is available. should not change anything.
|
| 331 |
+
self.patch_size = patch_size
|
| 332 |
+
self.width = width
|
| 333 |
+
self.layers = layers
|
| 334 |
+
self.heads = heads
|
| 335 |
+
|
| 336 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 337 |
+
|
| 338 |
+
scale = width ** -0.5
|
| 339 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 340 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 341 |
+
self.ln_pre = LayerNorm(width)
|
| 342 |
+
|
| 343 |
+
self.transformer = Transformer(width, layers, heads, text_or_image=text_or_image, flag=True)
|
| 344 |
+
|
| 345 |
+
self.ln_post = LayerNorm(width)
|
| 346 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 347 |
+
|
| 348 |
+
def forward(self, x: torch.Tensor):
|
| 349 |
+
x = self.conv1(x)
|
| 350 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 351 |
+
x = x.permute(0, 2, 1)
|
| 352 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 353 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 354 |
+
x = self.ln_pre(x)
|
| 355 |
+
|
| 356 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 357 |
+
x = self.transformer(x)
|
| 358 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 359 |
+
|
| 360 |
+
x = self.ln_post(x[:, 0, :])
|
| 361 |
+
|
| 362 |
+
if self.proj is not None:
|
| 363 |
+
x = x @ self.proj
|
| 364 |
+
|
| 365 |
+
return x
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class CLIP(nn.Module):
|
| 369 |
+
def __init__(self,
|
| 370 |
+
embed_dim: int,
|
| 371 |
+
# vision
|
| 372 |
+
image_resolution: int,
|
| 373 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 374 |
+
vision_width: int,
|
| 375 |
+
vision_patch_size: int,
|
| 376 |
+
# text
|
| 377 |
+
context_length: int,
|
| 378 |
+
vocab_size: int,
|
| 379 |
+
transformer_width: int,
|
| 380 |
+
transformer_heads: int,
|
| 381 |
+
transformer_layers: int,
|
| 382 |
+
baseline = False
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.baseline = baseline
|
| 386 |
+
|
| 387 |
+
self.context_length = context_length
|
| 388 |
+
|
| 389 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 390 |
+
vision_heads = vision_width * 32 // 64
|
| 391 |
+
self.visual = ModifiedResNet(
|
| 392 |
+
layers=vision_layers,
|
| 393 |
+
output_dim=embed_dim,
|
| 394 |
+
heads=vision_heads,
|
| 395 |
+
input_resolution=image_resolution,
|
| 396 |
+
width=vision_width
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
vision_heads = vision_width // 64
|
| 400 |
+
self.visual = VisualTransformer(
|
| 401 |
+
input_resolution=image_resolution,
|
| 402 |
+
patch_size=vision_patch_size,
|
| 403 |
+
width=vision_width,
|
| 404 |
+
layers=vision_layers,
|
| 405 |
+
heads=vision_heads,
|
| 406 |
+
output_dim=embed_dim,
|
| 407 |
+
text_or_image='image',
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# self.transformer = Transformer(
|
| 411 |
+
# width=transformer_width,
|
| 412 |
+
# layers=transformer_layers,
|
| 413 |
+
# heads=transformer_heads,
|
| 414 |
+
# attn_mask=self.build_attention_mask(),
|
| 415 |
+
# text_or_image='text'
|
| 416 |
+
# )
|
| 417 |
+
self.transformer = Transformer(
|
| 418 |
+
width=transformer_width,
|
| 419 |
+
layers=transformer_layers,
|
| 420 |
+
heads=transformer_heads,
|
| 421 |
+
attn_mask=self.build_attention_mask(),
|
| 422 |
+
text_or_image='text',
|
| 423 |
+
flag = True,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
self.vocab_size = vocab_size
|
| 428 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 429 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 430 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 431 |
+
|
| 432 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 433 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
self.initialize_parameters()
|
| 437 |
+
|
| 438 |
+
def initialize_parameters(self):
|
| 439 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 440 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 441 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 442 |
+
|
| 443 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 444 |
+
if self.visual.attnpool is not None:
|
| 445 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 446 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 447 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 448 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 449 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 450 |
+
|
| 451 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 452 |
+
for name, param in resnet_block.named_parameters():
|
| 453 |
+
if name.endswith("bn3.weight"):
|
| 454 |
+
nn.init.zeros_(param)
|
| 455 |
+
|
| 456 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 457 |
+
attn_std = self.transformer.width ** -0.5
|
| 458 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 459 |
+
for block in self.transformer.resblocks:
|
| 460 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 461 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 462 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 463 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 464 |
+
|
| 465 |
+
if self.text_projection is not None:
|
| 466 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 467 |
+
|
| 468 |
+
def build_attention_mask(self):
|
| 469 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 470 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 471 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 472 |
+
mask.fill_(float("-inf"))
|
| 473 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 474 |
+
return mask
|
| 475 |
+
|
| 476 |
+
@property
|
| 477 |
+
def dtype(self):
|
| 478 |
+
return self.visual.conv1.weight.dtype
|
| 479 |
+
|
| 480 |
+
def encode_image(self, image):
|
| 481 |
+
return self.visual(image.type(self.dtype))
|
| 482 |
+
|
| 483 |
+
def encode_text(self, text):
|
| 484 |
+
|
| 485 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 486 |
+
|
| 487 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 488 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 489 |
+
x = self.transformer(x)
|
| 490 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 491 |
+
x = self.ln_final(x).type(self.dtype)
|
| 492 |
+
|
| 493 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 494 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 495 |
+
|
| 496 |
+
return x
|
| 497 |
+
|
| 498 |
+
def forward(self, image, text, taskid, is_train):
|
| 499 |
+
global global_taskid, global_is_train
|
| 500 |
+
global_taskid = taskid
|
| 501 |
+
global_is_train = is_train
|
| 502 |
+
return self.encode_image(image)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def convert_weights(model: nn.Module):
|
| 507 |
+
"""Convert applicable model parameters to fp16"""
|
| 508 |
+
|
| 509 |
+
def _convert_weights_to_fp16(l):
|
| 510 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 511 |
+
l.weight.data = l.weight.data.half()
|
| 512 |
+
if l.bias is not None:
|
| 513 |
+
l.bias.data = l.bias.data.half()
|
| 514 |
+
|
| 515 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 516 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 517 |
+
tensor = getattr(l, attr)
|
| 518 |
+
if tensor is not None:
|
| 519 |
+
tensor.data = tensor.data.half()
|
| 520 |
+
|
| 521 |
+
for name in ["text_projection", "proj"]:
|
| 522 |
+
if hasattr(l, name):
|
| 523 |
+
attr = getattr(l, name)
|
| 524 |
+
if attr is not None:
|
| 525 |
+
attr.data = attr.data.half()
|
| 526 |
+
|
| 527 |
+
model.apply(_convert_weights_to_fp16)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def build_model(state_dict: dict):
|
| 531 |
+
vit = "visual.proj" in state_dict
|
| 532 |
+
|
| 533 |
+
if vit:
|
| 534 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 535 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 536 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 537 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 538 |
+
image_resolution = vision_patch_size * grid_size
|
| 539 |
+
else:
|
| 540 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 541 |
+
vision_layers = tuple(counts)
|
| 542 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 543 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 544 |
+
vision_patch_size = None
|
| 545 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 546 |
+
image_resolution = output_width * 32
|
| 547 |
+
|
| 548 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 549 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 550 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 551 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 552 |
+
transformer_heads = transformer_width // 64
|
| 553 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
| 554 |
+
|
| 555 |
+
model = CLIP(
|
| 556 |
+
embed_dim,
|
| 557 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 558 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 562 |
+
if key in state_dict:
|
| 563 |
+
del state_dict[key]
|
| 564 |
+
|
| 565 |
+
model.load_state_dict(state_dict, strict=False)
|
| 566 |
+
for p in model.parameters():
|
| 567 |
+
p.data = p.data.float()
|
| 568 |
+
return model.eval()
|
continual_clip/clip_original/tokenizer.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import ftfy
|
| 7 |
+
import regex as re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache()
|
| 11 |
+
def default_bpe():
|
| 12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def bytes_to_unicode():
|
| 17 |
+
"""
|
| 18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
+
The reversible bpe codes work on unicode strings.
|
| 20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
+
"""
|
| 26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 27 |
+
cs = bs[:]
|
| 28 |
+
n = 0
|
| 29 |
+
for b in range(2**8):
|
| 30 |
+
if b not in bs:
|
| 31 |
+
bs.append(b)
|
| 32 |
+
cs.append(2**8+n)
|
| 33 |
+
n += 1
|
| 34 |
+
cs = [chr(n) for n in cs]
|
| 35 |
+
return dict(zip(bs, cs))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""Return set of symbol pairs in a word.
|
| 40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
+
"""
|
| 42 |
+
pairs = set()
|
| 43 |
+
prev_char = word[0]
|
| 44 |
+
for char in word[1:]:
|
| 45 |
+
pairs.add((prev_char, char))
|
| 46 |
+
prev_char = char
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def basic_clean(text):
|
| 51 |
+
text = ftfy.fix_text(text)
|
| 52 |
+
text = html.unescape(html.unescape(text))
|
| 53 |
+
return text.strip()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def whitespace_clean(text):
|
| 57 |
+
text = re.sub(r'\s+', ' ', text)
|
| 58 |
+
text = text.strip()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SimpleTokenizer(object):
|
| 63 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
| 64 |
+
self.byte_encoder = bytes_to_unicode()
|
| 65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
+
merges = merges[1:49152-256-2+1]
|
| 68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
+
vocab = list(bytes_to_unicode().values())
|
| 70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
+
for merge in merges:
|
| 72 |
+
vocab.append(''.join(merge))
|
| 73 |
+
if not special_tokens:
|
| 74 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
| 75 |
+
else:
|
| 76 |
+
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
| 77 |
+
vocab.extend(special_tokens)
|
| 78 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 79 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 80 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 81 |
+
self.cache = {t:t for t in special_tokens}
|
| 82 |
+
special = "|".join(special_tokens)
|
| 83 |
+
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = len(self.encoder)
|
| 86 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
| 87 |
+
|
| 88 |
+
def bpe(self, token):
|
| 89 |
+
if token in self.cache:
|
| 90 |
+
return self.cache[token]
|
| 91 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 92 |
+
pairs = get_pairs(word)
|
| 93 |
+
|
| 94 |
+
if not pairs:
|
| 95 |
+
return token+'</w>'
|
| 96 |
+
|
| 97 |
+
while True:
|
| 98 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 99 |
+
if bigram not in self.bpe_ranks:
|
| 100 |
+
break
|
| 101 |
+
first, second = bigram
|
| 102 |
+
new_word = []
|
| 103 |
+
i = 0
|
| 104 |
+
while i < len(word):
|
| 105 |
+
try:
|
| 106 |
+
j = word.index(first, i)
|
| 107 |
+
new_word.extend(word[i:j])
|
| 108 |
+
i = j
|
| 109 |
+
except:
|
| 110 |
+
new_word.extend(word[i:])
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 114 |
+
new_word.append(first+second)
|
| 115 |
+
i += 2
|
| 116 |
+
else:
|
| 117 |
+
new_word.append(word[i])
|
| 118 |
+
i += 1
|
| 119 |
+
new_word = tuple(new_word)
|
| 120 |
+
word = new_word
|
| 121 |
+
if len(word) == 1:
|
| 122 |
+
break
|
| 123 |
+
else:
|
| 124 |
+
pairs = get_pairs(word)
|
| 125 |
+
word = ' '.join(word)
|
| 126 |
+
self.cache[token] = word
|
| 127 |
+
return word
|
| 128 |
+
|
| 129 |
+
def encode(self, text):
|
| 130 |
+
bpe_tokens = []
|
| 131 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 132 |
+
for token in re.findall(self.pat, text):
|
| 133 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 134 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 135 |
+
return bpe_tokens
|
| 136 |
+
|
| 137 |
+
def decode(self, tokens):
|
| 138 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 139 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 140 |
+
return text
|
continual_clip/datasets.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from continuum import ClassIncremental, InstanceIncremental
|
| 7 |
+
from continuum.datasets import (
|
| 8 |
+
CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50,
|
| 9 |
+
fgvc_aircraft, Caltech101, DTD, EuroSAT, flowers102, food101,
|
| 10 |
+
MNIST, OxfordPet, SUN397
|
| 11 |
+
|
| 12 |
+
)
|
| 13 |
+
from .utils import get_dataset_class_names
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ImageNet1000(ImageFolderDataset):
|
| 17 |
+
"""Continuum dataset for datasetsss with tree-like structure.
|
| 18 |
+
:param train_folder: The folder of the train data.
|
| 19 |
+
:param test_folder: The folder of the test data.
|
| 20 |
+
:param download: Dummy parameter.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
data_path: str,
|
| 26 |
+
train: bool = True,
|
| 27 |
+
download: bool = False,
|
| 28 |
+
):
|
| 29 |
+
super().__init__(data_path=data_path, train=train, download=download)
|
| 30 |
+
|
| 31 |
+
def get_data(self):
|
| 32 |
+
if self.train:
|
| 33 |
+
self.data_path = os.path.join(self.data_path, "train")
|
| 34 |
+
else:
|
| 35 |
+
self.data_path = os.path.join(self.data_path, "val")
|
| 36 |
+
return super().get_data()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_dataset(cfg, is_train, transforms=None):
|
| 40 |
+
if cfg.dataset == "cifar100":
|
| 41 |
+
data_path = os.path.join(cfg.dataset_root, cfg.dataset)
|
| 42 |
+
dataset = CIFAR100(
|
| 43 |
+
data_path=data_path,
|
| 44 |
+
download=True,
|
| 45 |
+
train=is_train,
|
| 46 |
+
# transforms=transforms
|
| 47 |
+
)
|
| 48 |
+
classes_names = dataset.dataset.classes
|
| 49 |
+
|
| 50 |
+
# elif cfg.dataset == "tiny-imagenet-200":
|
| 51 |
+
elif cfg.dataset == "tinyimagenet":
|
| 52 |
+
# data_path = '/data/kangborui/'
|
| 53 |
+
data_path = os.path.join(cfg.dataset_root, cfg.dataset)
|
| 54 |
+
dataset = TinyImageNet200(
|
| 55 |
+
data_path,
|
| 56 |
+
train=is_train,
|
| 57 |
+
download=True
|
| 58 |
+
)
|
| 59 |
+
classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
|
| 60 |
+
|
| 61 |
+
elif cfg.dataset == "imagenet100":
|
| 62 |
+
data_path = cfg.dataset_root
|
| 63 |
+
# data_path = os.path.join(cfg.dataset_root, "ImageNet")
|
| 64 |
+
dataset = ImageNet100(
|
| 65 |
+
data_path,
|
| 66 |
+
train=is_train,
|
| 67 |
+
data_subset=os.path.join('/home/kangborui/ClProject/MoE-Adapters4CL-cross-guild-fusion/cil/dataset_reqs/imagenet100_splits', "train_100.txt" if is_train else "val_100.txt")
|
| 68 |
+
)
|
| 69 |
+
classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
|
| 70 |
+
|
| 71 |
+
elif cfg.dataset == "imagenet1000":
|
| 72 |
+
data_path = os.path.join(cfg.dataset_root, cfg.dataset)
|
| 73 |
+
dataset = ImageNet1000(
|
| 74 |
+
data_path,
|
| 75 |
+
train=is_train
|
| 76 |
+
)
|
| 77 |
+
classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
|
| 78 |
+
|
| 79 |
+
elif cfg.dataset == "core50":
|
| 80 |
+
data_path = os.path.join(cfg.dataset_root, cfg.dataset)
|
| 81 |
+
dataset = dataset = Core50(
|
| 82 |
+
data_path,
|
| 83 |
+
scenario="domains",
|
| 84 |
+
classification="category",
|
| 85 |
+
train=is_train
|
| 86 |
+
)
|
| 87 |
+
classes_names = [
|
| 88 |
+
"plug adapters", "mobile phones", "scissors", "light bulbs", "cans",
|
| 89 |
+
"glasses", "balls", "markers", "cups", "remote controls"
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
ValueError(f"'{cfg.dataset}' is a invalid dataset.")
|
| 94 |
+
|
| 95 |
+
return dataset, classes_names
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module:
|
| 99 |
+
|
| 100 |
+
dataset, classes_names = get_dataset(cfg, is_train)
|
| 101 |
+
|
| 102 |
+
if cfg.scenario == "class":
|
| 103 |
+
scenario = ClassIncremental(
|
| 104 |
+
dataset,
|
| 105 |
+
initial_increment=cfg.initial_increment,
|
| 106 |
+
increment=cfg.increment,
|
| 107 |
+
transformations=transforms.transforms, # Convert Compose into list
|
| 108 |
+
class_order=cfg.class_order,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
elif cfg.scenario == "domain":
|
| 112 |
+
scenario = InstanceIncremental(
|
| 113 |
+
dataset,
|
| 114 |
+
transformations=transforms.transforms,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
elif cfg.scenario == "task-agnostic":
|
| 118 |
+
NotImplementedError("Method has not been implemented. Soon be added.")
|
| 119 |
+
|
| 120 |
+
else:
|
| 121 |
+
ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, "
|
| 122 |
+
"please choose from {{'class', 'domain', 'task-agnostic'}}.")
|
| 123 |
+
|
| 124 |
+
return scenario, classes_names
|
continual_clip/dynamic_dataset.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
# from torchvision import datasetsss
|
| 8 |
+
from . import datasets, utils
|
| 9 |
+
# import clip.clip as clip
|
| 10 |
+
import clip
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
class DynamicDataset():
|
| 14 |
+
|
| 15 |
+
def __init__(self, cfg):
|
| 16 |
+
self.ref_database = {} # all data key = dataset_name; value is 4D tensor (num_image, 3, 224, 224)
|
| 17 |
+
self.ref_names = [] # collect the name of the dataset
|
| 18 |
+
self.ref_model, _, self.test_preprocess = clip.load(cfg.model_name, jit=False)
|
| 19 |
+
self.cur_dataset = None
|
| 20 |
+
self.memory_size = 5000
|
| 21 |
+
self.batch_id = 0
|
| 22 |
+
|
| 23 |
+
def update(self, dataset, load):
|
| 24 |
+
# load is a model directly
|
| 25 |
+
self.cur_dataset = dataset
|
| 26 |
+
if not load: # first round USELESS FOR CICL
|
| 27 |
+
new_dataset = self.getNewDataset()
|
| 28 |
+
self.ref_database[self.cur_dataset] = new_dataset[:self.memory_size]
|
| 29 |
+
self.ref_names.append(self.cur_dataset)
|
| 30 |
+
else: # other rounds
|
| 31 |
+
self.ref_model = load
|
| 32 |
+
self.reduceExampleSet()
|
| 33 |
+
self.constructExampleSet()
|
| 34 |
+
|
| 35 |
+
def reduceExampleSet(self):
|
| 36 |
+
print("Reducing Example Set")
|
| 37 |
+
K, t = self.memory_size, len(self.ref_names)+1
|
| 38 |
+
m = K // t
|
| 39 |
+
for dataset in self.ref_names:
|
| 40 |
+
self.ref_database[dataset] = self.ref_database[dataset][:m]
|
| 41 |
+
|
| 42 |
+
def constructExampleSet(self):
|
| 43 |
+
# breakpoint()
|
| 44 |
+
print("Constructing Example Set")
|
| 45 |
+
self.ref_names.append(self.batch_id)
|
| 46 |
+
new_dataset = torch.tensor(self.getNewDataset())
|
| 47 |
+
image_feature = []
|
| 48 |
+
num = new_dataset.shape[0]
|
| 49 |
+
|
| 50 |
+
print("[Constructing] Calculating Distance")
|
| 51 |
+
for ndx in tqdm(np.arange(num)):
|
| 52 |
+
img = torch.unsqueeze(new_dataset[ndx], dim=0)
|
| 53 |
+
img = img.cuda()
|
| 54 |
+
img_feature = self.ref_model(img, None)
|
| 55 |
+
image_feature.append(img_feature.cpu().detach().tolist())
|
| 56 |
+
image_feature = torch.tensor(image_feature)
|
| 57 |
+
image_feature = torch.squeeze(image_feature, dim=1)
|
| 58 |
+
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
| 59 |
+
image_feature = np.array(image_feature.cpu().detach())
|
| 60 |
+
image_feature_average = image_feature.mean(axis=0)
|
| 61 |
+
|
| 62 |
+
K, t = self.memory_size, len(self.ref_names)
|
| 63 |
+
m = K - K // t
|
| 64 |
+
update_dataset = []
|
| 65 |
+
if not m:
|
| 66 |
+
m = self.memory_size
|
| 67 |
+
cur_embedding_sum = None
|
| 68 |
+
print("[Constructing] Collecting Examples")
|
| 69 |
+
for k in tqdm(np.arange(min(m, len(image_feature)))):
|
| 70 |
+
if not k:
|
| 71 |
+
index = np.argmin(
|
| 72 |
+
np.sum((image_feature_average - image_feature)**2, axis=1)
|
| 73 |
+
)
|
| 74 |
+
cur_embedding_sum = image_feature[index]
|
| 75 |
+
update_dataset.append((new_dataset.cpu())[index].tolist())
|
| 76 |
+
image_feature = np.delete(image_feature, index, axis=0)
|
| 77 |
+
else:
|
| 78 |
+
index = np.argmin(
|
| 79 |
+
np.sum((
|
| 80 |
+
image_feature_average - (1/(k+1))*(image_feature + cur_embedding_sum)
|
| 81 |
+
)**2, axis=1)
|
| 82 |
+
)
|
| 83 |
+
cur_embedding_sum += image_feature[index]
|
| 84 |
+
update_dataset.append((new_dataset.cpu())[index].tolist())
|
| 85 |
+
image_feature = np.delete(image_feature, index, axis=0)
|
| 86 |
+
|
| 87 |
+
self.ref_database[self.batch_id] = update_dataset
|
| 88 |
+
print("finishing current task", self.batch_id)
|
| 89 |
+
self.batch_id = self.batch_id + 1
|
| 90 |
+
|
| 91 |
+
def getNewDataset(self):
|
| 92 |
+
samples = []
|
| 93 |
+
count = 0
|
| 94 |
+
for sample in tqdm(self.cur_dataset):
|
| 95 |
+
if count == 10000:
|
| 96 |
+
return samples
|
| 97 |
+
count += 1
|
| 98 |
+
samples.append(sample[0].tolist())
|
| 99 |
+
return samples
|
| 100 |
+
|
| 101 |
+
def get(self):
|
| 102 |
+
print("Getting Reference Images")
|
| 103 |
+
value = list(self.ref_database.values())
|
| 104 |
+
out = []
|
| 105 |
+
for i in tqdm(value):
|
| 106 |
+
out += i
|
| 107 |
+
return torch.tensor(out)
|
| 108 |
+
|
continual_clip/models.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import DictConfig
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import clip.clip as clip
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from .utils import get_class_ids_per_task, get_class_names
|
| 9 |
+
from . import utils
|
| 10 |
+
from .dynamic_dataset import DynamicDataset
|
| 11 |
+
|
| 12 |
+
DEFAULT_THRESHOLD = 0.985
|
| 13 |
+
TOP_SELECT = 1
|
| 14 |
+
EPOCH_NUM = 4
|
| 15 |
+
TOP_K_RATIO = 0.1
|
| 16 |
+
LAMBDA_SCALE = 30
|
| 17 |
+
LAYER_NUM = 12
|
| 18 |
+
|
| 19 |
+
class ClassIncremental(nn.Module):
|
| 20 |
+
def __init__(self, cfg, device, origin_flag, jit=False):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.prompt_template = cfg.prompt_template
|
| 23 |
+
self.device = device
|
| 24 |
+
self.classes_names = None
|
| 25 |
+
self.origin_flag = origin_flag
|
| 26 |
+
self.model, self.transforms, _ = clip.load(cfg.model_name, device=device, jit=jit)
|
| 27 |
+
self.ref_model = None
|
| 28 |
+
self.class_ids_per_task = list(get_class_ids_per_task(cfg))
|
| 29 |
+
self.current_class_names = []
|
| 30 |
+
self.text_tokens = None
|
| 31 |
+
self.dynamic_dataset = DynamicDataset(cfg)
|
| 32 |
+
self.prev_gradients = None
|
| 33 |
+
self.visual_cur_matrix = {}
|
| 34 |
+
self.visual_U = {}
|
| 35 |
+
self.loss_list = []
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def forward(self, image, taskid):
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
logits_per_image, _ = self.model(image, self.text_tokens, 0, is_train=False)
|
| 42 |
+
probs = logits_per_image.softmax(dim=-1)
|
| 43 |
+
return probs
|
| 44 |
+
|
| 45 |
+
def adaptation(self, task_id, cfg, train_dataset, train_classes_names, world):
|
| 46 |
+
self.current_class_names += get_class_names(self.classes_names, self.class_ids_per_task[task_id])
|
| 47 |
+
self.text_tokens = clip.tokenize(
|
| 48 |
+
[self.prompt_template.format(c) for c in self.current_class_names]
|
| 49 |
+
).cuda(device=2)
|
| 50 |
+
if cfg.method != "zeroshot":
|
| 51 |
+
self.train(task_id, cfg, train_dataset, train_classes_names, world)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def train(self, task_id, cfg, train_dataset, train_classes_names, world):
|
| 56 |
+
|
| 57 |
+
train_loader = DataLoader(train_dataset[task_id:task_id + 1],
|
| 58 |
+
batch_size=cfg.batch_size,
|
| 59 |
+
shuffle=True, num_workers=8)
|
| 60 |
+
|
| 61 |
+
train_iter = iter(train_loader)
|
| 62 |
+
EPOCH = EPOCH_NUM
|
| 63 |
+
num_batches = len(train_loader)
|
| 64 |
+
total_iterations = EPOCH * num_batches
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
for k, v in self.model.named_parameters():
|
| 68 |
+
if "adapt" not in k:
|
| 69 |
+
v.requires_grad = False
|
| 70 |
+
|
| 71 |
+
params = [
|
| 72 |
+
v for k, v in self.model.named_parameters() if "adapt" in k
|
| 73 |
+
]
|
| 74 |
+
params_name = [
|
| 75 |
+
k for k, v in self.model.named_parameters() if "adapt" in k
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
print('========trainable params============', params_name)
|
| 79 |
+
# optimizer
|
| 80 |
+
optimizer = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
| 81 |
+
scheduler = utils.cosine_lr(
|
| 82 |
+
optimizer, cfg.lr, 30, total_iterations
|
| 83 |
+
)
|
| 84 |
+
self.model = self.model.cuda(device=2)
|
| 85 |
+
|
| 86 |
+
classnames = get_class_names(self.classes_names, self.class_ids_per_task[task_id])
|
| 87 |
+
print(classnames)
|
| 88 |
+
texts = [self.prompt_template.format(c) for c in classnames]
|
| 89 |
+
texts = clip.tokenize(texts).cuda(device=2)
|
| 90 |
+
|
| 91 |
+
self.model.train()
|
| 92 |
+
|
| 93 |
+
batch_count = 0
|
| 94 |
+
lamda = [[0 for _ in range(LAYER_NUM)] for _ in range(LAYER_NUM)]
|
| 95 |
+
for iteration in tqdm(range(total_iterations + 1)):
|
| 96 |
+
scheduler(iteration)
|
| 97 |
+
try:
|
| 98 |
+
inputs, targets, task_ids = next(train_iter)
|
| 99 |
+
except:
|
| 100 |
+
train_iter = iter(train_loader)
|
| 101 |
+
inputs, targets, task_ids = next(train_iter)
|
| 102 |
+
|
| 103 |
+
if cfg.dataset == "tinyimagenet" and task_id != 0:
|
| 104 |
+
shift = 100 + (task_id - 1) * cfg.increment
|
| 105 |
+
targets -= shift
|
| 106 |
+
elif cfg.dataset == "imagenet100" and task_id != 0:
|
| 107 |
+
shift = cfg.initial_increment + (task_id - 1) * cfg.increment
|
| 108 |
+
targets -= shift
|
| 109 |
+
else:
|
| 110 |
+
shift = task_id * cfg.increment
|
| 111 |
+
targets -= shift
|
| 112 |
+
|
| 113 |
+
inputs, targets = inputs.cuda(device=2), targets.cuda(device=2)
|
| 114 |
+
logits_per_image, _ = self.model.cuda(device=2)(inputs, texts.cuda(device=2), 0, is_train=True) # 分开
|
| 115 |
+
|
| 116 |
+
loss = F.cross_entropy(logits_per_image, targets, label_smoothing=cfg.ls)
|
| 117 |
+
self.loss_list.append(loss)
|
| 118 |
+
print('CELoss: {}'.format(loss))
|
| 119 |
+
optimizer.zero_grad()
|
| 120 |
+
loss.backward()
|
| 121 |
+
|
| 122 |
+
if task_id != 0:
|
| 123 |
+
if batch_count == 0:
|
| 124 |
+
for j in range(LAYER_NUM):
|
| 125 |
+
activation_visual = self.model.visual.transformer.lora_feature[j]
|
| 126 |
+
activation_visual = torch.bmm(activation_visual.detach().permute(1, 2, 0),
|
| 127 |
+
activation_visual.detach().permute(1, 0, 2)).sum(dim=0)
|
| 128 |
+
U_visual, S, Vh = torch.linalg.svd(activation_visual, full_matrices=False)
|
| 129 |
+
U_visual = U_visual[:, :TOP_SELECT]
|
| 130 |
+
|
| 131 |
+
for k in range(LAYER_NUM):
|
| 132 |
+
v_visual = self.visual_U[k]
|
| 133 |
+
|
| 134 |
+
normalized_vector_visual = U_visual / torch.norm(U_visual)
|
| 135 |
+
similarities_visual = []
|
| 136 |
+
for column_visual in v_visual.t():
|
| 137 |
+
normalized_column_visual = column_visual / torch.norm(column_visual)
|
| 138 |
+
cos_sim_visual = torch.dot(normalized_vector_visual.squeeze(),
|
| 139 |
+
normalized_column_visual.squeeze())
|
| 140 |
+
similarities_visual.append(cos_sim_visual)
|
| 141 |
+
|
| 142 |
+
dot_products_visual = torch.mean(
|
| 143 |
+
torch.topk(torch.stack(similarities_visual), int(len(similarities_visual) * TOP_K_RATIO))[0])
|
| 144 |
+
lamda[j][k] = torch.exp(-dot_products_visual) * LAMBDA_SCALE
|
| 145 |
+
|
| 146 |
+
batch_count = batch_count + 1
|
| 147 |
+
for name, params in self.model.named_parameters():
|
| 148 |
+
|
| 149 |
+
for i in range(LAYER_NUM):
|
| 150 |
+
if 'visual' in name and 'adapt' in name and 'down' in name and 'weight' in name:
|
| 151 |
+
v = self.visual_U[i]
|
| 152 |
+
v_ = torch.mm(params.grad.data, v)
|
| 153 |
+
params.grad.data = torch.mm(v_, v.T)* lamda[int(name.split(".")[3])][i]
|
| 154 |
+
|
| 155 |
+
elif 'visual' in name and 'adapt' in name and 'up' in name and 'weight' in name:
|
| 156 |
+
v = self.visual_U[i]
|
| 157 |
+
v_ = torch.mm(v.T, params.grad.data)
|
| 158 |
+
params.grad.data = torch.mm(v, v_)* lamda[int(name.split(".")[3])][i]
|
| 159 |
+
|
| 160 |
+
optimizer.step()
|
| 161 |
+
|
| 162 |
+
torch.cuda.empty_cache()
|
| 163 |
+
|
| 164 |
+
train_loader_ = DataLoader(train_dataset[task_id:task_id + 1],
|
| 165 |
+
batch_size=128,
|
| 166 |
+
shuffle=True, num_workers=8)
|
| 167 |
+
counts = 0
|
| 168 |
+
models = self.model.cuda(2)
|
| 169 |
+
for inputs, targets, task_ids in tqdm(train_loader_):
|
| 170 |
+
inputs = inputs.cuda(device=2)
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
outputs = models(inputs, texts.cuda(2), 0, is_train=False)
|
| 173 |
+
|
| 174 |
+
for i in range(LAYER_NUM):
|
| 175 |
+
if len(self.visual_cur_matrix) == i:
|
| 176 |
+
activation = models.visual.transformer.lora_feature[i]
|
| 177 |
+
activation = torch.bmm(activation.detach().permute(1, 2, 0),
|
| 178 |
+
activation.detach().permute(1, 0, 2)).sum(dim=0)
|
| 179 |
+
self.visual_cur_matrix[i] = activation
|
| 180 |
+
|
| 181 |
+
U, S, Vh = torch.linalg.svd(activation, full_matrices=False)
|
| 182 |
+
self.visual_U[i] = U[:,TOP_SELECT:]
|
| 183 |
+
|
| 184 |
+
else:
|
| 185 |
+
activation = models.visual.transformer.lora_feature[i]
|
| 186 |
+
activation = torch.bmm(activation.detach().permute(1, 2, 0),
|
| 187 |
+
activation.detach().permute(1, 0, 2)).sum(dim=0)
|
| 188 |
+
|
| 189 |
+
U1, S1, Vh1 = torch.linalg.svd(activation, full_matrices=False)
|
| 190 |
+
Ui = torch.cat((self.visual_U[i], U1[:, TOP_SELECT:]), dim=1)
|
| 191 |
+
self.visual_U[i] = Ui
|
| 192 |
+
|
| 193 |
+
counts = counts + 1
|
| 194 |
+
if counts == 1:
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
torch.cuda.empty_cache()
|
| 198 |
+
self.model.eval()
|
| 199 |
+
|
| 200 |
+
class DomainIncremental(nn.Module):
|
| 201 |
+
pass
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class TaskAgnostic(nn.Module):
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def load_model(cfg: DictConfig, device: torch.device, origin_flag) -> nn.Module:
|
| 209 |
+
r"""Load a CLIP model in different continual scenarios.
|
| 210 |
+
|
| 211 |
+
Arguments:
|
| 212 |
+
cfg (DictConfig): Experiment configurations.
|
| 213 |
+
device (torch.device): Device to train (or) evaluate the model on.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
nn.Module: Return scenario specific CLIP model.
|
| 217 |
+
"""
|
| 218 |
+
if cfg.scenario == "class":
|
| 219 |
+
return ClassIncremental(cfg, device, origin_flag)
|
| 220 |
+
elif cfg.scenario == "domain":
|
| 221 |
+
return DomainIncremental(cfg, device)
|
| 222 |
+
elif cfg.scenario == "task-aganostic":
|
| 223 |
+
return TaskAgnostic(cfg, device)
|
| 224 |
+
else:
|
| 225 |
+
raise ValueError(f"""
|
| 226 |
+
`{cfg.scenarios}` is not a valid scenario,
|
| 227 |
+
Please choose from ['class', "domain', 'task-agnostic']
|
| 228 |
+
""")
|
continual_clip/utils.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import yaml
|
| 5 |
+
|
| 6 |
+
from omegaconf import DictConfig, OmegaConf
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
import pickle
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from clip.tokenizer import SimpleTokenizer as _Tokenizer
|
| 16 |
+
|
| 17 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 18 |
+
_tokenizer = _Tokenizer()
|
| 19 |
+
|
| 20 |
+
def get_class_order(file_name: str) -> list:
|
| 21 |
+
r"""TO BE DOCUMENTED"""
|
| 22 |
+
with open(file_name, "r+") as f:
|
| 23 |
+
data = yaml.safe_load(f)
|
| 24 |
+
return data["class_order"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_class_ids_per_task(args):
|
| 28 |
+
yield args.class_order[:args.initial_increment]
|
| 29 |
+
for i in range(args.initial_increment, len(args.class_order), args.increment):
|
| 30 |
+
yield args.class_order[i:i + args.increment]
|
| 31 |
+
|
| 32 |
+
def get_class_names(classes_names, class_ids_per_task):
|
| 33 |
+
return [classes_names[class_id] for class_id in class_ids_per_task]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_dataset_class_names(workdir, dataset_name, long=False):
|
| 37 |
+
with open(os.path.join(workdir, "dataset_reqs", f"{dataset_name}_classes.txt"), "r") as f:
|
| 38 |
+
lines = f.read().splitlines()
|
| 39 |
+
return [line.split("\t")[-1] for line in lines]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def save_config(config: DictConfig) -> None:
|
| 43 |
+
OmegaConf.save(config, "config.yaml")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_workdir(path):
|
| 47 |
+
split_path = path.split("/")
|
| 48 |
+
workdir_idx = split_path.index("cil")
|
| 49 |
+
return "/".join(split_path[:workdir_idx+1])
|
| 50 |
+
|
| 51 |
+
###########################
|
| 52 |
+
def assign_learning_rate(param_group, new_lr):
|
| 53 |
+
param_group["lr"] = new_lr
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _warmup_lr(base_lr, warmup_length, step):
|
| 57 |
+
return base_lr * (step + 1) / warmup_length
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
|
| 61 |
+
if not isinstance(base_lrs, list):
|
| 62 |
+
base_lrs = [base_lrs for _ in optimizer.param_groups]
|
| 63 |
+
assert len(base_lrs) == len(optimizer.param_groups)
|
| 64 |
+
|
| 65 |
+
def _lr_adjuster(step):
|
| 66 |
+
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
|
| 67 |
+
if step < warmup_length:
|
| 68 |
+
lr = _warmup_lr(base_lr, warmup_length, step)
|
| 69 |
+
else:
|
| 70 |
+
e = step - warmup_length
|
| 71 |
+
es = steps - warmup_length
|
| 72 |
+
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
|
| 73 |
+
assign_learning_rate(param_group, lr)
|
| 74 |
+
|
| 75 |
+
return _lr_adjuster
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def accuracy(output, target, topk=(1,)):
|
| 79 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
| 80 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 81 |
+
return [
|
| 82 |
+
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
|
| 83 |
+
for k in topk
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def torch_save(classifier, save_path):
|
| 88 |
+
if os.path.dirname(save_path) != "":
|
| 89 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 90 |
+
torch.save({"state_dict": classifier.state_dict()}, save_path)
|
| 91 |
+
print("Checkpoint saved to", save_path)
|
| 92 |
+
|
| 93 |
+
# with open(save_path, 'wb') as f:
|
| 94 |
+
# pickle.dump(classifier.cpu(), f)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def torch_load(classifier, save_path, device=None):
|
| 98 |
+
checkpoint = torch.load(save_path)
|
| 99 |
+
missing_keys, unexpected_keys = classifier.load_state_dict(
|
| 100 |
+
checkpoint["state_dict"], strict=False
|
| 101 |
+
)
|
| 102 |
+
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
|
| 103 |
+
print("Missing keys:", missing_keys)
|
| 104 |
+
print("Unexpected keys:", unexpected_keys)
|
| 105 |
+
print("Checkpoint loaded from", save_path)
|
| 106 |
+
# with open(save_path, 'rb') as f:
|
| 107 |
+
# classifier = pickle.load(f)
|
| 108 |
+
|
| 109 |
+
if device is not None:
|
| 110 |
+
classifier = classifier.to(device)
|
| 111 |
+
return classifier
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_logits(inputs, classifier):
|
| 115 |
+
assert callable(classifier)
|
| 116 |
+
if hasattr(classifier, "to"):
|
| 117 |
+
classifier = classifier.to(inputs.device)
|
| 118 |
+
return classifier(inputs)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_probs(inputs, classifier):
|
| 122 |
+
if hasattr(classifier, "predict_proba"):
|
| 123 |
+
probs = classifier.predict_proba(inputs.detach().cpu().numpy())
|
| 124 |
+
return torch.from_numpy(probs)
|
| 125 |
+
logits = get_logits(inputs, classifier)
|
| 126 |
+
return logits.softmax(dim=1)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class LabelSmoothing(torch.nn.Module):
|
| 130 |
+
def __init__(self, smoothing=0.0):
|
| 131 |
+
super(LabelSmoothing, self).__init__()
|
| 132 |
+
self.confidence = 1.0 - smoothing
|
| 133 |
+
self.smoothing = smoothing
|
| 134 |
+
|
| 135 |
+
def forward(self, x, target):
|
| 136 |
+
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
| 137 |
+
|
| 138 |
+
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
| 139 |
+
nll_loss = nll_loss.squeeze(1)
|
| 140 |
+
smooth_loss = -logprobs.mean(dim=-1)
|
| 141 |
+
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
| 142 |
+
return loss.mean()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def seed_all(seed):
|
| 146 |
+
torch.manual_seed(seed)
|
| 147 |
+
torch.cuda.manual_seed_all(seed)
|
| 148 |
+
np.random.seed(seed)
|
| 149 |
+
random.seed(seed)
|
| 150 |
+
torch.backends.cudnn.deterministic = True
|
| 151 |
+
torch.backends.cudnn.benchmark = False
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def num_parameters(model):
|
| 155 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 156 |
+
|
| 157 |
+
def batch(iterable, n=64):
|
| 158 |
+
l = len(iterable)
|
| 159 |
+
for ndx in range(0, l, n):
|
| 160 |
+
yield iterable[ndx:min(ndx + n, l)]
|
| 161 |
+
|
| 162 |
+
def merge_we(model_0, model_1, sma_count):
|
| 163 |
+
for param_q, param_k in zip(model_0.parameters(), model_1.parameters()):
|
| 164 |
+
param_k.data = (param_k.data * sma_count + param_q.data) / (1.0 + sma_count)
|
| 165 |
+
return model_1
|
| 166 |
+
|
| 167 |
+
def wise_we(model_0, model_1, sma_count, model_n, alpha=0.95):
|
| 168 |
+
for param_q, param_k, param_n in zip(model_0.parameters(), model_1.parameters(), model_n.parameters()):
|
| 169 |
+
param_k.data = (
|
| 170 |
+
(param_k.data * sma_count + param_q.data) / (1.0 + sma_count)
|
| 171 |
+
) * alpha + param_n.data * (1-alpha)
|
| 172 |
+
return model_1
|
| 173 |
+
|
| 174 |
+
def merge_we_router(model_0, model_1, sma_count):
|
| 175 |
+
for param_q, param_k, name_q, name_k in zip(model_0.parameters(), model_1.parameters(), model_0.named_parameters(), model_1.named_parameters()):
|
| 176 |
+
if "router" in name_k[0] or "noise" in name_k[0]:
|
| 177 |
+
param_k.data = (param_k.data * sma_count + param_q.data) / (1.0 + sma_count)
|
| 178 |
+
# print('111', name_k[0], name_q[0])
|
| 179 |
+
return model_1
|
| 180 |
+
|
| 181 |
+
def moving_avg(model_0, model_1, alpha=0.999):
|
| 182 |
+
for param_q, param_k in zip(model_0.parameters(), model_1.parameters()):
|
| 183 |
+
param_q.data = param_q.data * alpha + param_k.data * (1 - alpha)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def l2_loss(model, model_ref):
|
| 187 |
+
loss = 0.0
|
| 188 |
+
for param_q, param_k in zip(model.parameters(), model_ref.parameters()):
|
| 189 |
+
loss += F.mse_loss(param_q, param_k.detach(), reduction="sum")
|
| 190 |
+
return loss
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def virtual_vocab(length=10, n_class=1000):
|
| 194 |
+
voc_len = len(_tokenizer.encoder)
|
| 195 |
+
# breakpoint()
|
| 196 |
+
texts = torch.randint(0, voc_len, (n_class, length))
|
| 197 |
+
start = torch.full((n_class, 1), _tokenizer.encoder["<start_of_text>"])
|
| 198 |
+
end = torch.full((n_class, 1), _tokenizer.encoder["<end_of_text>"])
|
| 199 |
+
zeros = torch.zeros((n_class, 75 - length), dtype=torch.long)
|
| 200 |
+
|
| 201 |
+
texts = torch.cat([start, texts, end, zeros], dim=1)
|
| 202 |
+
return texts
|
| 203 |
+
|
| 204 |
+
def distillation(t, s, T=2):
|
| 205 |
+
p = F.softmax(t / T, dim=1)
|
| 206 |
+
loss = F.cross_entropy(s / T, p, reduction="mean") * (T ** 2)
|
| 207 |
+
return loss
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
dataset_reqs/imagenet1000_classes.txt
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0 tench
|
| 2 |
+
1 goldfish
|
| 3 |
+
2 great white shark
|
| 4 |
+
3 tiger shark
|
| 5 |
+
4 hammerhead shark
|
| 6 |
+
5 electric ray
|
| 7 |
+
6 stingray
|
| 8 |
+
7 rooster
|
| 9 |
+
8 hen
|
| 10 |
+
9 ostrich
|
| 11 |
+
10 brambling
|
| 12 |
+
11 goldfinch
|
| 13 |
+
12 house finch
|
| 14 |
+
13 junco
|
| 15 |
+
14 indigo bunting
|
| 16 |
+
15 American robin
|
| 17 |
+
16 bulbul
|
| 18 |
+
17 jay
|
| 19 |
+
18 magpie
|
| 20 |
+
19 chickadee
|
| 21 |
+
20 American dipper
|
| 22 |
+
21 kite (bird of prey)
|
| 23 |
+
22 bald eagle
|
| 24 |
+
23 vulture
|
| 25 |
+
24 great grey owl
|
| 26 |
+
25 fire salamander
|
| 27 |
+
26 smooth newt
|
| 28 |
+
27 newt
|
| 29 |
+
28 spotted salamander
|
| 30 |
+
29 axolotl
|
| 31 |
+
30 American bullfrog
|
| 32 |
+
31 tree frog
|
| 33 |
+
32 tailed frog
|
| 34 |
+
33 loggerhead sea turtle
|
| 35 |
+
34 leatherback sea turtle
|
| 36 |
+
35 mud turtle
|
| 37 |
+
36 terrapin
|
| 38 |
+
37 box turtle
|
| 39 |
+
38 banded gecko
|
| 40 |
+
39 green iguana
|
| 41 |
+
40 Carolina anole
|
| 42 |
+
41 desert grassland whiptail lizard
|
| 43 |
+
42 agama
|
| 44 |
+
43 frilled-necked lizard
|
| 45 |
+
44 alligator lizard
|
| 46 |
+
45 Gila monster
|
| 47 |
+
46 European green lizard
|
| 48 |
+
47 chameleon
|
| 49 |
+
48 Komodo dragon
|
| 50 |
+
49 Nile crocodile
|
| 51 |
+
50 American alligator
|
| 52 |
+
51 triceratops
|
| 53 |
+
52 worm snake
|
| 54 |
+
53 ring-necked snake
|
| 55 |
+
54 eastern hog-nosed snake
|
| 56 |
+
55 smooth green snake
|
| 57 |
+
56 kingsnake
|
| 58 |
+
57 garter snake
|
| 59 |
+
58 water snake
|
| 60 |
+
59 vine snake
|
| 61 |
+
60 night snake
|
| 62 |
+
61 boa constrictor
|
| 63 |
+
62 African rock python
|
| 64 |
+
63 Indian cobra
|
| 65 |
+
64 green mamba
|
| 66 |
+
65 sea snake
|
| 67 |
+
66 Saharan horned viper
|
| 68 |
+
67 eastern diamondback rattlesnake
|
| 69 |
+
68 sidewinder rattlesnake
|
| 70 |
+
69 trilobite
|
| 71 |
+
70 harvestman
|
| 72 |
+
71 scorpion
|
| 73 |
+
72 yellow garden spider
|
| 74 |
+
73 barn spider
|
| 75 |
+
74 European garden spider
|
| 76 |
+
75 southern black widow
|
| 77 |
+
76 tarantula
|
| 78 |
+
77 wolf spider
|
| 79 |
+
78 tick
|
| 80 |
+
79 centipede
|
| 81 |
+
80 black grouse
|
| 82 |
+
81 ptarmigan
|
| 83 |
+
82 ruffed grouse
|
| 84 |
+
83 prairie grouse
|
| 85 |
+
84 peafowl
|
| 86 |
+
85 quail
|
| 87 |
+
86 partridge
|
| 88 |
+
87 african grey parrot
|
| 89 |
+
88 macaw
|
| 90 |
+
89 sulphur-crested cockatoo
|
| 91 |
+
90 lorikeet
|
| 92 |
+
91 coucal
|
| 93 |
+
92 bee eater
|
| 94 |
+
93 hornbill
|
| 95 |
+
94 hummingbird
|
| 96 |
+
95 jacamar
|
| 97 |
+
96 toucan
|
| 98 |
+
97 duck
|
| 99 |
+
98 red-breasted merganser
|
| 100 |
+
99 goose
|
| 101 |
+
100 black swan
|
| 102 |
+
101 tusker
|
| 103 |
+
102 echidna
|
| 104 |
+
103 platypus
|
| 105 |
+
104 wallaby
|
| 106 |
+
105 koala
|
| 107 |
+
106 wombat
|
| 108 |
+
107 jellyfish
|
| 109 |
+
108 sea anemone
|
| 110 |
+
109 brain coral
|
| 111 |
+
110 flatworm
|
| 112 |
+
111 nematode
|
| 113 |
+
112 conch
|
| 114 |
+
113 snail
|
| 115 |
+
114 slug
|
| 116 |
+
115 sea slug
|
| 117 |
+
116 chiton
|
| 118 |
+
117 chambered nautilus
|
| 119 |
+
118 Dungeness crab
|
| 120 |
+
119 rock crab
|
| 121 |
+
120 fiddler crab
|
| 122 |
+
121 red king crab
|
| 123 |
+
122 American lobster
|
| 124 |
+
123 spiny lobster
|
| 125 |
+
124 crayfish
|
| 126 |
+
125 hermit crab
|
| 127 |
+
126 isopod
|
| 128 |
+
127 white stork
|
| 129 |
+
128 black stork
|
| 130 |
+
129 spoonbill
|
| 131 |
+
130 flamingo
|
| 132 |
+
131 little blue heron
|
| 133 |
+
132 great egret
|
| 134 |
+
133 bittern bird
|
| 135 |
+
134 crane bird
|
| 136 |
+
135 limpkin
|
| 137 |
+
136 common gallinule
|
| 138 |
+
137 American coot
|
| 139 |
+
138 bustard
|
| 140 |
+
139 ruddy turnstone
|
| 141 |
+
140 dunlin
|
| 142 |
+
141 common redshank
|
| 143 |
+
142 dowitcher
|
| 144 |
+
143 oystercatcher
|
| 145 |
+
144 pelican
|
| 146 |
+
145 king penguin
|
| 147 |
+
146 albatross
|
| 148 |
+
147 grey whale
|
| 149 |
+
148 killer whale
|
| 150 |
+
149 dugong
|
| 151 |
+
150 sea lion
|
| 152 |
+
151 Chihuahua
|
| 153 |
+
152 Japanese Chin
|
| 154 |
+
153 Maltese
|
| 155 |
+
154 Pekingese
|
| 156 |
+
155 Shih Tzu
|
| 157 |
+
156 King Charles Spaniel
|
| 158 |
+
157 Papillon
|
| 159 |
+
158 toy terrier
|
| 160 |
+
159 Rhodesian Ridgeback
|
| 161 |
+
160 Afghan Hound
|
| 162 |
+
161 Basset Hound
|
| 163 |
+
162 Beagle
|
| 164 |
+
163 Bloodhound
|
| 165 |
+
164 Bluetick Coonhound
|
| 166 |
+
165 Black and Tan Coonhound
|
| 167 |
+
166 Treeing Walker Coonhound
|
| 168 |
+
167 English foxhound
|
| 169 |
+
168 Redbone Coonhound
|
| 170 |
+
169 borzoi
|
| 171 |
+
170 Irish Wolfhound
|
| 172 |
+
171 Italian Greyhound
|
| 173 |
+
172 Whippet
|
| 174 |
+
173 Ibizan Hound
|
| 175 |
+
174 Norwegian Elkhound
|
| 176 |
+
175 Otterhound
|
| 177 |
+
176 Saluki
|
| 178 |
+
177 Scottish Deerhound
|
| 179 |
+
178 Weimaraner
|
| 180 |
+
179 Staffordshire Bull Terrier
|
| 181 |
+
180 American Staffordshire Terrier
|
| 182 |
+
181 Bedlington Terrier
|
| 183 |
+
182 Border Terrier
|
| 184 |
+
183 Kerry Blue Terrier
|
| 185 |
+
184 Irish Terrier
|
| 186 |
+
185 Norfolk Terrier
|
| 187 |
+
186 Norwich Terrier
|
| 188 |
+
187 Yorkshire Terrier
|
| 189 |
+
188 Wire Fox Terrier
|
| 190 |
+
189 Lakeland Terrier
|
| 191 |
+
190 Sealyham Terrier
|
| 192 |
+
191 Airedale Terrier
|
| 193 |
+
192 Cairn Terrier
|
| 194 |
+
193 Australian Terrier
|
| 195 |
+
194 Dandie Dinmont Terrier
|
| 196 |
+
195 Boston Terrier
|
| 197 |
+
196 Miniature Schnauzer
|
| 198 |
+
197 Giant Schnauzer
|
| 199 |
+
198 Standard Schnauzer
|
| 200 |
+
199 Scottish Terrier
|
| 201 |
+
200 Tibetan Terrier
|
| 202 |
+
201 Australian Silky Terrier
|
| 203 |
+
202 Soft-coated Wheaten Terrier
|
| 204 |
+
203 West Highland White Terrier
|
| 205 |
+
204 Lhasa Apso
|
| 206 |
+
205 Flat-Coated Retriever
|
| 207 |
+
206 Curly-coated Retriever
|
| 208 |
+
207 Golden Retriever
|
| 209 |
+
208 Labrador Retriever
|
| 210 |
+
209 Chesapeake Bay Retriever
|
| 211 |
+
210 German Shorthaired Pointer
|
| 212 |
+
211 Vizsla
|
| 213 |
+
212 English Setter
|
| 214 |
+
213 Irish Setter
|
| 215 |
+
214 Gordon Setter
|
| 216 |
+
215 Brittany dog
|
| 217 |
+
216 Clumber Spaniel
|
| 218 |
+
217 English Springer Spaniel
|
| 219 |
+
218 Welsh Springer Spaniel
|
| 220 |
+
219 Cocker Spaniel
|
| 221 |
+
220 Sussex Spaniel
|
| 222 |
+
221 Irish Water Spaniel
|
| 223 |
+
222 Kuvasz
|
| 224 |
+
223 Schipperke
|
| 225 |
+
224 Groenendael dog
|
| 226 |
+
225 Malinois
|
| 227 |
+
226 Briard
|
| 228 |
+
227 Australian Kelpie
|
| 229 |
+
228 Komondor
|
| 230 |
+
229 Old English Sheepdog
|
| 231 |
+
230 Shetland Sheepdog
|
| 232 |
+
231 collie
|
| 233 |
+
232 Border Collie
|
| 234 |
+
233 Bouvier des Flandres dog
|
| 235 |
+
234 Rottweiler
|
| 236 |
+
235 German Shepherd Dog
|
| 237 |
+
236 Dobermann
|
| 238 |
+
237 Miniature Pinscher
|
| 239 |
+
238 Greater Swiss Mountain Dog
|
| 240 |
+
239 Bernese Mountain Dog
|
| 241 |
+
240 Appenzeller Sennenhund
|
| 242 |
+
241 Entlebucher Sennenhund
|
| 243 |
+
242 Boxer
|
| 244 |
+
243 Bullmastiff
|
| 245 |
+
244 Tibetan Mastiff
|
| 246 |
+
245 French Bulldog
|
| 247 |
+
246 Great Dane
|
| 248 |
+
247 St. Bernard
|
| 249 |
+
248 husky
|
| 250 |
+
249 Alaskan Malamute
|
| 251 |
+
250 Siberian Husky
|
| 252 |
+
251 Dalmatian
|
| 253 |
+
252 Affenpinscher
|
| 254 |
+
253 Basenji
|
| 255 |
+
254 pug
|
| 256 |
+
255 Leonberger
|
| 257 |
+
256 Newfoundland dog
|
| 258 |
+
257 Great Pyrenees dog
|
| 259 |
+
258 Samoyed
|
| 260 |
+
259 Pomeranian
|
| 261 |
+
260 Chow Chow
|
| 262 |
+
261 Keeshond
|
| 263 |
+
262 brussels griffon
|
| 264 |
+
263 Pembroke Welsh Corgi
|
| 265 |
+
264 Cardigan Welsh Corgi
|
| 266 |
+
265 Toy Poodle
|
| 267 |
+
266 Miniature Poodle
|
| 268 |
+
267 Standard Poodle
|
| 269 |
+
268 Mexican hairless dog (xoloitzcuintli)
|
| 270 |
+
269 grey wolf
|
| 271 |
+
270 Alaskan tundra wolf
|
| 272 |
+
271 red wolf or maned wolf
|
| 273 |
+
272 coyote
|
| 274 |
+
273 dingo
|
| 275 |
+
274 dhole
|
| 276 |
+
275 African wild dog
|
| 277 |
+
276 hyena
|
| 278 |
+
277 red fox
|
| 279 |
+
278 kit fox
|
| 280 |
+
279 Arctic fox
|
| 281 |
+
280 grey fox
|
| 282 |
+
281 tabby cat
|
| 283 |
+
282 tiger cat
|
| 284 |
+
283 Persian cat
|
| 285 |
+
284 Siamese cat
|
| 286 |
+
285 Egyptian Mau
|
| 287 |
+
286 cougar
|
| 288 |
+
287 lynx
|
| 289 |
+
288 leopard
|
| 290 |
+
289 snow leopard
|
| 291 |
+
290 jaguar
|
| 292 |
+
291 lion
|
| 293 |
+
292 tiger
|
| 294 |
+
293 cheetah
|
| 295 |
+
294 brown bear
|
| 296 |
+
295 American black bear
|
| 297 |
+
296 polar bear
|
| 298 |
+
297 sloth bear
|
| 299 |
+
298 mongoose
|
| 300 |
+
299 meerkat
|
| 301 |
+
300 tiger beetle
|
| 302 |
+
301 ladybug
|
| 303 |
+
302 ground beetle
|
| 304 |
+
303 longhorn beetle
|
| 305 |
+
304 leaf beetle
|
| 306 |
+
305 dung beetle
|
| 307 |
+
306 rhinoceros beetle
|
| 308 |
+
307 weevil
|
| 309 |
+
308 fly
|
| 310 |
+
309 bee
|
| 311 |
+
310 ant
|
| 312 |
+
311 grasshopper
|
| 313 |
+
312 cricket insect
|
| 314 |
+
313 stick insect
|
| 315 |
+
314 cockroach
|
| 316 |
+
315 praying mantis
|
| 317 |
+
316 cicada
|
| 318 |
+
317 leafhopper
|
| 319 |
+
318 lacewing
|
| 320 |
+
319 dragonfly
|
| 321 |
+
320 damselfly
|
| 322 |
+
321 red admiral butterfly
|
| 323 |
+
322 ringlet butterfly
|
| 324 |
+
323 monarch butterfly
|
| 325 |
+
324 small white butterfly
|
| 326 |
+
325 sulphur butterfly
|
| 327 |
+
326 gossamer-winged butterfly
|
| 328 |
+
327 starfish
|
| 329 |
+
328 sea urchin
|
| 330 |
+
329 sea cucumber
|
| 331 |
+
330 cottontail rabbit
|
| 332 |
+
331 hare
|
| 333 |
+
332 Angora rabbit
|
| 334 |
+
333 hamster
|
| 335 |
+
334 porcupine
|
| 336 |
+
335 fox squirrel
|
| 337 |
+
336 marmot
|
| 338 |
+
337 beaver
|
| 339 |
+
338 guinea pig
|
| 340 |
+
339 common sorrel horse
|
| 341 |
+
340 zebra
|
| 342 |
+
341 pig
|
| 343 |
+
342 wild boar
|
| 344 |
+
343 warthog
|
| 345 |
+
344 hippopotamus
|
| 346 |
+
345 ox
|
| 347 |
+
346 water buffalo
|
| 348 |
+
347 bison
|
| 349 |
+
348 ram (adult male sheep)
|
| 350 |
+
349 bighorn sheep
|
| 351 |
+
350 Alpine ibex
|
| 352 |
+
351 hartebeest
|
| 353 |
+
352 impala (antelope)
|
| 354 |
+
353 gazelle
|
| 355 |
+
354 arabian camel
|
| 356 |
+
355 llama
|
| 357 |
+
356 weasel
|
| 358 |
+
357 mink
|
| 359 |
+
358 European polecat
|
| 360 |
+
359 black-footed ferret
|
| 361 |
+
360 otter
|
| 362 |
+
361 skunk
|
| 363 |
+
362 badger
|
| 364 |
+
363 armadillo
|
| 365 |
+
364 three-toed sloth
|
| 366 |
+
365 orangutan
|
| 367 |
+
366 gorilla
|
| 368 |
+
367 chimpanzee
|
| 369 |
+
368 gibbon
|
| 370 |
+
369 siamang
|
| 371 |
+
370 guenon
|
| 372 |
+
371 patas monkey
|
| 373 |
+
372 baboon
|
| 374 |
+
373 macaque
|
| 375 |
+
374 langur
|
| 376 |
+
375 black-and-white colobus
|
| 377 |
+
376 proboscis monkey
|
| 378 |
+
377 marmoset
|
| 379 |
+
378 white-headed capuchin
|
| 380 |
+
379 howler monkey
|
| 381 |
+
380 titi monkey
|
| 382 |
+
381 Geoffroy's spider monkey
|
| 383 |
+
382 common squirrel monkey
|
| 384 |
+
383 ring-tailed lemur
|
| 385 |
+
384 indri
|
| 386 |
+
385 Asian elephant
|
| 387 |
+
386 African bush elephant
|
| 388 |
+
387 red panda
|
| 389 |
+
388 giant panda
|
| 390 |
+
389 snoek fish
|
| 391 |
+
390 eel
|
| 392 |
+
391 silver salmon
|
| 393 |
+
392 rock beauty fish
|
| 394 |
+
393 clownfish
|
| 395 |
+
394 sturgeon
|
| 396 |
+
395 gar fish
|
| 397 |
+
396 lionfish
|
| 398 |
+
397 pufferfish
|
| 399 |
+
398 abacus
|
| 400 |
+
399 abaya
|
| 401 |
+
400 academic gown
|
| 402 |
+
401 accordion
|
| 403 |
+
402 acoustic guitar
|
| 404 |
+
403 aircraft carrier
|
| 405 |
+
404 airliner
|
| 406 |
+
405 airship
|
| 407 |
+
406 altar
|
| 408 |
+
407 ambulance
|
| 409 |
+
408 amphibious vehicle
|
| 410 |
+
409 analog clock
|
| 411 |
+
410 apiary
|
| 412 |
+
411 apron
|
| 413 |
+
412 trash can
|
| 414 |
+
413 assault rifle
|
| 415 |
+
414 backpack
|
| 416 |
+
415 bakery
|
| 417 |
+
416 balance beam
|
| 418 |
+
417 balloon
|
| 419 |
+
418 ballpoint pen
|
| 420 |
+
419 Band-Aid
|
| 421 |
+
420 banjo
|
| 422 |
+
421 baluster / handrail
|
| 423 |
+
422 barbell
|
| 424 |
+
423 barber chair
|
| 425 |
+
424 barbershop
|
| 426 |
+
425 barn
|
| 427 |
+
426 barometer
|
| 428 |
+
427 barrel
|
| 429 |
+
428 wheelbarrow
|
| 430 |
+
429 baseball
|
| 431 |
+
430 basketball
|
| 432 |
+
431 bassinet
|
| 433 |
+
432 bassoon
|
| 434 |
+
433 swimming cap
|
| 435 |
+
434 bath towel
|
| 436 |
+
435 bathtub
|
| 437 |
+
436 station wagon
|
| 438 |
+
437 lighthouse
|
| 439 |
+
438 beaker
|
| 440 |
+
439 military hat (bearskin or shako)
|
| 441 |
+
440 beer bottle
|
| 442 |
+
441 beer glass
|
| 443 |
+
442 bell tower
|
| 444 |
+
443 baby bib
|
| 445 |
+
444 tandem bicycle
|
| 446 |
+
445 bikini
|
| 447 |
+
446 ring binder
|
| 448 |
+
447 binoculars
|
| 449 |
+
448 birdhouse
|
| 450 |
+
449 boathouse
|
| 451 |
+
450 bobsleigh
|
| 452 |
+
451 bolo tie
|
| 453 |
+
452 poke bonnet
|
| 454 |
+
453 bookcase
|
| 455 |
+
454 bookstore
|
| 456 |
+
455 bottle cap
|
| 457 |
+
456 hunting bow
|
| 458 |
+
457 bow tie
|
| 459 |
+
458 brass memorial plaque
|
| 460 |
+
459 bra
|
| 461 |
+
460 breakwater
|
| 462 |
+
461 breastplate
|
| 463 |
+
462 broom
|
| 464 |
+
463 bucket
|
| 465 |
+
464 buckle
|
| 466 |
+
465 bulletproof vest
|
| 467 |
+
466 high-speed train
|
| 468 |
+
467 butcher shop
|
| 469 |
+
468 taxicab
|
| 470 |
+
469 cauldron
|
| 471 |
+
470 candle
|
| 472 |
+
471 cannon
|
| 473 |
+
472 canoe
|
| 474 |
+
473 can opener
|
| 475 |
+
474 cardigan
|
| 476 |
+
475 car mirror
|
| 477 |
+
476 carousel
|
| 478 |
+
477 tool kit
|
| 479 |
+
478 cardboard box / carton
|
| 480 |
+
479 car wheel
|
| 481 |
+
480 automated teller machine
|
| 482 |
+
481 cassette
|
| 483 |
+
482 cassette player
|
| 484 |
+
483 castle
|
| 485 |
+
484 catamaran
|
| 486 |
+
485 CD player
|
| 487 |
+
486 cello
|
| 488 |
+
487 mobile phone
|
| 489 |
+
488 chain
|
| 490 |
+
489 chain-link fence
|
| 491 |
+
490 chain mail
|
| 492 |
+
491 chainsaw
|
| 493 |
+
492 storage chest
|
| 494 |
+
493 chiffonier
|
| 495 |
+
494 bell or wind chime
|
| 496 |
+
495 china cabinet
|
| 497 |
+
496 Christmas stocking
|
| 498 |
+
497 church
|
| 499 |
+
498 movie theater
|
| 500 |
+
499 cleaver
|
| 501 |
+
500 cliff dwelling
|
| 502 |
+
501 cloak
|
| 503 |
+
502 clogs
|
| 504 |
+
503 cocktail shaker
|
| 505 |
+
504 coffee mug
|
| 506 |
+
505 coffeemaker
|
| 507 |
+
506 spiral or coil
|
| 508 |
+
507 combination lock
|
| 509 |
+
508 computer keyboard
|
| 510 |
+
509 candy store
|
| 511 |
+
510 container ship
|
| 512 |
+
511 convertible
|
| 513 |
+
512 corkscrew
|
| 514 |
+
513 cornet
|
| 515 |
+
514 cowboy boot
|
| 516 |
+
515 cowboy hat
|
| 517 |
+
516 cradle
|
| 518 |
+
517 construction crane
|
| 519 |
+
518 crash helmet
|
| 520 |
+
519 crate
|
| 521 |
+
520 infant bed
|
| 522 |
+
521 Crock Pot
|
| 523 |
+
522 croquet ball
|
| 524 |
+
523 crutch
|
| 525 |
+
524 cuirass
|
| 526 |
+
525 dam
|
| 527 |
+
526 desk
|
| 528 |
+
527 desktop computer
|
| 529 |
+
528 rotary dial telephone
|
| 530 |
+
529 diaper
|
| 531 |
+
530 digital clock
|
| 532 |
+
531 digital watch
|
| 533 |
+
532 dining table
|
| 534 |
+
533 dishcloth
|
| 535 |
+
534 dishwasher
|
| 536 |
+
535 disc brake
|
| 537 |
+
536 dock
|
| 538 |
+
537 dog sled
|
| 539 |
+
538 dome
|
| 540 |
+
539 doormat
|
| 541 |
+
540 drilling rig
|
| 542 |
+
541 drum
|
| 543 |
+
542 drumstick
|
| 544 |
+
543 dumbbell
|
| 545 |
+
544 Dutch oven
|
| 546 |
+
545 electric fan
|
| 547 |
+
546 electric guitar
|
| 548 |
+
547 electric locomotive
|
| 549 |
+
548 entertainment center
|
| 550 |
+
549 envelope
|
| 551 |
+
550 espresso machine
|
| 552 |
+
551 face powder
|
| 553 |
+
552 feather boa
|
| 554 |
+
553 filing cabinet
|
| 555 |
+
554 fireboat
|
| 556 |
+
555 fire truck
|
| 557 |
+
556 fire screen
|
| 558 |
+
557 flagpole
|
| 559 |
+
558 flute
|
| 560 |
+
559 folding chair
|
| 561 |
+
560 football helmet
|
| 562 |
+
561 forklift
|
| 563 |
+
562 fountain
|
| 564 |
+
563 fountain pen
|
| 565 |
+
564 four-poster bed
|
| 566 |
+
565 freight car
|
| 567 |
+
566 French horn
|
| 568 |
+
567 frying pan
|
| 569 |
+
568 fur coat
|
| 570 |
+
569 garbage truck
|
| 571 |
+
570 gas mask or respirator
|
| 572 |
+
571 gas pump
|
| 573 |
+
572 goblet
|
| 574 |
+
573 go-kart
|
| 575 |
+
574 golf ball
|
| 576 |
+
575 golf cart
|
| 577 |
+
576 gondola
|
| 578 |
+
577 gong
|
| 579 |
+
578 gown
|
| 580 |
+
579 grand piano
|
| 581 |
+
580 greenhouse
|
| 582 |
+
581 radiator grille
|
| 583 |
+
582 grocery store
|
| 584 |
+
583 guillotine
|
| 585 |
+
584 hair clip
|
| 586 |
+
585 hair spray
|
| 587 |
+
586 half-track
|
| 588 |
+
587 hammer
|
| 589 |
+
588 hamper
|
| 590 |
+
589 hair dryer
|
| 591 |
+
590 hand-held computer
|
| 592 |
+
591 handkerchief
|
| 593 |
+
592 hard disk drive
|
| 594 |
+
593 harmonica
|
| 595 |
+
594 harp
|
| 596 |
+
595 combine harvester
|
| 597 |
+
596 hatchet
|
| 598 |
+
597 holster
|
| 599 |
+
598 home theater
|
| 600 |
+
599 honeycomb
|
| 601 |
+
600 hook
|
| 602 |
+
601 hoop skirt
|
| 603 |
+
602 gymnastic horizontal bar
|
| 604 |
+
603 horse-drawn vehicle
|
| 605 |
+
604 hourglass
|
| 606 |
+
605 iPod
|
| 607 |
+
606 clothes iron
|
| 608 |
+
607 carved pumpkin
|
| 609 |
+
608 jeans
|
| 610 |
+
609 jeep
|
| 611 |
+
610 T-shirt
|
| 612 |
+
611 jigsaw puzzle
|
| 613 |
+
612 rickshaw
|
| 614 |
+
613 joystick
|
| 615 |
+
614 kimono
|
| 616 |
+
615 knee pad
|
| 617 |
+
616 knot
|
| 618 |
+
617 lab coat
|
| 619 |
+
618 ladle
|
| 620 |
+
619 lampshade
|
| 621 |
+
620 laptop computer
|
| 622 |
+
621 lawn mower
|
| 623 |
+
622 lens cap
|
| 624 |
+
623 letter opener
|
| 625 |
+
624 library
|
| 626 |
+
625 lifeboat
|
| 627 |
+
626 lighter
|
| 628 |
+
627 limousine
|
| 629 |
+
628 ocean liner
|
| 630 |
+
629 lipstick
|
| 631 |
+
630 slip-on shoe
|
| 632 |
+
631 lotion
|
| 633 |
+
632 music speaker
|
| 634 |
+
633 loupe magnifying glass
|
| 635 |
+
634 sawmill
|
| 636 |
+
635 magnetic compass
|
| 637 |
+
636 messenger bag
|
| 638 |
+
637 mailbox
|
| 639 |
+
638 tights
|
| 640 |
+
639 one-piece bathing suit
|
| 641 |
+
640 manhole cover
|
| 642 |
+
641 maraca
|
| 643 |
+
642 marimba
|
| 644 |
+
643 mask
|
| 645 |
+
644 matchstick
|
| 646 |
+
645 maypole
|
| 647 |
+
646 maze
|
| 648 |
+
647 measuring cup
|
| 649 |
+
648 medicine cabinet
|
| 650 |
+
649 megalith
|
| 651 |
+
650 microphone
|
| 652 |
+
651 microwave oven
|
| 653 |
+
652 military uniform
|
| 654 |
+
653 milk can
|
| 655 |
+
654 minibus
|
| 656 |
+
655 miniskirt
|
| 657 |
+
656 minivan
|
| 658 |
+
657 missile
|
| 659 |
+
658 mitten
|
| 660 |
+
659 mixing bowl
|
| 661 |
+
660 mobile home
|
| 662 |
+
661 ford model t
|
| 663 |
+
662 modem
|
| 664 |
+
663 monastery
|
| 665 |
+
664 monitor
|
| 666 |
+
665 moped
|
| 667 |
+
666 mortar and pestle
|
| 668 |
+
667 graduation cap
|
| 669 |
+
668 mosque
|
| 670 |
+
669 mosquito net
|
| 671 |
+
670 vespa
|
| 672 |
+
671 mountain bike
|
| 673 |
+
672 tent
|
| 674 |
+
673 computer mouse
|
| 675 |
+
674 mousetrap
|
| 676 |
+
675 moving van
|
| 677 |
+
676 muzzle
|
| 678 |
+
677 metal nail
|
| 679 |
+
678 neck brace
|
| 680 |
+
679 necklace
|
| 681 |
+
680 baby pacifier
|
| 682 |
+
681 notebook computer
|
| 683 |
+
682 obelisk
|
| 684 |
+
683 oboe
|
| 685 |
+
684 ocarina
|
| 686 |
+
685 odometer
|
| 687 |
+
686 oil filter
|
| 688 |
+
687 pipe organ
|
| 689 |
+
688 oscilloscope
|
| 690 |
+
689 overskirt
|
| 691 |
+
690 bullock cart
|
| 692 |
+
691 oxygen mask
|
| 693 |
+
692 product packet / packaging
|
| 694 |
+
693 paddle
|
| 695 |
+
694 paddle wheel
|
| 696 |
+
695 padlock
|
| 697 |
+
696 paintbrush
|
| 698 |
+
697 pajamas
|
| 699 |
+
698 palace
|
| 700 |
+
699 pan flute
|
| 701 |
+
700 paper towel
|
| 702 |
+
701 parachute
|
| 703 |
+
702 parallel bars
|
| 704 |
+
703 park bench
|
| 705 |
+
704 parking meter
|
| 706 |
+
705 railroad car
|
| 707 |
+
706 patio
|
| 708 |
+
707 payphone
|
| 709 |
+
708 pedestal
|
| 710 |
+
709 pencil case
|
| 711 |
+
710 pencil sharpener
|
| 712 |
+
711 perfume
|
| 713 |
+
712 Petri dish
|
| 714 |
+
713 photocopier
|
| 715 |
+
714 plectrum
|
| 716 |
+
715 Pickelhaube
|
| 717 |
+
716 picket fence
|
| 718 |
+
717 pickup truck
|
| 719 |
+
718 pier
|
| 720 |
+
719 piggy bank
|
| 721 |
+
720 pill bottle
|
| 722 |
+
721 pillow
|
| 723 |
+
722 ping-pong ball
|
| 724 |
+
723 pinwheel
|
| 725 |
+
724 pirate ship
|
| 726 |
+
725 drink pitcher
|
| 727 |
+
726 block plane
|
| 728 |
+
727 planetarium
|
| 729 |
+
728 plastic bag
|
| 730 |
+
729 plate rack
|
| 731 |
+
730 farm plow
|
| 732 |
+
731 plunger
|
| 733 |
+
732 Polaroid camera
|
| 734 |
+
733 pole
|
| 735 |
+
734 police van
|
| 736 |
+
735 poncho
|
| 737 |
+
736 pool table
|
| 738 |
+
737 soda bottle
|
| 739 |
+
738 plant pot
|
| 740 |
+
739 potter's wheel
|
| 741 |
+
740 power drill
|
| 742 |
+
741 prayer rug
|
| 743 |
+
742 printer
|
| 744 |
+
743 prison
|
| 745 |
+
744 missile
|
| 746 |
+
745 projector
|
| 747 |
+
746 hockey puck
|
| 748 |
+
747 punching bag
|
| 749 |
+
748 purse
|
| 750 |
+
749 quill
|
| 751 |
+
750 quilt
|
| 752 |
+
751 race car
|
| 753 |
+
752 racket
|
| 754 |
+
753 radiator
|
| 755 |
+
754 radio
|
| 756 |
+
755 radio telescope
|
| 757 |
+
756 rain barrel
|
| 758 |
+
757 recreational vehicle
|
| 759 |
+
758 fishing casting reel
|
| 760 |
+
759 reflex camera
|
| 761 |
+
760 refrigerator
|
| 762 |
+
761 remote control
|
| 763 |
+
762 restaurant
|
| 764 |
+
763 revolver
|
| 765 |
+
764 rifle
|
| 766 |
+
765 rocking chair
|
| 767 |
+
766 rotisserie
|
| 768 |
+
767 eraser
|
| 769 |
+
768 rugby ball
|
| 770 |
+
769 ruler measuring stick
|
| 771 |
+
770 sneaker
|
| 772 |
+
771 safe
|
| 773 |
+
772 safety pin
|
| 774 |
+
773 salt shaker
|
| 775 |
+
774 sandal
|
| 776 |
+
775 sarong
|
| 777 |
+
776 saxophone
|
| 778 |
+
777 scabbard
|
| 779 |
+
778 weighing scale
|
| 780 |
+
779 school bus
|
| 781 |
+
780 schooner
|
| 782 |
+
781 scoreboard
|
| 783 |
+
782 CRT monitor
|
| 784 |
+
783 screw
|
| 785 |
+
784 screwdriver
|
| 786 |
+
785 seat belt
|
| 787 |
+
786 sewing machine
|
| 788 |
+
787 shield
|
| 789 |
+
788 shoe store
|
| 790 |
+
789 shoji screen / room divider
|
| 791 |
+
790 shopping basket
|
| 792 |
+
791 shopping cart
|
| 793 |
+
792 shovel
|
| 794 |
+
793 shower cap
|
| 795 |
+
794 shower curtain
|
| 796 |
+
795 ski
|
| 797 |
+
796 balaclava ski mask
|
| 798 |
+
797 sleeping bag
|
| 799 |
+
798 slide rule
|
| 800 |
+
799 sliding door
|
| 801 |
+
800 slot machine
|
| 802 |
+
801 snorkel
|
| 803 |
+
802 snowmobile
|
| 804 |
+
803 snowplow
|
| 805 |
+
804 soap dispenser
|
| 806 |
+
805 soccer ball
|
| 807 |
+
806 sock
|
| 808 |
+
807 solar thermal collector
|
| 809 |
+
808 sombrero
|
| 810 |
+
809 soup bowl
|
| 811 |
+
810 keyboard space bar
|
| 812 |
+
811 space heater
|
| 813 |
+
812 space shuttle
|
| 814 |
+
813 spatula
|
| 815 |
+
814 motorboat
|
| 816 |
+
815 spider web
|
| 817 |
+
816 spindle
|
| 818 |
+
817 sports car
|
| 819 |
+
818 spotlight
|
| 820 |
+
819 stage
|
| 821 |
+
820 steam locomotive
|
| 822 |
+
821 through arch bridge
|
| 823 |
+
822 steel drum
|
| 824 |
+
823 stethoscope
|
| 825 |
+
824 scarf
|
| 826 |
+
825 stone wall
|
| 827 |
+
826 stopwatch
|
| 828 |
+
827 stove
|
| 829 |
+
828 strainer
|
| 830 |
+
829 tram
|
| 831 |
+
830 stretcher
|
| 832 |
+
831 couch
|
| 833 |
+
832 stupa
|
| 834 |
+
833 submarine
|
| 835 |
+
834 suit
|
| 836 |
+
835 sundial
|
| 837 |
+
836 sunglasses
|
| 838 |
+
837 sunglasses
|
| 839 |
+
838 sunscreen
|
| 840 |
+
839 suspension bridge
|
| 841 |
+
840 mop
|
| 842 |
+
841 sweatshirt
|
| 843 |
+
842 swim trunks / shorts
|
| 844 |
+
843 swing
|
| 845 |
+
844 electrical switch
|
| 846 |
+
845 syringe
|
| 847 |
+
846 table lamp
|
| 848 |
+
847 tank
|
| 849 |
+
848 tape player
|
| 850 |
+
849 teapot
|
| 851 |
+
850 teddy bear
|
| 852 |
+
851 television
|
| 853 |
+
852 tennis ball
|
| 854 |
+
853 thatched roof
|
| 855 |
+
854 front curtain
|
| 856 |
+
855 thimble
|
| 857 |
+
856 threshing machine
|
| 858 |
+
857 throne
|
| 859 |
+
858 tile roof
|
| 860 |
+
859 toaster
|
| 861 |
+
860 tobacco shop
|
| 862 |
+
861 toilet seat
|
| 863 |
+
862 torch
|
| 864 |
+
863 totem pole
|
| 865 |
+
864 tow truck
|
| 866 |
+
865 toy store
|
| 867 |
+
866 tractor
|
| 868 |
+
867 semi-trailer truck
|
| 869 |
+
868 tray
|
| 870 |
+
869 trench coat
|
| 871 |
+
870 tricycle
|
| 872 |
+
871 trimaran
|
| 873 |
+
872 tripod
|
| 874 |
+
873 triumphal arch
|
| 875 |
+
874 trolleybus
|
| 876 |
+
875 trombone
|
| 877 |
+
876 hot tub
|
| 878 |
+
877 turnstile
|
| 879 |
+
878 typewriter keyboard
|
| 880 |
+
879 umbrella
|
| 881 |
+
880 unicycle
|
| 882 |
+
881 upright piano
|
| 883 |
+
882 vacuum cleaner
|
| 884 |
+
883 vase
|
| 885 |
+
884 vaulted or arched ceiling
|
| 886 |
+
885 velvet fabric
|
| 887 |
+
886 vending machine
|
| 888 |
+
887 vestment
|
| 889 |
+
888 viaduct
|
| 890 |
+
889 violin
|
| 891 |
+
890 volleyball
|
| 892 |
+
891 waffle iron
|
| 893 |
+
892 wall clock
|
| 894 |
+
893 wallet
|
| 895 |
+
894 wardrobe
|
| 896 |
+
895 military aircraft
|
| 897 |
+
896 sink
|
| 898 |
+
897 washing machine
|
| 899 |
+
898 water bottle
|
| 900 |
+
899 water jug
|
| 901 |
+
900 water tower
|
| 902 |
+
901 whiskey jug
|
| 903 |
+
902 whistle
|
| 904 |
+
903 hair wig
|
| 905 |
+
904 window screen
|
| 906 |
+
905 window shade
|
| 907 |
+
906 Windsor tie
|
| 908 |
+
907 wine bottle
|
| 909 |
+
908 airplane wing
|
| 910 |
+
909 wok
|
| 911 |
+
910 wooden spoon
|
| 912 |
+
911 wool
|
| 913 |
+
912 split-rail fence
|
| 914 |
+
913 shipwreck
|
| 915 |
+
914 sailboat
|
| 916 |
+
915 yurt
|
| 917 |
+
916 website
|
| 918 |
+
917 comic book
|
| 919 |
+
918 crossword
|
| 920 |
+
919 traffic or street sign
|
| 921 |
+
920 traffic light
|
| 922 |
+
921 dust jacket
|
| 923 |
+
922 menu
|
| 924 |
+
923 plate
|
| 925 |
+
924 guacamole
|
| 926 |
+
925 consomme
|
| 927 |
+
926 hot pot
|
| 928 |
+
927 trifle
|
| 929 |
+
928 ice cream
|
| 930 |
+
929 popsicle
|
| 931 |
+
930 baguette
|
| 932 |
+
931 bagel
|
| 933 |
+
932 pretzel
|
| 934 |
+
933 cheeseburger
|
| 935 |
+
934 hot dog
|
| 936 |
+
935 mashed potatoes
|
| 937 |
+
936 cabbage
|
| 938 |
+
937 broccoli
|
| 939 |
+
938 cauliflower
|
| 940 |
+
939 zucchini
|
| 941 |
+
940 spaghetti squash
|
| 942 |
+
941 acorn squash
|
| 943 |
+
942 butternut squash
|
| 944 |
+
943 cucumber
|
| 945 |
+
944 artichoke
|
| 946 |
+
945 bell pepper
|
| 947 |
+
946 cardoon
|
| 948 |
+
947 mushroom
|
| 949 |
+
948 Granny Smith apple
|
| 950 |
+
949 strawberry
|
| 951 |
+
950 orange
|
| 952 |
+
951 lemon
|
| 953 |
+
952 fig
|
| 954 |
+
953 pineapple
|
| 955 |
+
954 banana
|
| 956 |
+
955 jackfruit
|
| 957 |
+
956 cherimoya (custard apple)
|
| 958 |
+
957 pomegranate
|
| 959 |
+
958 hay
|
| 960 |
+
959 carbonara
|
| 961 |
+
960 chocolate syrup
|
| 962 |
+
961 dough
|
| 963 |
+
962 meatloaf
|
| 964 |
+
963 pizza
|
| 965 |
+
964 pot pie
|
| 966 |
+
965 burrito
|
| 967 |
+
966 red wine
|
| 968 |
+
967 espresso
|
| 969 |
+
968 tea cup
|
| 970 |
+
969 eggnog
|
| 971 |
+
970 mountain
|
| 972 |
+
971 bubble
|
| 973 |
+
972 cliff
|
| 974 |
+
973 coral reef
|
| 975 |
+
974 geyser
|
| 976 |
+
975 lakeshore
|
| 977 |
+
976 promontory
|
| 978 |
+
977 sandbar
|
| 979 |
+
978 beach
|
| 980 |
+
979 valley
|
| 981 |
+
980 volcano
|
| 982 |
+
981 baseball player
|
| 983 |
+
982 bridegroom
|
| 984 |
+
983 scuba diver
|
| 985 |
+
984 rapeseed
|
| 986 |
+
985 daisy
|
| 987 |
+
986 yellow lady's slipper
|
| 988 |
+
987 corn
|
| 989 |
+
988 acorn
|
| 990 |
+
989 rose hip
|
| 991 |
+
990 horse chestnut seed
|
| 992 |
+
991 coral fungus
|
| 993 |
+
992 agaric
|
| 994 |
+
993 gyromitra
|
| 995 |
+
994 stinkhorn mushroom
|
| 996 |
+
995 earth star fungus
|
| 997 |
+
996 hen of the woods mushroom
|
| 998 |
+
997 bolete
|
| 999 |
+
998 corn cob
|
| 1000 |
+
999 toilet paper
|
dataset_reqs/imagenet100_classes.txt
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0 n01440764 tench
|
| 2 |
+
1 n01443537 goldfish
|
| 3 |
+
2 n01484850 great white shark
|
| 4 |
+
3 n01491361 tiger shark
|
| 5 |
+
4 n01494475 hammerhead shark
|
| 6 |
+
5 n01496331 electric ray
|
| 7 |
+
6 n01498041 stingray
|
| 8 |
+
7 n01514668 rooster
|
| 9 |
+
8 n01514859 hen
|
| 10 |
+
9 n01518878 ostrich
|
| 11 |
+
10 n01530575 brambling
|
| 12 |
+
11 n01531178 goldfinch
|
| 13 |
+
12 n01532829 house finch
|
| 14 |
+
13 n01534433 junco
|
| 15 |
+
14 n01537544 indigo bunting
|
| 16 |
+
15 n01558993 American robin
|
| 17 |
+
16 n01560419 bulbul
|
| 18 |
+
17 n01580077 jay
|
| 19 |
+
18 n01582220 magpie
|
| 20 |
+
19 n01592084 chickadee
|
| 21 |
+
20 n01601694 American dipper
|
| 22 |
+
21 n01608432 kite (bird of prey)
|
| 23 |
+
22 n01614925 bald eagle
|
| 24 |
+
23 n01616318 vulture
|
| 25 |
+
24 n01622779 great grey owl
|
| 26 |
+
25 n01629819 fire salamander
|
| 27 |
+
26 n01630670 smooth newt
|
| 28 |
+
27 n01631663 newt
|
| 29 |
+
28 n01632458 spotted salamander
|
| 30 |
+
29 n01632777 axolotl
|
| 31 |
+
30 n01641577 American bullfrog
|
| 32 |
+
31 n01644373 tree frog
|
| 33 |
+
32 n01644900 tailed frog
|
| 34 |
+
33 n01664065 loggerhead sea turtle
|
| 35 |
+
34 n01665541 leatherback sea turtle
|
| 36 |
+
35 n01667114 mud turtle
|
| 37 |
+
36 n01667778 terrapin
|
| 38 |
+
37 n01669191 box turtle
|
| 39 |
+
38 n01675722 banded gecko
|
| 40 |
+
39 n01677366 green iguana
|
| 41 |
+
40 n01682714 Carolina anole
|
| 42 |
+
41 n01685808 desert grassland whiptail lizard
|
| 43 |
+
42 n01687978 agama
|
| 44 |
+
43 n01688243 frilled-necked lizard
|
| 45 |
+
44 n01689811 alligator lizard
|
| 46 |
+
45 n01692333 Gila monster
|
| 47 |
+
46 n01693334 European green lizard
|
| 48 |
+
47 n01694178 chameleon
|
| 49 |
+
48 n01695060 Komodo dragon
|
| 50 |
+
49 n01697457 Nile crocodile
|
| 51 |
+
50 n01698640 American alligator
|
| 52 |
+
51 n01704323 triceratops
|
| 53 |
+
52 n01728572 worm snake
|
| 54 |
+
53 n01728920 ring-necked snake
|
| 55 |
+
54 n01729322 eastern hog-nosed snake
|
| 56 |
+
55 n01729977 smooth green snake
|
| 57 |
+
56 n01734418 kingsnake
|
| 58 |
+
57 n01735189 garter snake
|
| 59 |
+
58 n01737021 water snake
|
| 60 |
+
59 n01739381 vine snake
|
| 61 |
+
60 n01740131 night snake
|
| 62 |
+
61 n01742172 boa constrictor
|
| 63 |
+
62 n01744401 African rock python
|
| 64 |
+
63 n01748264 Indian cobra
|
| 65 |
+
64 n01749939 green mamba
|
| 66 |
+
65 n01751748 sea snake
|
| 67 |
+
66 n01753488 Saharan horned viper
|
| 68 |
+
67 n01755581 eastern diamondback rattlesnake
|
| 69 |
+
68 n01756291 sidewinder rattlesnake
|
| 70 |
+
69 n01768244 trilobite
|
| 71 |
+
70 n01770081 harvestman
|
| 72 |
+
71 n01770393 scorpion
|
| 73 |
+
72 n01773157 yellow garden spider
|
| 74 |
+
73 n01773549 barn spider
|
| 75 |
+
74 n01773797 European garden spider
|
| 76 |
+
75 n01774384 southern black widow
|
| 77 |
+
76 n01774750 tarantula
|
| 78 |
+
77 n01775062 wolf spider
|
| 79 |
+
78 n01776313 tick
|
| 80 |
+
79 n01784675 centipede
|
| 81 |
+
80 n01795545 black grouse
|
| 82 |
+
81 n01796340 ptarmigan
|
| 83 |
+
82 n01797886 ruffed grouse
|
| 84 |
+
83 n01798484 prairie grouse
|
| 85 |
+
84 n01806143 peafowl
|
| 86 |
+
85 n01806567 quail
|
| 87 |
+
86 n01807496 partridge
|
| 88 |
+
87 n01817953 african grey parrot
|
| 89 |
+
88 n01818515 macaw
|
| 90 |
+
89 n01819313 sulphur-crested cockatoo
|
| 91 |
+
90 n01820546 lorikeet
|
| 92 |
+
91 n01824575 coucal
|
| 93 |
+
92 n01828970 bee eater
|
| 94 |
+
93 n01829413 hornbill
|
| 95 |
+
94 n01833805 hummingbird
|
| 96 |
+
95 n01843065 jacamar
|
| 97 |
+
96 n01843383 toucan
|
| 98 |
+
97 n01847000 duck
|
| 99 |
+
98 n01855032 red-breasted merganser
|
| 100 |
+
99 n01855672 goose
|
dataset_reqs/imagenet100_splits/train_100.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset_reqs/imagenet100_splits/val_100.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset_reqs/tinyimagenet_classes.txt
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0 n02124075 Egyptian Mau
|
| 2 |
+
1 n04067472 fishing casting reel
|
| 3 |
+
2 n04540053 volleyball
|
| 4 |
+
3 n04099969 rocking chair
|
| 5 |
+
4 n07749582 lemon
|
| 6 |
+
5 n01641577 American bullfrog
|
| 7 |
+
6 n02802426 basketball
|
| 8 |
+
7 n09246464 cliff
|
| 9 |
+
8 n07920052 espresso
|
| 10 |
+
9 n03970156 plunger
|
| 11 |
+
10 n03891332 parking meter
|
| 12 |
+
11 n02106662 German Shepherd Dog
|
| 13 |
+
12 n03201208 dining table
|
| 14 |
+
13 n02279972 monarch butterfly
|
| 15 |
+
14 n02132136 brown bear
|
| 16 |
+
15 n04146614 school bus
|
| 17 |
+
16 n07873807 pizza
|
| 18 |
+
17 n02364673 guinea pig
|
| 19 |
+
18 n04507155 umbrella
|
| 20 |
+
19 n03854065 pipe organ
|
| 21 |
+
20 n03838899 oboe
|
| 22 |
+
21 n03733131 maypole
|
| 23 |
+
22 n01443537 goldfish
|
| 24 |
+
23 n07875152 pot pie
|
| 25 |
+
24 n03544143 hourglass
|
| 26 |
+
25 n09428293 beach
|
| 27 |
+
26 n03085013 computer keyboard
|
| 28 |
+
27 n02437312 arabian camel
|
| 29 |
+
28 n07614500 ice cream
|
| 30 |
+
29 n03804744 metal nail
|
| 31 |
+
30 n04265275 space heater
|
| 32 |
+
31 n02963159 cardigan
|
| 33 |
+
32 n02486410 baboon
|
| 34 |
+
33 n01944390 snail
|
| 35 |
+
34 n09256479 coral reef
|
| 36 |
+
35 n02058221 albatross
|
| 37 |
+
36 n04275548 spider web
|
| 38 |
+
37 n02321529 sea cucumber
|
| 39 |
+
38 n02769748 backpack
|
| 40 |
+
39 n02099712 Labrador Retriever
|
| 41 |
+
40 n07695742 pretzel
|
| 42 |
+
41 n02056570 king penguin
|
| 43 |
+
42 n02281406 sulphur butterfly
|
| 44 |
+
43 n01774750 tarantula
|
| 45 |
+
44 n02509815 red panda
|
| 46 |
+
45 n03983396 soda bottle
|
| 47 |
+
46 n07753592 banana
|
| 48 |
+
47 n04254777 sock
|
| 49 |
+
48 n02233338 cockroach
|
| 50 |
+
49 n04008634 missile
|
| 51 |
+
50 n02823428 beer bottle
|
| 52 |
+
51 n02236044 praying mantis
|
| 53 |
+
52 n03393912 freight car
|
| 54 |
+
53 n07583066 guacamole
|
| 55 |
+
54 n04074963 remote control
|
| 56 |
+
55 n01629819 fire salamander
|
| 57 |
+
56 n09332890 lakeshore
|
| 58 |
+
57 n02481823 chimpanzee
|
| 59 |
+
58 n03902125 payphone
|
| 60 |
+
59 n03404251 fur coat
|
| 61 |
+
60 n09193705 mountain
|
| 62 |
+
61 n03637318 lampshade
|
| 63 |
+
62 n04456115 torch
|
| 64 |
+
63 n02666196 abacus
|
| 65 |
+
64 n03796401 moving van
|
| 66 |
+
65 n02795169 barrel
|
| 67 |
+
66 n02123045 tabby cat
|
| 68 |
+
67 n01855672 goose
|
| 69 |
+
68 n01882714 koala
|
| 70 |
+
69 n02917067 high-speed train
|
| 71 |
+
70 n02988304 CD player
|
| 72 |
+
71 n04398044 teapot
|
| 73 |
+
72 n02843684 birdhouse
|
| 74 |
+
73 n02423022 gazelle
|
| 75 |
+
74 n02669723 academic gown
|
| 76 |
+
75 n04465501 tractor
|
| 77 |
+
76 n02165456 ladybug
|
| 78 |
+
77 n03770439 miniskirt
|
| 79 |
+
78 n02099601 Golden Retriever
|
| 80 |
+
79 n04486054 triumphal arch
|
| 81 |
+
80 n02950826 cannon
|
| 82 |
+
81 n03814639 neck brace
|
| 83 |
+
82 n04259630 sombrero
|
| 84 |
+
83 n03424325 gas mask or respirator
|
| 85 |
+
84 n02948072 candle
|
| 86 |
+
85 n03179701 desk
|
| 87 |
+
86 n03400231 frying pan
|
| 88 |
+
87 n02206856 bee
|
| 89 |
+
88 n03160309 dam
|
| 90 |
+
89 n01984695 spiny lobster
|
| 91 |
+
90 n03977966 police van
|
| 92 |
+
91 n03584254 iPod
|
| 93 |
+
92 n04023962 punching bag
|
| 94 |
+
93 n02814860 lighthouse
|
| 95 |
+
94 n01910747 jellyfish
|
| 96 |
+
95 n04596742 wok
|
| 97 |
+
96 n03992509 potter's wheel
|
| 98 |
+
97 n04133789 sandal
|
| 99 |
+
98 n03937543 pill bottle
|
| 100 |
+
99 n02927161 butcher shop
|
| 101 |
+
100 n01945685 slug
|
| 102 |
+
101 n02395406 pig
|
| 103 |
+
102 n02125311 cougar
|
| 104 |
+
103 n03126707 construction crane
|
| 105 |
+
104 n04532106 vestment
|
| 106 |
+
105 n02268443 dragonfly
|
| 107 |
+
106 n02977058 automated teller machine
|
| 108 |
+
107 n07734744 mushroom
|
| 109 |
+
108 n03599486 rickshaw
|
| 110 |
+
109 n04562935 water tower
|
| 111 |
+
110 n03014705 storage chest
|
| 112 |
+
111 n04251144 snorkel
|
| 113 |
+
112 n04356056 sunglasses
|
| 114 |
+
113 n02190166 fly
|
| 115 |
+
114 n03670208 limousine
|
| 116 |
+
115 n02002724 black stork
|
| 117 |
+
116 n02074367 dugong
|
| 118 |
+
117 n04285008 sports car
|
| 119 |
+
118 n04560804 water jug
|
| 120 |
+
119 n04366367 suspension bridge
|
| 121 |
+
120 n02403003 ox
|
| 122 |
+
121 n07615774 popsicle
|
| 123 |
+
122 n04501370 turnstile
|
| 124 |
+
123 n03026506 Christmas stocking
|
| 125 |
+
124 n02906734 broom
|
| 126 |
+
125 n01770393 scorpion
|
| 127 |
+
126 n04597913 wooden spoon
|
| 128 |
+
127 n03930313 picket fence
|
| 129 |
+
128 n04118538 rugby ball
|
| 130 |
+
129 n04179913 sewing machine
|
| 131 |
+
130 n04311004 through arch bridge
|
| 132 |
+
131 n02123394 Persian cat
|
| 133 |
+
132 n04070727 refrigerator
|
| 134 |
+
133 n02793495 barn
|
| 135 |
+
134 n02730930 apron
|
| 136 |
+
135 n02094433 Yorkshire Terrier
|
| 137 |
+
136 n04371430 swim trunks / shorts
|
| 138 |
+
137 n04328186 stopwatch
|
| 139 |
+
138 n03649909 lawn mower
|
| 140 |
+
139 n04417672 thatched roof
|
| 141 |
+
140 n03388043 fountain
|
| 142 |
+
141 n01774384 southern black widow
|
| 143 |
+
142 n02837789 bikini
|
| 144 |
+
143 n07579787 plate
|
| 145 |
+
144 n04399382 teddy bear
|
| 146 |
+
145 n02791270 barbershop
|
| 147 |
+
146 n03089624 candy store
|
| 148 |
+
147 n02814533 station wagon
|
| 149 |
+
148 n04149813 scoreboard
|
| 150 |
+
149 n07747607 orange
|
| 151 |
+
150 n03355925 flagpole
|
| 152 |
+
151 n01983481 American lobster
|
| 153 |
+
152 n04487081 trolleybus
|
| 154 |
+
153 n03250847 drumstick
|
| 155 |
+
154 n03255030 dumbbell
|
| 156 |
+
155 n02892201 brass memorial plaque
|
| 157 |
+
156 n02883205 bow tie
|
| 158 |
+
157 n03100240 convertible
|
| 159 |
+
158 n02415577 bighorn sheep
|
| 160 |
+
159 n02480495 orangutan
|
| 161 |
+
160 n01698640 American alligator
|
| 162 |
+
161 n01784675 centipede
|
| 163 |
+
162 n04376876 syringe
|
| 164 |
+
163 n03444034 go-kart
|
| 165 |
+
164 n01917289 brain coral
|
| 166 |
+
165 n01950731 sea slug
|
| 167 |
+
166 n03042490 cliff dwelling
|
| 168 |
+
167 n07711569 mashed potatoes
|
| 169 |
+
168 n04532670 viaduct
|
| 170 |
+
169 n03763968 military uniform
|
| 171 |
+
170 n07768694 pomegranate
|
| 172 |
+
171 n02999410 chain
|
| 173 |
+
172 n03617480 kimono
|
| 174 |
+
173 n06596364 comic book
|
| 175 |
+
174 n01768244 trilobite
|
| 176 |
+
175 n02410509 bison
|
| 177 |
+
176 n03976657 pole
|
| 178 |
+
177 n01742172 boa constrictor
|
| 179 |
+
178 n03980874 poncho
|
| 180 |
+
179 n02808440 bathtub
|
| 181 |
+
180 n02226429 grasshopper
|
| 182 |
+
181 n02231487 stick insect
|
| 183 |
+
182 n02085620 Chihuahua
|
| 184 |
+
183 n01644900 tailed frog
|
| 185 |
+
184 n02129165 lion
|
| 186 |
+
185 n02699494 altar
|
| 187 |
+
186 n03837869 obelisk
|
| 188 |
+
187 n02815834 beaker
|
| 189 |
+
188 n07720875 bell pepper
|
| 190 |
+
189 n02788148 baluster / handrail
|
| 191 |
+
190 n02909870 bucket
|
| 192 |
+
191 n03706229 magnetic compass
|
| 193 |
+
192 n07871810 meatloaf
|
| 194 |
+
193 n03447447 gondola
|
| 195 |
+
194 n02113799 Standard Poodle
|
| 196 |
+
195 n12267677 acorn
|
| 197 |
+
196 n03662601 lifeboat
|
| 198 |
+
197 n02841315 binoculars
|
| 199 |
+
198 n07715103 cauliflower
|
| 200 |
+
199 n02504458 African bush elephant
|
main.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import hydra
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
from omegaconf import DictConfig
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import torch
|
| 11 |
+
import statistics
|
| 12 |
+
from continuum.metrics import Logger
|
| 13 |
+
from continual_clip import utils
|
| 14 |
+
from continual_clip.models import load_model
|
| 15 |
+
from continual_clip.datasets import build_cl_scenarios
|
| 16 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 17 |
+
|
| 18 |
+
WORLD_NUM = 1
|
| 19 |
+
@hydra.main(config_path=None, config_name=None, version_base="1.1")
|
| 20 |
+
def continual_clip(cfg: DictConfig) -> None:
|
| 21 |
+
|
| 22 |
+
set_seed(RANDOM_SEED)
|
| 23 |
+
|
| 24 |
+
cfg.workdir = "/***/DMNSP/cil"
|
| 25 |
+
cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root)
|
| 26 |
+
|
| 27 |
+
utils.save_config(cfg)
|
| 28 |
+
cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order))
|
| 29 |
+
origin_flag = False
|
| 30 |
+
devices = [0]
|
| 31 |
+
model = load_model(cfg, devices[0], origin_flag)
|
| 32 |
+
|
| 33 |
+
eval_dataset, classes_names = build_cl_scenarios(
|
| 34 |
+
cfg, is_train=False, transforms=model.transforms
|
| 35 |
+
)
|
| 36 |
+
print(eval_dataset, eval_dataset)
|
| 37 |
+
train_dataset, train_classes_names = build_cl_scenarios(
|
| 38 |
+
cfg, is_train=True, transforms=model.transforms
|
| 39 |
+
)
|
| 40 |
+
model.classes_names = classes_names
|
| 41 |
+
|
| 42 |
+
print("Using devices", devices)
|
| 43 |
+
model = torch.nn.DataParallel(model, device_ids=devices)
|
| 44 |
+
|
| 45 |
+
with open(cfg.log_path, 'w+') as f:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
acc_list = []
|
| 49 |
+
forgetting_list = []
|
| 50 |
+
metric_logger = Logger(list_subsets=["test"])
|
| 51 |
+
world = WORLD_NUM
|
| 52 |
+
|
| 53 |
+
for task_id, _ in enumerate(eval_dataset):
|
| 54 |
+
|
| 55 |
+
logging.info(f"Evaluation for task {task_id} has started.")
|
| 56 |
+
|
| 57 |
+
model.module.adaptation(task_id, cfg, train_dataset, train_classes_names, world) # task id 已经传入mode
|
| 58 |
+
eval_sampler = DistributedSampler(eval_dataset[:task_id + 1], num_replicas=world, rank=0)
|
| 59 |
+
eval_loader = DataLoader(eval_dataset[:task_id + 1], batch_size=64, sampler=eval_sampler, num_workers=8)
|
| 60 |
+
|
| 61 |
+
for inputs, targets, task_ids in tqdm(eval_loader):
|
| 62 |
+
inputs, targets = inputs.cuda(device=devices[0]), targets.cuda(device=devices[0])
|
| 63 |
+
outputs = model.module.cuda(devices[0])(inputs.cuda(devices[0]), task_ids)
|
| 64 |
+
metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
acc_list.append(100 * metric_logger.accuracy)
|
| 68 |
+
forgetting_list.append(100 * metric_logger.forgetting)
|
| 69 |
+
|
| 70 |
+
with open(cfg.log_path, 'a+') as f:
|
| 71 |
+
f.write(json.dumps({
|
| 72 |
+
'task': task_id,
|
| 73 |
+
'acc': round(100 * metric_logger.accuracy, 2),
|
| 74 |
+
'avg_acc': round(100 * metric_logger.average_incremental_accuracy, 2),
|
| 75 |
+
'forgetting': round(100 * metric_logger.forgetting, 6),
|
| 76 |
+
'acc_per_task': [round(100 * acc_t, 2) for acc_t in metric_logger.accuracy_per_task],
|
| 77 |
+
'bwt': round(100 * metric_logger.backward_transfer, 2),
|
| 78 |
+
'fwt': round(100 * metric_logger.forward_transfer, 2),
|
| 79 |
+
}) + '\n')
|
| 80 |
+
metric_logger.end_task()
|
| 81 |
+
|
| 82 |
+
with open(cfg.log_path, 'a+') as f:
|
| 83 |
+
f.write(json.dumps({
|
| 84 |
+
'last_Cifar100': round(acc_list[-1], 2),
|
| 85 |
+
'avg_Cifar100': round(statistics.mean(acc_list), 2),
|
| 86 |
+
'avg_forgetting': round(statistics.mean(forgetting_list), 2)
|
| 87 |
+
}) + '\n')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Seeds: 386, 2345, 157 (Performance might slightly vary across different machines)
|
| 91 |
+
RANDOM_SEED = 386
|
| 92 |
+
|
| 93 |
+
def set_seed(seed):
|
| 94 |
+
random.seed(seed)
|
| 95 |
+
np.random.seed(seed)
|
| 96 |
+
torch.manual_seed(seed)
|
| 97 |
+
torch.cuda.manual_seed(seed)
|
| 98 |
+
torch.cuda.manual_seed_all(seed)
|
| 99 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 100 |
+
torch.backends.cudnn.deterministic = True
|
| 101 |
+
torch.backends.cudnn.benchmark = False
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
continual_clip()
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
continuum
|
| 2 |
+
hydra-core==1.2.0
|
| 3 |
+
numpy
|
| 4 |
+
oauthlib==3.2.1
|
| 5 |
+
omegaconf==2.2.3
|
| 6 |
+
open-clip-torch==1.3.0
|
| 7 |
+
pandas==1.4.3
|
| 8 |
+
Pillow==9.2.0
|
| 9 |
+
pipreqs==0.4.11
|
| 10 |
+
scikit-image==0.19.3
|
| 11 |
+
scikit-learn==1.1.1
|
| 12 |
+
scipy==1.8.1
|
| 13 |
+
tensorboard==2.10.0
|
| 14 |
+
timm @ git+https://github.com/Arnav0400/pytorch-image-models.git@ceea7127c1ef608179ba06eaeddc22ad3ef22de0
|
| 15 |
+
tokenizers==0.12.1
|
| 16 |
+
tqdm==4.64.0
|
| 17 |
+
transformers==4.21.1
|
| 18 |
+
ftfy
|
| 19 |
+
regex
|
run_cifar100-10-10.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!bin/bash
|
| 2 |
+
|
| 3 |
+
python main.py \
|
| 4 |
+
--config-path /DMNSP/configs/class \
|
| 5 |
+
--config-name cifar100_10-10.yaml \
|
| 6 |
+
dataset_root="/data/**/" \
|
| 7 |
+
class_order="/DMNSP/class_orders/cifar100.yaml"
|
| 8 |
+
|
| 9 |
+
|
templates/__init__.py
ADDED
|
File without changes
|
templates/fmow_template.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .template_utils import append_proper_article
|
| 2 |
+
|
| 3 |
+
fmow_template = [
|
| 4 |
+
lambda c : f"satellite photo of a {c}.",
|
| 5 |
+
lambda c : f"aerial photo of a {c}.",
|
| 6 |
+
lambda c : f"satellite photo of {append_proper_article(c)}.",
|
| 7 |
+
lambda c : f"aerial photo of {append_proper_article(c)}.",
|
| 8 |
+
lambda c : f"satellite photo of a {c} in asia.",
|
| 9 |
+
lambda c : f"aerial photo of a {c} in asia.",
|
| 10 |
+
lambda c : f"satellite photo of a {c} in africa.",
|
| 11 |
+
lambda c : f"aerial photo of a {c} in africa.",
|
| 12 |
+
lambda c : f"satellite photo of a {c} in the americas.",
|
| 13 |
+
lambda c : f"aerial photo of a {c} in the americas.",
|
| 14 |
+
lambda c : f"satellite photo of a {c} in europe.",
|
| 15 |
+
lambda c : f"aerial photo of a {c} in europe.",
|
| 16 |
+
lambda c : f"satellite photo of a {c} in oceania.",
|
| 17 |
+
lambda c : f"aerial photo of a {c} in oceania.",
|
| 18 |
+
lambda c: f"a photo of a {c}.",
|
| 19 |
+
lambda c: f"{c}.",
|
| 20 |
+
]
|
templates/iwildcam_template.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
iwildcam_template = [
|
| 2 |
+
lambda c: f"a photo of {c}.",
|
| 3 |
+
lambda c: f"{c} in the wild.",
|
| 4 |
+
]
|
templates/openai_imagenet_template.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openai_imagenet_template = [
|
| 2 |
+
lambda c: f'a bad photo of a {c}.',
|
| 3 |
+
lambda c: f'a photo of many {c}.',
|
| 4 |
+
lambda c: f'a sculpture of a {c}.',
|
| 5 |
+
lambda c: f'a photo of the hard to see {c}.',
|
| 6 |
+
lambda c: f'a low resolution photo of the {c}.',
|
| 7 |
+
lambda c: f'a rendering of a {c}.',
|
| 8 |
+
lambda c: f'graffiti of a {c}.',
|
| 9 |
+
lambda c: f'a bad photo of the {c}.',
|
| 10 |
+
lambda c: f'a cropped photo of the {c}.',
|
| 11 |
+
lambda c: f'a tattoo of a {c}.',
|
| 12 |
+
lambda c: f'the embroidered {c}.',
|
| 13 |
+
lambda c: f'a photo of a hard to see {c}.',
|
| 14 |
+
lambda c: f'a bright photo of a {c}.',
|
| 15 |
+
lambda c: f'a photo of a clean {c}.',
|
| 16 |
+
lambda c: f'a photo of a dirty {c}.',
|
| 17 |
+
lambda c: f'a dark photo of the {c}.',
|
| 18 |
+
lambda c: f'a drawing of a {c}.',
|
| 19 |
+
lambda c: f'a photo of my {c}.',
|
| 20 |
+
lambda c: f'the plastic {c}.',
|
| 21 |
+
lambda c: f'a photo of the cool {c}.',
|
| 22 |
+
lambda c: f'a close-up photo of a {c}.',
|
| 23 |
+
lambda c: f'a black and white photo of the {c}.',
|
| 24 |
+
lambda c: f'a painting of the {c}.',
|
| 25 |
+
lambda c: f'a painting of a {c}.',
|
| 26 |
+
lambda c: f'a pixelated photo of the {c}.',
|
| 27 |
+
lambda c: f'a sculpture of the {c}.',
|
| 28 |
+
lambda c: f'a bright photo of the {c}.',
|
| 29 |
+
lambda c: f'a cropped photo of a {c}.',
|
| 30 |
+
lambda c: f'a plastic {c}.',
|
| 31 |
+
lambda c: f'a photo of the dirty {c}.',
|
| 32 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
| 33 |
+
lambda c: f'a blurry photo of the {c}.',
|
| 34 |
+
lambda c: f'a photo of the {c}.',
|
| 35 |
+
lambda c: f'a good photo of the {c}.',
|
| 36 |
+
lambda c: f'a rendering of the {c}.',
|
| 37 |
+
lambda c: f'a {c} in a video game.',
|
| 38 |
+
lambda c: f'a photo of one {c}.',
|
| 39 |
+
lambda c: f'a doodle of a {c}.',
|
| 40 |
+
lambda c: f'a close-up photo of the {c}.',
|
| 41 |
+
lambda c: f'a photo of a {c}.',
|
| 42 |
+
lambda c: f'the origami {c}.',
|
| 43 |
+
lambda c: f'the {c} in a video game.',
|
| 44 |
+
lambda c: f'a sketch of a {c}.',
|
| 45 |
+
lambda c: f'a doodle of the {c}.',
|
| 46 |
+
lambda c: f'a origami {c}.',
|
| 47 |
+
lambda c: f'a low resolution photo of a {c}.',
|
| 48 |
+
lambda c: f'the toy {c}.',
|
| 49 |
+
lambda c: f'a rendition of the {c}.',
|
| 50 |
+
lambda c: f'a photo of the clean {c}.',
|
| 51 |
+
lambda c: f'a photo of a large {c}.',
|
| 52 |
+
lambda c: f'a rendition of a {c}.',
|
| 53 |
+
lambda c: f'a photo of a nice {c}.',
|
| 54 |
+
lambda c: f'a photo of a weird {c}.',
|
| 55 |
+
lambda c: f'a blurry photo of a {c}.',
|
| 56 |
+
lambda c: f'a cartoon {c}.',
|
| 57 |
+
lambda c: f'art of a {c}.',
|
| 58 |
+
lambda c: f'a sketch of the {c}.',
|
| 59 |
+
lambda c: f'a embroidered {c}.',
|
| 60 |
+
lambda c: f'a pixelated photo of a {c}.',
|
| 61 |
+
lambda c: f'itap of the {c}.',
|
| 62 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
| 63 |
+
lambda c: f'a good photo of a {c}.',
|
| 64 |
+
lambda c: f'a plushie {c}.',
|
| 65 |
+
lambda c: f'a photo of the nice {c}.',
|
| 66 |
+
lambda c: f'a photo of the small {c}.',
|
| 67 |
+
lambda c: f'a photo of the weird {c}.',
|
| 68 |
+
lambda c: f'the cartoon {c}.',
|
| 69 |
+
lambda c: f'art of the {c}.',
|
| 70 |
+
lambda c: f'a drawing of the {c}.',
|
| 71 |
+
lambda c: f'a photo of the large {c}.',
|
| 72 |
+
lambda c: f'a black and white photo of a {c}.',
|
| 73 |
+
lambda c: f'the plushie {c}.',
|
| 74 |
+
lambda c: f'a dark photo of a {c}.',
|
| 75 |
+
lambda c: f'itap of a {c}.',
|
| 76 |
+
lambda c: f'graffiti of the {c}.',
|
| 77 |
+
lambda c: f'a toy {c}.',
|
| 78 |
+
lambda c: f'itap of my {c}.',
|
| 79 |
+
lambda c: f'a photo of a cool {c}.',
|
| 80 |
+
lambda c: f'a photo of a small {c}.',
|
| 81 |
+
lambda c: f'a tattoo of the {c}.',
|
| 82 |
+
]
|
templates/simple_template.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
simple_template = [
|
| 2 |
+
lambda c: f"a photo of a {c}."
|
| 3 |
+
]
|
templates/template_utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def get_plural(name):
|
| 2 |
+
name = name.replace('_', ' ')
|
| 3 |
+
if name[-2:] == 'sh':
|
| 4 |
+
name = name + 'es'
|
| 5 |
+
elif name[-2:] == 'ch':
|
| 6 |
+
name = name + 'es'
|
| 7 |
+
elif name[-1:] == 'y':
|
| 8 |
+
name = name[:-1] + 'ies'
|
| 9 |
+
elif name[-1:] == 's':
|
| 10 |
+
name = name + 'es'
|
| 11 |
+
elif name[-1:] == 'x':
|
| 12 |
+
name = name + 'es'
|
| 13 |
+
elif name[-3:] == 'man':
|
| 14 |
+
name = name[:-3] + 'men'
|
| 15 |
+
elif name == 'mouse':
|
| 16 |
+
name = 'mice'
|
| 17 |
+
elif name[-1:] == 'f':
|
| 18 |
+
name = name[:-1] + 'ves'
|
| 19 |
+
else:
|
| 20 |
+
name = name + 's'
|
| 21 |
+
return name
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def append_proper_article(name):
|
| 25 |
+
name = name.replace('_', ' ')
|
| 26 |
+
if name[0] in 'aeiou':
|
| 27 |
+
return 'an ' + name
|
| 28 |
+
return 'a ' + name
|
templates/testing_template.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
testing_template = [
|
| 2 |
+
lambda c : f'a photo of the number: "{c}".',
|
| 3 |
+
lambda c: f'a bad photo of a {c}.',
|
| 4 |
+
lambda c: f'a photo of many {c}.',
|
| 5 |
+
lambda c: f'a sculpture of a {c}.',
|
| 6 |
+
lambda c: f'a photo of the hard to see {c}.',
|
| 7 |
+
lambda c: f'a low resolution photo of the {c}.',
|
| 8 |
+
lambda c: f'a rendering of a {c}.',
|
| 9 |
+
lambda c: f'graffiti of a {c}.',
|
| 10 |
+
lambda c: f'a bad photo of the {c}.',
|
| 11 |
+
lambda c: f'a cropped photo of the {c}.',
|
| 12 |
+
lambda c: f'a tattoo of a {c}.',
|
| 13 |
+
lambda c: f'the embroidered {c}.',
|
| 14 |
+
lambda c: f'a photo of a hard to see {c}.',
|
| 15 |
+
lambda c: f'a bright photo of a {c}.',
|
| 16 |
+
lambda c: f'a photo of a clean {c}.',
|
| 17 |
+
lambda c: f'a photo of a dirty {c}.',
|
| 18 |
+
lambda c: f'a dark photo of the {c}.',
|
| 19 |
+
lambda c: f'a drawing of a {c}.',
|
| 20 |
+
lambda c: f'a photo of my {c}.',
|
| 21 |
+
lambda c: f'the plastic {c}.',
|
| 22 |
+
lambda c: f'a photo of the cool {c}.',
|
| 23 |
+
lambda c: f'a close-up photo of a {c}.',
|
| 24 |
+
lambda c: f'a black and white photo of the {c}.',
|
| 25 |
+
lambda c: f'a painting of the {c}.',
|
| 26 |
+
lambda c: f'a painting of a {c}.',
|
| 27 |
+
lambda c: f'a pixelated photo of the {c}.',
|
| 28 |
+
lambda c: f'a sculpture of the {c}.',
|
| 29 |
+
lambda c: f'a bright photo of the {c}.',
|
| 30 |
+
lambda c: f'a cropped photo of a {c}.',
|
| 31 |
+
lambda c: f'a plastic {c}.',
|
| 32 |
+
lambda c: f'a photo of the dirty {c}.',
|
| 33 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
| 34 |
+
lambda c: f'a blurry photo of the {c}.',
|
| 35 |
+
lambda c: f'a photo of the {c}.',
|
| 36 |
+
lambda c: f'a good photo of the {c}.',
|
| 37 |
+
lambda c: f'a rendering of the {c}.',
|
| 38 |
+
lambda c: f'a {c} in a video game.',
|
| 39 |
+
lambda c: f'a photo of one {c}.',
|
| 40 |
+
lambda c: f'a doodle of a {c}.',
|
| 41 |
+
lambda c: f'a close-up photo of the {c}.',
|
| 42 |
+
lambda c: f'a photo of a {c}.',
|
| 43 |
+
lambda c: f'the origami {c}.',
|
| 44 |
+
lambda c: f'the {c} in a video game.',
|
| 45 |
+
lambda c: f'a sketch of a {c}.',
|
| 46 |
+
lambda c: f'a doodle of the {c}.',
|
| 47 |
+
lambda c: f'a origami {c}.',
|
| 48 |
+
lambda c: f'a low resolution photo of a {c}.',
|
| 49 |
+
lambda c: f'the toy {c}.',
|
| 50 |
+
lambda c: f'a rendition of the {c}.',
|
| 51 |
+
lambda c: f'a photo of the clean {c}.',
|
| 52 |
+
lambda c: f'a photo of a large {c}.',
|
| 53 |
+
lambda c: f'a rendition of a {c}.',
|
| 54 |
+
lambda c: f'a photo of a nice {c}.',
|
| 55 |
+
lambda c: f'a photo of a weird {c}.',
|
| 56 |
+
lambda c: f'a blurry photo of a {c}.',
|
| 57 |
+
lambda c: f'a cartoon {c}.',
|
| 58 |
+
lambda c: f'art of a {c}.',
|
| 59 |
+
lambda c: f'a sketch of the {c}.',
|
| 60 |
+
lambda c: f'a embroidered {c}.',
|
| 61 |
+
lambda c: f'a pixelated photo of a {c}.',
|
| 62 |
+
lambda c: f'itap of the {c}.',
|
| 63 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
| 64 |
+
lambda c: f'a good photo of a {c}.',
|
| 65 |
+
lambda c: f'a plushie {c}.',
|
| 66 |
+
lambda c: f'a photo of the nice {c}.',
|
| 67 |
+
lambda c: f'a photo of the small {c}.',
|
| 68 |
+
lambda c: f'a photo of the weird {c}.',
|
| 69 |
+
lambda c: f'the cartoon {c}.',
|
| 70 |
+
lambda c: f'art of the {c}.',
|
| 71 |
+
lambda c: f'a drawing of the {c}.',
|
| 72 |
+
lambda c: f'a photo of the large {c}.',
|
| 73 |
+
lambda c: f'a black and white photo of a {c}.',
|
| 74 |
+
lambda c: f'the plushie {c}.',
|
| 75 |
+
lambda c: f'a dark photo of a {c}.',
|
| 76 |
+
lambda c: f'itap of a {c}.',
|
| 77 |
+
lambda c: f'graffiti of the {c}.',
|
| 78 |
+
lambda c: f'a toy {c}.',
|
| 79 |
+
lambda c: f'itap of my {c}.',
|
| 80 |
+
lambda c: f'a photo of a cool {c}.',
|
| 81 |
+
lambda c: f'a photo of a small {c}.',
|
| 82 |
+
lambda c: f'a tattoo of the {c}.',
|
| 83 |
+
]
|