boringKey commited on
Commit
b389d26
·
verified ·
1 Parent(s): 8b8d9f0

Upload 46 files

Browse files
Files changed (46) hide show
  1. LICENSE +203 -0
  2. README.md +113 -0
  3. __init__.py +0 -0
  4. class_orders/cifar100.yaml +1 -0
  5. class_orders/tinyimagenet.yaml +17 -0
  6. clip/README.md +1 -0
  7. clip/__init__.py +1 -0
  8. clip/adapter.py +75 -0
  9. clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  10. clip/clip.py +310 -0
  11. clip/model.py +486 -0
  12. clip/tokenizer.py +140 -0
  13. configs/class/cifar100_10-10.yaml +33 -0
  14. configs/class/cifar100_2-2.yaml +33 -0
  15. configs/class/cifar100_5-5.yaml +37 -0
  16. configs/class/tinyimagenet_100-10.yaml +36 -0
  17. configs/class/tinyimagenet_100-20.yaml +36 -0
  18. configs/class/tinyimagenet_100-5.yaml +36 -0
  19. continual_clip/__init__.py +0 -0
  20. continual_clip/cc.py +53 -0
  21. continual_clip/clip_original/README.md +1 -0
  22. continual_clip/clip_original/__init__.py +1 -0
  23. continual_clip/clip_original/adapter.py +75 -0
  24. continual_clip/clip_original/bpe_simple_vocab_16e6.txt.gz +3 -0
  25. continual_clip/clip_original/clip.py +208 -0
  26. continual_clip/clip_original/model.py +568 -0
  27. continual_clip/clip_original/tokenizer.py +140 -0
  28. continual_clip/datasets.py +124 -0
  29. continual_clip/dynamic_dataset.py +108 -0
  30. continual_clip/models.py +228 -0
  31. continual_clip/utils.py +210 -0
  32. dataset_reqs/imagenet1000_classes.txt +1000 -0
  33. dataset_reqs/imagenet100_classes.txt +100 -0
  34. dataset_reqs/imagenet100_splits/train_100.txt +0 -0
  35. dataset_reqs/imagenet100_splits/val_100.txt +0 -0
  36. dataset_reqs/tinyimagenet_classes.txt +200 -0
  37. main.py +104 -0
  38. requirements.txt +19 -0
  39. run_cifar100-10-10.sh +9 -0
  40. templates/__init__.py +0 -0
  41. templates/fmow_template.py +20 -0
  42. templates/iwildcam_template.py +4 -0
  43. templates/openai_imagenet_template.py +82 -0
  44. templates/simple_template.py +3 -0
  45. templates/template_utils.py +28 -0
  46. templates/testing_template.py +83 -0
LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 - present, Facebook, Inc
190
+ Copyright 2022 - present, Arthur Douillard
191
+ Copyright 2023 - present, Zangwei Zheng, Mingyuan Ma
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DMNSP: Dynamic Multi-Layer Null Space Projection for Vision-Language Continual Learning
2
+
3
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/)
4
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/)
5
+ [![License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
6
+
7
+ Official implementation of the paper "Dynamic Multi-Layer Null Space Projection for Vision-Language Continual Learning" (ICCV 2025) in PyTorch.
8
+
9
+ ## 🎯 Abstract
10
+
11
+ Vision-Language Models (VLM) have emerged as a highly promising approach for Continual Learning (CL) due to their powerful generalized features. While adapter-based VLM can exploit both task-specific and task-agnostic features, current CL methods have largely overlooked the distinct and evolving parameter distributions in visual and language modalities, which are found crucial for effectively mitigating catastrophic forgetting.In this study, we find that the **visual modality experiences a broader parameter distribution and greater variance** during class increments than the textual modality, leading to higher vulnerability to forgetting. Consequently, we handle the branches of the two modalities asymmetrically.
12
+
13
+ ### Key Contributions
14
+
15
+ - 🔍 **Asymmetric Modality Handling**: We propose handling visual and language modalities differently based on their distinct parameter distribution characteristics
16
+ - 🚀 **Multi-layer Null Space Projection**: A novel strategy applied only to the visual modality branch to restrict parameter updates within specific subspaces
17
+ - ⚖️ **Dynamic Projection Coefficient**: Precise control of gradient projection magnitude for optimal stability-plasticity balance
18
+
19
+ ## 🛠️ Installation
20
+
21
+ ### Setup Environment
22
+
23
+ ```bash
24
+ # Install dependencies
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ ## 📊 Datasets
29
+
30
+ The framework supports the following datasets for class incremental learning:
31
+
32
+ - **CIFAR100**: 100 classes, various incremental settings (2-2, 5-5, 10-10)
33
+ - **TinyImageNet**: 200 classes, incremental settings (200-100-5, 200-100-10, 200-100-20)
34
+
35
+ ### Data Preparation
36
+
37
+ 1. The datasets will be automatically downloaded when running experiments
38
+ 2. Update the `dataset_root` path in your configuration files or command line
39
+ 3. Ensure sufficient disk space for dataset storage
40
+
41
+ ## 🚀 Quick Start
42
+
43
+ ### Basic Usage
44
+
45
+ ```bash
46
+ # Run CIFAR100 with 10 initial classes and 10 incremental classes
47
+ sh run_cifar100-10-10.sh
48
+
49
+ # Or run with custom parameters
50
+ python main.py \
51
+ --config-path ./configs/class \
52
+ --config-name cifar100_10-10.yaml \
53
+ dataset_root="/path/to/your/data" \
54
+ class_order="./class_orders/cifar100.yaml"
55
+ ```
56
+
57
+ ### Configuration Options
58
+
59
+ The project uses Hydra for configuration management. Key parameters include:
60
+
61
+ ```yaml
62
+ # Model settings
63
+ model_name: "ViT-B/16" # CLIP model variant
64
+ prompt_template: "a bad photo of a {}." # Text prompt template
65
+
66
+ # Training settings
67
+ batch_size: 128 # Training batch size
68
+ lr: 1e-3 # Learning rate
69
+ weight_decay: 0.0 # Weight decay
70
+ ls: 0.0 # Label smoothing
71
+
72
+ # Incremental learning settings
73
+ initial_increment: 10 # Initial number of classes
74
+ increment: 10 # Classes per incremental step
75
+ method: "DMNSP" # Method name
76
+ ```
77
+
78
+ ## 🔧 Advanced Usage
79
+
80
+ ### Custom Datasets
81
+
82
+ To add support for new datasets:
83
+
84
+ 1. Add dataset configuration in `continual_clip/datasets.py`
85
+ 2. Create corresponding class order file in `class_orders/`
86
+ 3. Add configuration YAML in `configs/class/`
87
+
88
+ ## 📄 License
89
+
90
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
91
+
92
+ ## 📚 Citation
93
+
94
+ If you find this work useful in your research, please consider citing:
95
+
96
+ ```bibtex
97
+ @inproceedings{Kang2025DMNSP,
98
+ title={Dynamic Multi-Layer Null Space Projection for Vision-Language Continual Learning},
99
+ author={Borui Kang, Lei Wang, Zhiping Wu, Tao Feng, Yawen Li, Yang Gao, Wenbin Li},
100
+ journal={ICCV},
101
+ year={2025}
102
+ }
103
+ ```
104
+
105
+ ## 📞 Contact
106
+
107
+ For questions or issues, please:
108
+ - Open an issue on GitHub
109
+ - Contact the authors at [kangborui.cn@gmail.com]
110
+
111
+ ---
112
+
113
+ **Note**: This implementation is for research purposes. Please ensure you comply with the respective licenses of the datasets and models used.
__init__.py ADDED
File without changes
class_orders/cifar100.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ class_order: [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39]
class_orders/tinyimagenet.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class_order: [
2
+ 131, 181, 22, 172, 144, 92, 97, 187, 58, 93, 6, 70, 106, 68,
3
+ 153, 168, 179, 199, 29, 46, 9, 142, 134, 88, 193, 110, 26,
4
+ 32, 117, 112, 17, 39, 166, 13, 94, 138, 109, 147, 51, 101,
5
+ 59, 188, 116, 5, 170, 99, 100, 167, 180, 146, 65, 1, 104,
6
+ 43, 38, 184, 123, 171, 137, 162, 71, 44, 95, 174, 12, 7,
7
+ 54, 152, 21, 47, 28, 176, 34, 2, 132, 118, 42, 189, 150,
8
+ 14, 165, 41, 192, 45, 82, 128, 63, 57, 197, 160, 53, 75,
9
+ 108, 135, 121, 159, 183, 67, 169, 50, 87, 69, 89, 196,
10
+ 115, 19, 148, 96, 86, 11, 8, 60, 33, 173, 78, 4, 119, 105,
11
+ 182, 127, 177, 30, 186, 40, 49, 178, 76, 157, 161, 73, 164,
12
+ 151, 31, 74, 191, 27, 125, 198, 81, 20, 155, 114, 139, 36,
13
+ 61, 56, 145, 48, 16, 83, 62, 85, 126, 0, 102, 23, 3, 140,
14
+ 15, 195, 133, 113, 190, 141, 52, 163, 156, 80, 111, 90, 175,
15
+ 143, 120, 84, 18, 25, 79, 37, 154, 136, 64, 158, 24, 185,
16
+ 72, 35, 129, 55, 149, 91, 122, 77, 103, 124, 130, 66, 10, 107, 194, 98
17
+ ]
clip/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder is a lightly modified version of https://github.com/openai/CLIP.
clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
clip/adapter.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # https://github.com/jxhe/unify-parameter-efficient-tuning
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class Adapter(nn.Module):
12
+ def __init__(self,
13
+ d_model=None,
14
+ bottleneck=None,
15
+ dropout=0.0,
16
+ init_option="lora",
17
+ adapter_scalar="1.0",
18
+ adapter_layernorm_option="in"):
19
+ super().__init__()
20
+ self.n_embd = d_model if d_model is None else d_model
21
+ self.down_size = bottleneck
22
+
23
+ #_before
24
+ self.adapter_layernorm_option = adapter_layernorm_option
25
+
26
+ self.adapter_layer_norm_before = None
27
+ if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
28
+ self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
29
+
30
+ if adapter_scalar == "learnable_scalar":
31
+ self.scale = nn.Parameter(torch.ones(1))
32
+ else:
33
+ self.scale = float(adapter_scalar)
34
+
35
+ # self.linear = nn.Linear(self.n_embd, self.n_embd)
36
+
37
+ self.down_proj = nn.Linear(self.n_embd, 64)
38
+ self.non_linear_func = nn.ReLU()
39
+ self.up_proj = nn.Linear(self.down_size, self.n_embd)
40
+
41
+ self.dropout = dropout
42
+ if init_option == "bert":
43
+ raise NotImplementedError
44
+ elif init_option == "lora":
45
+ with torch.no_grad():
46
+ nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
47
+ nn.init.zeros_(self.up_proj.weight)
48
+ nn.init.zeros_(self.down_proj.bias)
49
+ nn.init.zeros_(self.up_proj.bias)
50
+ elif init_option == "linear":
51
+ with torch.no_grad():
52
+ nn.init.zeros_(self.linear.weight)
53
+
54
+ def forward(self, x, add_residual=True, residual=None):
55
+
56
+ residual = x if residual is None else residual
57
+ if self.adapter_layernorm_option == 'in': # none
58
+ x = self.adapter_layer_norm_before(x)
59
+
60
+ down = self.down_proj(x)
61
+ down = self.non_linear_func(down)
62
+ down = nn.functional.dropout(down, p=self.dropout, training=self.training)
63
+ up = self.up_proj(down)
64
+
65
+ up = up * self.scale
66
+
67
+ if self.adapter_layernorm_option == 'out': # none
68
+ up = self.adapter_layer_norm_before(up)
69
+
70
+ if add_residual:
71
+ output = up + residual
72
+ else:
73
+ output = up
74
+ return down, output, \
75
+ self.up_proj.weight, self.down_proj.weight, self.up_proj.bias, self.down_proj.bias
clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
clip/clip.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code ported from https://github.com/openai/CLIP
2
+
3
+ import hashlib
4
+ import os
5
+ import urllib
6
+ import warnings
7
+ from typing import Union, List
8
+
9
+ import torch
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, InterpolationMode
11
+ from tqdm import tqdm
12
+
13
+ from clip.model import build_model
14
+ from clip.tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ __all__ = ["available_models", "load", "tokenize"]
17
+ _tokenizer = _Tokenizer()
18
+
19
+ _MODELS = {
20
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
21
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
23
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
24
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
25
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
26
+ }
27
+
28
+
29
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
30
+ os.makedirs(root, exist_ok=True)
31
+ filename = os.path.basename(url)
32
+
33
+ expected_sha256 = url.split("/")[-2]
34
+ download_target = os.path.join(root, filename)
35
+
36
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
37
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
38
+
39
+ if os.path.isfile(download_target):
40
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
41
+ return download_target
42
+ else:
43
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
44
+
45
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
46
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
47
+ while True:
48
+ buffer = source.read(8192)
49
+ if not buffer:
50
+ break
51
+
52
+ output.write(buffer)
53
+ loop.update(len(buffer))
54
+
55
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
56
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
57
+
58
+ return download_target
59
+
60
+ def _convert_to_rgb(image):
61
+ return image.convert('RGB')
62
+
63
+ def _transform(n_px: int, is_train: bool):
64
+ normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
65
+ if is_train:
66
+ return Compose([
67
+ RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
68
+ _convert_to_rgb,
69
+ ToTensor(),
70
+ normalize,
71
+ ])
72
+ else:
73
+ return Compose([
74
+ Resize(n_px, interpolation=InterpolationMode.BICUBIC),
75
+ CenterCrop(n_px),
76
+ _convert_to_rgb,
77
+ ToTensor(),
78
+ normalize,
79
+ ])
80
+
81
+
82
+
83
+ def available_models() -> List[str]:
84
+ """Returns the names of available CLIP models"""
85
+ return list(_MODELS.keys())
86
+
87
+ # def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
88
+ # """Load a CLIP model
89
+
90
+ # Parameters
91
+ # ----------
92
+ # name : str
93
+ # A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
94
+
95
+ # device : Union[str, torch.device]
96
+ # The device to put the loaded model
97
+
98
+ # jit : bool
99
+ # Whether to load the optimized JIT model or more hackable non-JIT model (default).
100
+
101
+ # download_root: str
102
+ # path to download the model files; by default, it uses "~/.cache/clip"
103
+
104
+ # Returns
105
+ # -------
106
+ # model : torch.nn.Module
107
+ # The CLIP model
108
+
109
+ # preprocess : Callable[[PIL.Image], torch.Tensor]
110
+ # A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
111
+ # """
112
+ # if name in _MODELS:
113
+ # model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
114
+ # elif os.path.isfile(name):
115
+ # model_path = name
116
+ # else:
117
+ # raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
118
+
119
+ # try:
120
+ # # loading JIT archive
121
+ # model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
122
+ # state_dict = None
123
+ # except RuntimeError:
124
+ # # loading saved state dict
125
+ # if jit:
126
+ # warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
127
+ # jit = False
128
+ # state_dict = torch.load(model_path, map_location="cpu")
129
+
130
+ # if not jit:
131
+ # model = build_model(state_dict or model.state_dict()).to(device)
132
+ # if str(device) == "cpu":
133
+ # model.float()
134
+ # return model, _transform(model.visual.input_resolution)
135
+
136
+ # # patch the device names
137
+ # device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
138
+ # device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
139
+
140
+ # def patch_device(module):
141
+ # try:
142
+ # graphs = [module.graph] if hasattr(module, "graph") else []
143
+ # except RuntimeError:
144
+ # graphs = []
145
+
146
+ # if hasattr(module, "forward1"):
147
+ # graphs.append(module.forward1.graph)
148
+
149
+ # for graph in graphs:
150
+ # for node in graph.findAllNodes("prim::Constant"):
151
+ # if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
152
+ # node.copyAttributes(device_node)
153
+
154
+ # model.apply(patch_device)
155
+ # patch_device(model.encode_image)
156
+ # patch_device(model.encode_text)
157
+
158
+ # # patch dtype to float32 on CPU
159
+ # if str(device) == "cpu":
160
+ # float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
161
+ # float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
162
+ # float_node = float_input.node()
163
+
164
+ # def patch_float(module):
165
+ # try:
166
+ # graphs = [module.graph] if hasattr(module, "graph") else []
167
+ # except RuntimeError:
168
+ # graphs = []
169
+
170
+ # if hasattr(module, "forward1"):
171
+ # graphs.append(module.forward1.graph)
172
+
173
+ # for graph in graphs:
174
+ # for node in graph.findAllNodes("aten::to"):
175
+ # inputs = list(node.inputs())
176
+ # for i in [1, 2]: # dtype can be the second or third argument to aten::to()
177
+ # if inputs[i].node()["value"] == 5:
178
+ # inputs[i].node().copyAttributes(float_node)
179
+
180
+ # model.apply(patch_float)
181
+ # patch_float(model.encode_image)
182
+ # patch_float(model.encode_text)
183
+
184
+ # model.float()
185
+
186
+ # return model, _transform(model.input_resolution.item())
187
+
188
+
189
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True):
190
+ """Load a CLIP model
191
+ Parameters
192
+ ----------
193
+ name : str
194
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
195
+ device : Union[str, torch.device]
196
+ The device to put the loaded model
197
+ jit : bool
198
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
199
+ Returns
200
+ -------
201
+ model : torch.nn.Module
202
+ The CLIP model
203
+ preprocess : Callable[[PIL.Image], torch.Tensor]
204
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
205
+ """
206
+ if name in _MODELS:
207
+ model_path = _download(_MODELS[name])
208
+ elif os.path.isfile(name):
209
+ model_path = name
210
+ else:
211
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
212
+
213
+ try:
214
+ # loading JIT archive
215
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
216
+ state_dict = None
217
+ except RuntimeError:
218
+ # loading saved state dict
219
+ if jit:
220
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
221
+ jit = False
222
+ state_dict = torch.load(model_path, map_location="cpu")
223
+
224
+ if not jit:
225
+ try:
226
+ model = build_model(state_dict or model.state_dict()).to(device)
227
+ except KeyError:
228
+ sd = {k[7:]: v for k,v in state_dict["state_dict"].items()}
229
+ model = build_model(sd).to(device)
230
+
231
+ if str(device) == "cpu":
232
+ model.float()
233
+ return model, \
234
+ _transform(model.visual.input_resolution, is_train=True), \
235
+ _transform(model.visual.input_resolution, is_train=False)
236
+
237
+ # patch the device names
238
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
239
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
240
+
241
+ def patch_device(module):
242
+ graphs = [module.graph] if hasattr(module, "graph") else []
243
+ if hasattr(module, "forward1"):
244
+ graphs.append(module.forward1.graph)
245
+
246
+ for graph in graphs:
247
+ for node in graph.findAllNodes("prim::Constant"):
248
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
249
+ node.copyAttributes(device_node)
250
+
251
+ model.apply(patch_device)
252
+ patch_device(model.encode_image)
253
+ patch_device(model.encode_text)
254
+
255
+ # patch dtype to float32 on CPU
256
+ if str(device) == "cpu":
257
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
258
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
259
+ float_node = float_input.node()
260
+
261
+ def patch_float(module):
262
+ graphs = [module.graph] if hasattr(module, "graph") else []
263
+ if hasattr(module, "forward1"):
264
+ graphs.append(module.forward1.graph)
265
+
266
+ for graph in graphs:
267
+ for node in graph.findAllNodes("aten::to"):
268
+ inputs = list(node.inputs())
269
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
270
+ if inputs[i].node()["value"] == 5:
271
+ inputs[i].node().copyAttributes(float_node)
272
+
273
+ model.apply(patch_float)
274
+ patch_float(model.encode_image)
275
+ patch_float(model.encode_text)
276
+
277
+ model.float()
278
+
279
+ return model, \
280
+ _transform(model.input_resolution.item(), is_train=True), \
281
+ _transform(model.input_resolution.item(), is_train=False)
282
+
283
+
284
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
285
+ """
286
+ Returns the tokenized representation of given input string(s)
287
+ Parameters
288
+ ----------
289
+ texts : Union[str, List[str]]
290
+ An input string or a list of input strings to tokenize
291
+ context_length : int
292
+ The context length to use; all CLIP models use 77 as the context length
293
+ Returns
294
+ -------
295
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
296
+ """
297
+ if isinstance(texts, str):
298
+ texts = [texts]
299
+
300
+ sot_token = _tokenizer.encoder["<start_of_text>"]
301
+ eot_token = _tokenizer.encoder["<end_of_text>"]
302
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
303
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
304
+
305
+ for i, tokens in enumerate(all_tokens):
306
+ if len(tokens) > context_length: # Truncate
307
+ tokens = tokens[:context_length]
308
+ result[i, :len(tokens)] = torch.tensor(tokens)
309
+
310
+ return result
clip/model.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from .adapter import Adapter
11
+ from torch.distributions.normal import Normal
12
+ from collections import Counter
13
+
14
+ global_taskid = 0
15
+ global_is_train=True
16
+
17
+ class Bottleneck(nn.Module):
18
+ expansion = 4
19
+
20
+ def __init__(self, inplanes, planes, stride=1):
21
+ super().__init__()
22
+
23
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
24
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(planes)
26
+
27
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
28
+ self.bn2 = nn.BatchNorm2d(planes)
29
+
30
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
31
+
32
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
33
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
34
+
35
+ self.relu = nn.ReLU(inplace=True)
36
+ self.downsample = None
37
+ self.stride = stride
38
+
39
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
40
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
41
+ self.downsample = nn.Sequential(OrderedDict([
42
+ ("-1", nn.AvgPool2d(stride)),
43
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
44
+ ("1", nn.BatchNorm2d(planes * self.expansion))
45
+ ]))
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ identity = x
49
+
50
+ out = self.relu(self.bn1(self.conv1(x)))
51
+ out = self.relu(self.bn2(self.conv2(out)))
52
+ out = self.avgpool(out)
53
+ out = self.bn3(self.conv3(out))
54
+
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+
58
+ out += identity
59
+ out = self.relu(out)
60
+ return out
61
+
62
+
63
+ class AttentionPool2d(nn.Module):
64
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
65
+ super().__init__()
66
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
67
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
68
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
69
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
70
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
71
+ self.num_heads = num_heads
72
+
73
+ def forward(self, x):
74
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
75
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
76
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
77
+ x, _ = F.multi_head_attention_forward(
78
+ query=x, key=x, value=x,
79
+ embed_dim_to_check=x.shape[-1],
80
+ num_heads=self.num_heads,
81
+ q_proj_weight=self.q_proj.weight,
82
+ k_proj_weight=self.k_proj.weight,
83
+ v_proj_weight=self.v_proj.weight,
84
+ in_proj_weight=None,
85
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
86
+ bias_k=None,
87
+ bias_v=None,
88
+ add_zero_attn=False,
89
+ dropout_p=0,
90
+ out_proj_weight=self.c_proj.weight,
91
+ out_proj_bias=self.c_proj.bias,
92
+ use_separate_proj_weight=True,
93
+ training=self.training,
94
+ need_weights=False
95
+ )
96
+
97
+ return x[0]
98
+
99
+
100
+ class ModifiedResNet(nn.Module):
101
+ """
102
+ A ResNet class that is similar to torchvision's but contains the following changes:
103
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
104
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
105
+ - The final pooling layer is a QKV attention instead of an average pool
106
+ """
107
+
108
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
109
+ super().__init__()
110
+ self.output_dim = output_dim
111
+ self.input_resolution = input_resolution
112
+
113
+ # the 3-layer stem
114
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
115
+ self.bn1 = nn.BatchNorm2d(width // 2)
116
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
117
+ self.bn2 = nn.BatchNorm2d(width // 2)
118
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
119
+ self.bn3 = nn.BatchNorm2d(width)
120
+ self.avgpool = nn.AvgPool2d(2)
121
+ self.relu = nn.ReLU(inplace=True)
122
+
123
+ # residual layers
124
+ self._inplanes = width # this is a *mutable* variable used during construction
125
+ self.layer1 = self._make_layer(width, layers[0])
126
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
127
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
128
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
129
+
130
+ embed_dim = width * 32 # the ResNet feature dimension
131
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
132
+
133
+ def _make_layer(self, planes, blocks, stride=1):
134
+ layers = [Bottleneck(self._inplanes, planes, stride)]
135
+
136
+ self._inplanes = planes * Bottleneck.expansion
137
+ for _ in range(1, blocks):
138
+ layers.append(Bottleneck(self._inplanes, planes))
139
+
140
+ return nn.Sequential(*layers)
141
+
142
+ def forward(self, x):
143
+ def stem(x):
144
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
145
+ x = self.relu(bn(conv(x)))
146
+ x = self.avgpool(x)
147
+ return x
148
+
149
+ x = x.type(self.conv1.weight.dtype)
150
+ x = stem(x)
151
+ x = self.layer1(x)
152
+ x = self.layer2(x)
153
+ x = self.layer3(x)
154
+ x = self.layer4(x)
155
+ x = self.attnpool(x)
156
+
157
+ return x
158
+
159
+
160
+ class LayerNorm(nn.LayerNorm):
161
+ """Subclass torch's LayerNorm to handle fp16."""
162
+
163
+ def forward(self, x: torch.Tensor):
164
+ orig_type = x.dtype
165
+ ret = super().forward(x.type(torch.float32))
166
+ return ret.type(orig_type)
167
+
168
+
169
+ class QuickGELU(nn.Module):
170
+ def forward(self, x: torch.Tensor):
171
+ return x * torch.sigmoid(1.702 * x)
172
+
173
+
174
+ class ResidualAttentionBlock(nn.Module):
175
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, text_or_image=None, flag=False):
176
+ super().__init__()
177
+ self.register_buffer("mean", torch.tensor([0.0]))
178
+ self.register_buffer("std", torch.tensor([1.0]))
179
+ self.attn = nn.MultiheadAttention(d_model, n_head)
180
+ self.ln_1 = LayerNorm(d_model)
181
+ self.mlp = nn.Sequential(OrderedDict([
182
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
183
+ ("gelu", QuickGELU()),
184
+ ("c_proj", nn.Linear(d_model * 4, d_model))
185
+ ]))
186
+ self.ln_2 = LayerNorm(d_model)
187
+ self.attn_mask = attn_mask
188
+ self.is_train = global_is_train
189
+ self.ffn_num = 64
190
+ self.softmax = nn.Softmax(1)
191
+ self.softplus = nn.Softplus()
192
+ self.noisy_gating = True
193
+ self.adaptmlp_list = nn.ModuleList()
194
+ self.text_or_image = text_or_image
195
+ self.flag = flag
196
+ self.adaptmlp = Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num,
197
+ init_option='lora',
198
+ adapter_scalar=0.1,
199
+ adapter_layernorm_option='none',
200
+ )
201
+
202
+ def attention(self, x: torch.Tensor):
203
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
204
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
205
+
206
+ def forward(self, x: torch.Tensor):
207
+ x = x + self.attention(self.ln_1(x))
208
+ x_re = x.permute(1, 0, 2)
209
+ down, adapt_x, up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = self.adaptmlp(x_re, add_residual=False)
210
+ adapt_x = adapt_x.permute(1, 0, 2)
211
+ down = down.permute(1, 0, 2)
212
+ x = x + self.mlp(self.ln_2(x)) + adapt_x
213
+
214
+ return x, down, adapt_x, up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias
215
+
216
+
217
+ class Transformer(nn.Module):
218
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, text_or_image=None,
219
+ flag =True,):
220
+ super().__init__()
221
+ self.width = width
222
+ self.layers = layers
223
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, text_or_image, flag) for _ in range(layers)])
224
+ self.lora_feature = {}
225
+
226
+ def forward(self, x: torch.Tensor):
227
+ for i in range(len(self.resblocks)):
228
+ x, down, adapt_x, up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = self.resblocks[i](x)
229
+ self.lora_feature[i] = adapt_x
230
+ return x
231
+
232
+
233
+ class VisualTransformer(nn.Module):
234
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, text_or_image=None):
235
+ super().__init__()
236
+ self.input_resolution = input_resolution
237
+ self.output_dim = output_dim
238
+ # Added so this info is available. should not change anything.
239
+ self.patch_size = patch_size
240
+ self.width = width
241
+ self.layers = layers
242
+ self.heads = heads
243
+
244
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
245
+
246
+ scale = width ** -0.5
247
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
248
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
249
+ self.ln_pre = LayerNorm(width)
250
+
251
+ self.transformer = Transformer(width, layers, heads, text_or_image=text_or_image, flag=True)
252
+
253
+ self.ln_post = LayerNorm(width)
254
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
255
+
256
+ def forward(self, x: torch.Tensor):
257
+ x = self.conv1(x)
258
+ x = x.reshape(x.shape[0], x.shape[1], -1)
259
+ x = x.permute(0, 2, 1)
260
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
261
+ x = x + self.positional_embedding.to(x.dtype)
262
+ x = self.ln_pre(x)
263
+
264
+ x = x.permute(1, 0, 2) # NLD -> LND
265
+ x = self.transformer(x)
266
+ x_before_fusion = x
267
+ x = x.permute(1, 0, 2) # LND -> NLD
268
+
269
+ x = self.ln_post(x[:, 0, :])
270
+
271
+ if self.proj is not None:
272
+ x = x @ self.proj
273
+
274
+ return x, x_before_fusion
275
+
276
+
277
+ class CLIP(nn.Module):
278
+ def __init__(self,
279
+ embed_dim: int,
280
+ # vision
281
+ image_resolution: int,
282
+ vision_layers: Union[Tuple[int, int, int, int], int],
283
+ vision_width: int,
284
+ vision_patch_size: int,
285
+ # text
286
+ context_length: int,
287
+ vocab_size: int,
288
+ transformer_width: int,
289
+ transformer_heads: int,
290
+ transformer_layers: int,
291
+ baseline = False
292
+ ):
293
+ super().__init__()
294
+ self.baseline = baseline
295
+
296
+ self.context_length = context_length
297
+
298
+ if isinstance(vision_layers, (tuple, list)):
299
+ vision_heads = vision_width * 32 // 64
300
+ self.visual = ModifiedResNet(
301
+ layers=vision_layers,
302
+ output_dim=embed_dim,
303
+ heads=vision_heads,
304
+ input_resolution=image_resolution,
305
+ width=vision_width
306
+ )
307
+ else:
308
+ vision_heads = vision_width // 64
309
+ self.visual = VisualTransformer(
310
+ input_resolution=image_resolution,
311
+ patch_size=vision_patch_size,
312
+ width=vision_width,
313
+ layers=vision_layers,
314
+ heads=vision_heads,
315
+ output_dim=embed_dim,
316
+ text_or_image='image',
317
+ )
318
+
319
+ self.transformer = Transformer(
320
+ width=transformer_width,
321
+ layers=transformer_layers,
322
+ heads=transformer_heads,
323
+ attn_mask=self.build_attention_mask(),
324
+ text_or_image='text',
325
+ flag = True,
326
+ )
327
+
328
+
329
+ self.vocab_size = vocab_size
330
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
331
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
332
+ self.ln_final = LayerNorm(transformer_width)
333
+
334
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
335
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
336
+ # self.adapt_lamda = [torch.nn.Parameter(30 * torch.rand(1)) for _ in range(12)]
337
+
338
+
339
+ self.initialize_parameters()
340
+
341
+ def initialize_parameters(self):
342
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
343
+ nn.init.normal_(self.positional_embedding, std=0.01)
344
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
345
+
346
+ if isinstance(self.visual, ModifiedResNet):
347
+ if self.visual.attnpool is not None:
348
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
349
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
350
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
351
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
352
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
353
+
354
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
355
+ for name, param in resnet_block.named_parameters():
356
+ if name.endswith("bn3.weight"):
357
+ nn.init.zeros_(param)
358
+
359
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
360
+ attn_std = self.transformer.width ** -0.5
361
+ fc_std = (2 * self.transformer.width) ** -0.5
362
+ for block in self.transformer.resblocks:
363
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
364
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
365
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
366
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
367
+
368
+ if self.text_projection is not None:
369
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
370
+
371
+ def build_attention_mask(self):
372
+ # lazily create causal attention mask, with full attention between the vision tokens
373
+ # pytorch uses additive attention mask; fill with -inf
374
+ mask = torch.empty(self.context_length, self.context_length)
375
+ mask.fill_(float("-inf"))
376
+ mask.triu_(1) # zero out the lower diagonal
377
+ return mask
378
+
379
+ @property
380
+ def dtype(self):
381
+ return self.visual.conv1.weight.dtype
382
+
383
+ def encode_image(self, image):
384
+ return self.visual(image.type(self.dtype))
385
+
386
+ def encode_text(self, text):
387
+
388
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
389
+
390
+ x = x + self.positional_embedding.type(self.dtype)
391
+ x = x.permute(1, 0, 2) # NLD -> LND
392
+ x = self.transformer(x)
393
+ x_before_fusion = x
394
+ x = x.permute(1, 0, 2) # LND -> NLD
395
+ x = self.ln_final(x).type(self.dtype)
396
+
397
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
398
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
399
+
400
+ return x, x_before_fusion
401
+
402
+
403
+ def forward(self, image, text, taskid, is_train):
404
+ global global_taskid, global_is_train
405
+ global_taskid = taskid
406
+ global_is_train = is_train
407
+ if image is None:
408
+ return self.encode_text(text)
409
+ elif text is None:
410
+ return self.encode_image(image)
411
+ image_features, x_img_before_fusion = self.encode_image(image)
412
+ text_features, x_txt_before_fusion = self.encode_text(text)
413
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
414
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
415
+
416
+ # if self.baseline:
417
+ logit_scale = self.logit_scale.exp()
418
+ logits_per_image = logit_scale * image_features @ text_features.t()
419
+ logits_per_text = logits_per_image.t()
420
+ return logits_per_image, logits_per_text
421
+
422
+
423
+
424
+ def convert_weights(model: nn.Module):
425
+ """Convert applicable model parameters to fp16"""
426
+
427
+ def _convert_weights_to_fp16(l):
428
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
429
+ l.weight.data = l.weight.data.half()
430
+ if l.bias is not None:
431
+ l.bias.data = l.bias.data.half()
432
+
433
+ if isinstance(l, nn.MultiheadAttention):
434
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
435
+ tensor = getattr(l, attr)
436
+ if tensor is not None:
437
+ tensor.data = tensor.data.half()
438
+
439
+ for name in ["text_projection", "proj"]:
440
+ if hasattr(l, name):
441
+ attr = getattr(l, name)
442
+ if attr is not None:
443
+ attr.data = attr.data.half()
444
+
445
+ model.apply(_convert_weights_to_fp16)
446
+
447
+
448
+ def build_model(state_dict: dict):
449
+ vit = "visual.proj" in state_dict
450
+
451
+ if vit:
452
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
453
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
454
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
455
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
456
+ image_resolution = vision_patch_size * grid_size
457
+ else:
458
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
459
+ vision_layers = tuple(counts)
460
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
461
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
462
+ vision_patch_size = None
463
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
464
+ image_resolution = output_width * 32
465
+
466
+ embed_dim = state_dict["text_projection"].shape[1]
467
+ context_length = state_dict["positional_embedding"].shape[0]
468
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
469
+ transformer_width = state_dict["ln_final.weight"].shape[0]
470
+ transformer_heads = transformer_width // 64
471
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
472
+
473
+ model = CLIP(
474
+ embed_dim,
475
+ image_resolution, vision_layers, vision_width, vision_patch_size,
476
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
477
+ )
478
+
479
+ for key in ["input_resolution", "context_length", "vocab_size"]:
480
+ if key in state_dict:
481
+ del state_dict[key]
482
+
483
+ model.load_state_dict(state_dict, strict=False)
484
+ for p in model.parameters():
485
+ p.data = p.data.float()
486
+ return model.eval()
clip/tokenizer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ if not special_tokens:
74
+ special_tokens = ['<start_of_text>', '<end_of_text>']
75
+ else:
76
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
77
+ vocab.extend(special_tokens)
78
+ self.encoder = dict(zip(vocab, range(len(vocab))))
79
+ self.decoder = {v: k for k, v in self.encoder.items()}
80
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
81
+ self.cache = {t:t for t in special_tokens}
82
+ special = "|".join(special_tokens)
83
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
84
+
85
+ self.vocab_size = len(self.encoder)
86
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
87
+
88
+ def bpe(self, token):
89
+ if token in self.cache:
90
+ return self.cache[token]
91
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
92
+ pairs = get_pairs(word)
93
+
94
+ if not pairs:
95
+ return token+'</w>'
96
+
97
+ while True:
98
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
99
+ if bigram not in self.bpe_ranks:
100
+ break
101
+ first, second = bigram
102
+ new_word = []
103
+ i = 0
104
+ while i < len(word):
105
+ try:
106
+ j = word.index(first, i)
107
+ new_word.extend(word[i:j])
108
+ i = j
109
+ except:
110
+ new_word.extend(word[i:])
111
+ break
112
+
113
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
114
+ new_word.append(first+second)
115
+ i += 2
116
+ else:
117
+ new_word.append(word[i])
118
+ i += 1
119
+ new_word = tuple(new_word)
120
+ word = new_word
121
+ if len(word) == 1:
122
+ break
123
+ else:
124
+ pairs = get_pairs(word)
125
+ word = ' '.join(word)
126
+ self.cache[token] = word
127
+ return word
128
+
129
+ def encode(self, text):
130
+ bpe_tokens = []
131
+ text = whitespace_clean(basic_clean(text)).lower()
132
+ for token in re.findall(self.pat, text):
133
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
134
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
135
+ return bpe_tokens
136
+
137
+ def decode(self, tokens):
138
+ text = ''.join([self.decoder[token] for token in tokens])
139
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
140
+ return text
configs/class/cifar100_10-10.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
4
+ job:
5
+ chdir: true
6
+
7
+ job_logging:
8
+ version: 1
9
+ formatters:
10
+ simple:
11
+ format: '%(message)s'
12
+
13
+ class_order: ""
14
+ dataset_root: ""
15
+ workdir: ""
16
+ log_path: "metrics.json"
17
+ model_name: "ViT-B/16"
18
+ prompt_template: "a bad photo of a {}."
19
+
20
+ batch_size: 128
21
+ increment: ${initial_increment}
22
+ initial_increment: 10
23
+ scenario: "class"
24
+ dataset: "cifar100"
25
+
26
+ weight_decay: 0.0
27
+ l2: 0
28
+ ce_method: 0
29
+
30
+ method: "DMNSP"
31
+
32
+ lr: 1e-3
33
+ ls: 0.0
configs/class/cifar100_2-2.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}-${ls}
4
+ job:
5
+ chdir: true
6
+
7
+ job_logging:
8
+ version: 1
9
+ formatters:
10
+ simple:
11
+ format: '%(message)s'
12
+
13
+ class_order: ""
14
+ dataset_root: ""
15
+ workdir: ""
16
+ log_path: "metrics.json"
17
+ model_name: "ViT-B/16"
18
+ prompt_template: "a bad photo of a {}."
19
+
20
+ batch_size: 128
21
+ increment: ${initial_increment}
22
+ initial_increment: 2
23
+ scenario: "class"
24
+ dataset: "cifar100"
25
+
26
+ weight_decay: 0.0
27
+ l2: 0
28
+ ce_method: 0
29
+
30
+ method: "DMNSP"
31
+ lr: 1e-3
32
+ ls: 0.0
33
+
configs/class/cifar100_5-5.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
4
+ job:
5
+ chdir: true
6
+
7
+ job_logging:
8
+ version: 1
9
+ formatters:
10
+ simple:
11
+ format: '%(message)s'
12
+
13
+ class_order: ""
14
+ dataset_root: ""
15
+ workdir: ""
16
+ log_path: "metrics.json"
17
+ model_name: "ViT-B/16"
18
+ prompt_template: "a bad photo of a {}."
19
+
20
+ batch_size: 128
21
+ increment: ${initial_increment}
22
+ initial_increment: 5
23
+ scenario: "class"
24
+ dataset: "cifar100"
25
+
26
+
27
+
28
+ weight_decay: 0.0
29
+ l2: 0
30
+ ce_method: 0
31
+
32
+ method: "DMNSP"
33
+ lr: 1e-3
34
+ ls: 0.0
35
+
36
+
37
+
configs/class/tinyimagenet_100-10.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
4
+ job:
5
+ chdir: true
6
+
7
+ job_logging:
8
+ version: 1
9
+ formatters:
10
+ simple:
11
+ format: '%(message)s'
12
+
13
+ class_order: ""
14
+ dataset_root: ""
15
+ workdir: ""
16
+ log_path: "metrics.json"
17
+ model_name: "ViT-B/16"
18
+ prompt_template: "a bad photo of a {}."
19
+
20
+ batch_size: 128
21
+ initial_increment: 100
22
+ increment: 10
23
+ scenario: "class"
24
+ dataset: "tinyimagenet"
25
+
26
+ weight_decay: 0.0
27
+ l2: 0
28
+ ce_method: 0
29
+
30
+ method: "DMNSP"
31
+ lr: 1e-3
32
+ ls: 0.0
33
+ we:
34
+ avg_freq:
35
+ ref_dataset:
36
+ ref_sentences: random
configs/class/tinyimagenet_100-20.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
4
+ job:
5
+ chdir: true
6
+
7
+ job_logging:
8
+ version: 1
9
+ formatters:
10
+ simple:
11
+ format: '%(message)s'
12
+
13
+ class_order: ""
14
+ dataset_root: ""
15
+ workdir: ""
16
+ log_path: "metrics.json"
17
+ model_name: "ViT-B/16"
18
+ prompt_template: "a bad photo of a {}."
19
+
20
+ batch_size: 128
21
+ initial_increment: 100
22
+ increment: 20
23
+ scenario: "class"
24
+ dataset: "tinyimagenet"
25
+
26
+ weight_decay: 0.0
27
+ l2: 0
28
+ ce_method: 0
29
+
30
+ method: "DMNSP"
31
+ lr: 1e-3
32
+ ls: 0.0
33
+ we:
34
+ avg_freq:
35
+ ref_dataset:
36
+ ref_sentences: random
configs/class/tinyimagenet_100-5.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./experiments/${scenario}/${dataset}_${initial_increment}-${increment}-${method}
4
+ job:
5
+ chdir: true
6
+
7
+ job_logging:
8
+ version: 1
9
+ formatters:
10
+ simple:
11
+ format: '%(message)s'
12
+
13
+ class_order: ""
14
+ dataset_root: ""
15
+ workdir: ""
16
+ log_path: "metrics.json"
17
+ model_name: "ViT-B/16"
18
+ prompt_template: "a bad photo of a {}."
19
+
20
+ batch_size: 128
21
+ initial_increment: 100
22
+ increment: 5
23
+ scenario: "class"
24
+ dataset: "tinyimagenet"
25
+
26
+ weight_decay: 0.0
27
+ l2: 0
28
+ ce_method: 0
29
+
30
+ method: "DMNSP"
31
+ lr: 1e-3
32
+ ls: 0.0
33
+ we:
34
+ avg_freq:
35
+ ref_dataset:
36
+ ref_sentences: random
continual_clip/__init__.py ADDED
File without changes
continual_clip/cc.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from PIL import Image
4
+ import torch
5
+ from torch.utils.data import (
6
+ DataLoader,
7
+ Dataset,
8
+ IterableDataset,
9
+ SubsetRandomSampler,
10
+ get_worker_info,
11
+ )
12
+ import clip.clip as clip
13
+
14
+
15
+ class CsvDataset(Dataset):
16
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
17
+ df = pd.read_csv(input_filename, sep=sep)
18
+
19
+ self.location = os.path.dirname(input_filename)
20
+ self.images = df[img_key].tolist()
21
+ self.captions = df[caption_key].tolist()
22
+ self.transforms = transforms
23
+
24
+ def __len__(self):
25
+ return len(self.captions)
26
+
27
+ def __getitem__(self, idx):
28
+ image_path = os.path.join(self.location, str(self.images[idx]))
29
+ images = self.transforms(Image.open(image_path))
30
+ texts = clip.tokenize([str(self.captions[idx])])[0]
31
+ return images, texts
32
+
33
+
34
+ class conceptual_captions(Dataset):
35
+ def __init__(
36
+ self, transforms, location, batch_size, *args, num_workers=16, **kwargs
37
+ ):
38
+ file_name = "Validation_GCC-1.1.0-Validation_output.csv"
39
+ file_path = os.path.join(location, file_name)
40
+ self.template = lambda c: f"a photo of a {c}."
41
+ self.train_dataset = CsvDataset(
42
+ input_filename=file_path,
43
+ transforms=transforms,
44
+ img_key="filepath",
45
+ caption_key="title",
46
+ )
47
+ # breakpoint()
48
+ self.train_loader = torch.utils.data.DataLoader(
49
+ self.train_dataset,
50
+ batch_size=batch_size,
51
+ shuffle=True,
52
+ num_workers=num_workers,
53
+ )
continual_clip/clip_original/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder is a lightly modified version of https://github.com/openai/CLIP.
continual_clip/clip_original/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
continual_clip/clip_original/adapter.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # https://github.com/jxhe/unify-parameter-efficient-tuning
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class Adapter(nn.Module):
12
+ def __init__(self,
13
+ d_model=None,
14
+ bottleneck=None,
15
+ dropout=0.0,
16
+ init_option="lora",
17
+ adapter_scalar="1.0",
18
+ adapter_layernorm_option="in"):
19
+ super().__init__()
20
+ self.n_embd = d_model if d_model is None else d_model
21
+ self.down_size = bottleneck
22
+
23
+ #_before
24
+ self.adapter_layernorm_option = adapter_layernorm_option
25
+
26
+ self.adapter_layer_norm_before = None
27
+ if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
28
+ self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd)
29
+
30
+ if adapter_scalar == "learnable_scalar":
31
+ self.scale = nn.Parameter(torch.ones(1))
32
+ else:
33
+ self.scale = float(adapter_scalar)
34
+
35
+ self.linear = nn.Linear(self.n_embd, self.n_embd)
36
+
37
+ self.down_proj = nn.Linear(self.n_embd, 64)
38
+ self.non_linear_func = nn.ReLU()
39
+ self.up_proj = nn.Linear(self.down_size, self.n_embd)
40
+
41
+ self.dropout = dropout
42
+ if init_option == "bert":
43
+ raise NotImplementedError
44
+ elif init_option == "lora":
45
+ with torch.no_grad():
46
+ nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
47
+ nn.init.zeros_(self.up_proj.weight)
48
+ nn.init.zeros_(self.down_proj.bias)
49
+ nn.init.zeros_(self.up_proj.bias)
50
+ elif init_option == "linear":
51
+ with torch.no_grad():
52
+ nn.init.zeros_(self.linear.weight)
53
+
54
+ def forward(self, x, add_residual=True, residual=None):
55
+
56
+ residual = x if residual is None else residual
57
+ if self.adapter_layernorm_option == 'in': # none
58
+ x = self.adapter_layer_norm_before(x)
59
+
60
+ down = self.down_proj(x)
61
+ down = self.non_linear_func(down)
62
+ down = nn.functional.dropout(down, p=self.dropout, training=self.training)
63
+ up = self.up_proj(down)
64
+
65
+
66
+ up = up * self.scale
67
+
68
+ if self.adapter_layernorm_option == 'out': # none
69
+ up = self.adapter_layer_norm_before(up)
70
+
71
+ if add_residual:
72
+ output = up + residual
73
+ else:
74
+ output = up
75
+ return output
continual_clip/clip_original/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
continual_clip/clip_original/clip.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code ported from https://github.com/openai/CLIP
2
+
3
+ import hashlib
4
+ import os
5
+ import urllib
6
+ import warnings
7
+ from typing import Union, List
8
+
9
+ import torch
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, InterpolationMode
11
+ from tqdm import tqdm
12
+ from .model import build_model
13
+ from clip.tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ __all__ = ["available_models", "load", "tokenize"]
16
+ _tokenizer = _Tokenizer()
17
+
18
+ _MODELS = {
19
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
20
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
21
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
22
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
23
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
24
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
25
+ }
26
+
27
+
28
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
29
+ os.makedirs(root, exist_ok=True)
30
+ filename = os.path.basename(url)
31
+
32
+ expected_sha256 = url.split("/")[-2]
33
+ download_target = os.path.join(root, filename)
34
+
35
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
36
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
37
+
38
+ if os.path.isfile(download_target):
39
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
40
+ return download_target
41
+ else:
42
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
43
+
44
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
45
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
46
+ while True:
47
+ buffer = source.read(8192)
48
+ if not buffer:
49
+ break
50
+
51
+ output.write(buffer)
52
+ loop.update(len(buffer))
53
+
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
55
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
56
+
57
+ return download_target
58
+
59
+ def _convert_to_rgb(image):
60
+ return image.convert('RGB')
61
+
62
+ def _transform(n_px: int, is_train: bool):
63
+ normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
64
+ if is_train:
65
+ return Compose([
66
+ RandomResizedCrop(n_px, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
67
+ _convert_to_rgb,
68
+ ToTensor(),
69
+ normalize,
70
+ ])
71
+ else:
72
+ return Compose([
73
+ Resize(n_px, interpolation=InterpolationMode.BICUBIC),
74
+ CenterCrop(n_px),
75
+ _convert_to_rgb,
76
+ ToTensor(),
77
+ normalize,
78
+ ])
79
+
80
+
81
+
82
+ def available_models() -> List[str]:
83
+ """Returns the names of available CLIP models"""
84
+ return list(_MODELS.keys())
85
+
86
+
87
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, is_train=False, pretrained=True):
88
+ """Load a CLIP model
89
+ Parameters
90
+ ----------
91
+ name : str
92
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
93
+ device : Union[str, torch.device]
94
+ The device to put the loaded model
95
+ jit : bool
96
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
97
+ Returns
98
+ -------
99
+ model : torch.nn.Module
100
+ The CLIP model
101
+ preprocess : Callable[[PIL.Image], torch.Tensor]
102
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
103
+ """
104
+ if name in _MODELS:
105
+ model_path = _download(_MODELS[name])
106
+ elif os.path.isfile(name):
107
+ model_path = name
108
+ else:
109
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
110
+
111
+ try:
112
+ # loading JIT archive
113
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
114
+ state_dict = None
115
+ except RuntimeError:
116
+ # loading saved state dict
117
+ if jit:
118
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
119
+ jit = False
120
+ state_dict = torch.load(model_path, map_location="cpu")
121
+
122
+ if not jit:
123
+ try:
124
+ model = build_model(state_dict or model.state_dict()).to(device)
125
+ except KeyError:
126
+ sd = {k[7:]: v for k,v in state_dict["state_dict"].items()}
127
+ model = build_model(sd).to(device)
128
+
129
+ if str(device) == "cpu":
130
+ model.float()
131
+ return model, \
132
+ _transform(model.visual.input_resolution, is_train=True), \
133
+ _transform(model.visual.input_resolution, is_train=False)
134
+
135
+ # patch the device names
136
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
137
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
138
+
139
+ def patch_device(module):
140
+ graphs = [module.graph] if hasattr(module, "graph") else []
141
+ if hasattr(module, "forward1"):
142
+ graphs.append(module.forward1.graph)
143
+
144
+ for graph in graphs:
145
+ for node in graph.findAllNodes("prim::Constant"):
146
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
147
+ node.copyAttributes(device_node)
148
+
149
+ model.apply(patch_device)
150
+ patch_device(model.encode_image)
151
+ patch_device(model.encode_text)
152
+
153
+ # patch dtype to float32 on CPU
154
+ if str(device) == "cpu":
155
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
156
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
157
+ float_node = float_input.node()
158
+
159
+ def patch_float(module):
160
+ graphs = [module.graph] if hasattr(module, "graph") else []
161
+ if hasattr(module, "forward1"):
162
+ graphs.append(module.forward1.graph)
163
+
164
+ for graph in graphs:
165
+ for node in graph.findAllNodes("aten::to"):
166
+ inputs = list(node.inputs())
167
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
168
+ if inputs[i].node()["value"] == 5:
169
+ inputs[i].node().copyAttributes(float_node)
170
+
171
+ model.apply(patch_float)
172
+ patch_float(model.encode_image)
173
+ patch_float(model.encode_text)
174
+
175
+ model.float()
176
+
177
+ return model, \
178
+ _transform(model.input_resolution.item(), is_train=True), \
179
+ _transform(model.input_resolution.item(), is_train=False)
180
+
181
+
182
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
183
+ """
184
+ Returns the tokenized representation of given input string(s)
185
+ Parameters
186
+ ----------
187
+ texts : Union[str, List[str]]
188
+ An input string or a list of input strings to tokenize
189
+ context_length : int
190
+ The context length to use; all CLIP models use 77 as the context length
191
+ Returns
192
+ -------
193
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
194
+ """
195
+ if isinstance(texts, str):
196
+ texts = [texts]
197
+
198
+ sot_token = _tokenizer.encoder["<start_of_text>"]
199
+ eot_token = _tokenizer.encoder["<end_of_text>"]
200
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
201
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
202
+
203
+ for i, tokens in enumerate(all_tokens):
204
+ if len(tokens) > context_length: # Truncate
205
+ tokens = tokens[:context_length]
206
+ result[i, :len(tokens)] = torch.tensor(tokens)
207
+
208
+ return result
continual_clip/clip_original/model.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from .adapter import Adapter
11
+ from torch.distributions.normal import Normal
12
+ from collections import Counter
13
+
14
+ global_taskid = 0
15
+ global_is_train=True
16
+
17
+ class Bottleneck(nn.Module):
18
+ expansion = 4
19
+
20
+ def __init__(self, inplanes, planes, stride=1):
21
+ super().__init__()
22
+
23
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
24
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(planes)
26
+
27
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
28
+ self.bn2 = nn.BatchNorm2d(planes)
29
+
30
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
31
+
32
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
33
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
34
+
35
+ self.relu = nn.ReLU(inplace=True)
36
+ self.downsample = None
37
+ self.stride = stride
38
+
39
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
40
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
41
+ self.downsample = nn.Sequential(OrderedDict([
42
+ ("-1", nn.AvgPool2d(stride)),
43
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
44
+ ("1", nn.BatchNorm2d(planes * self.expansion))
45
+ ]))
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ identity = x
49
+
50
+ out = self.relu(self.bn1(self.conv1(x)))
51
+ out = self.relu(self.bn2(self.conv2(out)))
52
+ out = self.avgpool(out)
53
+ out = self.bn3(self.conv3(out))
54
+
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+
58
+ out += identity
59
+ out = self.relu(out)
60
+ return out
61
+
62
+
63
+ class AttentionPool2d(nn.Module):
64
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
65
+ super().__init__()
66
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
67
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
68
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
69
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
70
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
71
+ self.num_heads = num_heads
72
+
73
+ def forward(self, x):
74
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
75
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
76
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
77
+ x, _ = F.multi_head_attention_forward(
78
+ query=x, key=x, value=x,
79
+ embed_dim_to_check=x.shape[-1],
80
+ num_heads=self.num_heads,
81
+ q_proj_weight=self.q_proj.weight,
82
+ k_proj_weight=self.k_proj.weight,
83
+ v_proj_weight=self.v_proj.weight,
84
+ in_proj_weight=None,
85
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
86
+ bias_k=None,
87
+ bias_v=None,
88
+ add_zero_attn=False,
89
+ dropout_p=0,
90
+ out_proj_weight=self.c_proj.weight,
91
+ out_proj_bias=self.c_proj.bias,
92
+ use_separate_proj_weight=True,
93
+ training=self.training,
94
+ need_weights=False
95
+ )
96
+
97
+ return x[0]
98
+
99
+
100
+ class ModifiedResNet(nn.Module):
101
+ """
102
+ A ResNet class that is similar to torchvision's but contains the following changes:
103
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
104
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
105
+ - The final pooling layer is a QKV attention instead of an average pool
106
+ """
107
+
108
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
109
+ super().__init__()
110
+ self.output_dim = output_dim
111
+ self.input_resolution = input_resolution
112
+
113
+ # the 3-layer stem
114
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
115
+ self.bn1 = nn.BatchNorm2d(width // 2)
116
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
117
+ self.bn2 = nn.BatchNorm2d(width // 2)
118
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
119
+ self.bn3 = nn.BatchNorm2d(width)
120
+ self.avgpool = nn.AvgPool2d(2)
121
+ self.relu = nn.ReLU(inplace=True)
122
+
123
+ # residual layers
124
+ self._inplanes = width # this is a *mutable* variable used during construction
125
+ self.layer1 = self._make_layer(width, layers[0])
126
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
127
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
128
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
129
+
130
+ embed_dim = width * 32 # the ResNet feature dimension
131
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
132
+
133
+ def _make_layer(self, planes, blocks, stride=1):
134
+ layers = [Bottleneck(self._inplanes, planes, stride)]
135
+
136
+ self._inplanes = planes * Bottleneck.expansion
137
+ for _ in range(1, blocks):
138
+ layers.append(Bottleneck(self._inplanes, planes))
139
+
140
+ return nn.Sequential(*layers)
141
+
142
+ def forward(self, x):
143
+ def stem(x):
144
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
145
+ x = self.relu(bn(conv(x)))
146
+ x = self.avgpool(x)
147
+ return x
148
+
149
+ x = x.type(self.conv1.weight.dtype)
150
+ x = stem(x)
151
+ x = self.layer1(x)
152
+ x = self.layer2(x)
153
+ x = self.layer3(x)
154
+ x = self.layer4(x)
155
+ x = self.attnpool(x)
156
+
157
+ return x
158
+
159
+
160
+ class LayerNorm(nn.LayerNorm):
161
+ """Subclass torch's LayerNorm to handle fp16."""
162
+
163
+ def forward(self, x: torch.Tensor):
164
+ orig_type = x.dtype
165
+ ret = super().forward(x.type(torch.float32))
166
+ return ret.type(orig_type)
167
+
168
+
169
+ class QuickGELU(nn.Module):
170
+ def forward(self, x: torch.Tensor):
171
+ return x * torch.sigmoid(1.702 * x)
172
+
173
+
174
+ class ResidualAttentionBlock(nn.Module):
175
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, text_or_image=None, flag=False):
176
+ super().__init__()
177
+ self.register_buffer("mean", torch.tensor([0.0]))
178
+ self.register_buffer("std", torch.tensor([1.0]))
179
+ self.attn = nn.MultiheadAttention(d_model, n_head)
180
+ self.ln_1 = LayerNorm(d_model)
181
+ self.mlp = nn.Sequential(OrderedDict([
182
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
183
+ ("gelu", QuickGELU()),
184
+ ("c_proj", nn.Linear(d_model * 4, d_model))
185
+ ]))
186
+ self.ln_2 = LayerNorm(d_model)
187
+ self.attn_mask = attn_mask
188
+ self.is_train = global_is_train
189
+ self.step = 1
190
+ self.top_k = 2
191
+ self.ffn_num = 64
192
+ self.experts_num = 1
193
+ self.softmax = nn.Softmax(1)
194
+ self.softplus = nn.Softplus()
195
+ self.noisy_gating = True
196
+ self.adaptmlp_list = nn.ModuleList()
197
+ self.text_or_image = text_or_image
198
+ self.flag = flag
199
+ if text_or_image == 'text':
200
+ print('vanilla text transformer')
201
+ self.choose_map_text = torch.zeros([ self.experts_num])
202
+ else:
203
+ print('vanilla image transformer')
204
+ self.choose_map_image = torch.zeros([ self.experts_num])
205
+
206
+ # self.taskid = None
207
+ def attention(self, x: torch.Tensor):
208
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
209
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
210
+
211
+ def cv_squared(self, x):
212
+ """The squared coefficient of variation of a sample.
213
+ Useful as a loss to encourage a positive distribution to be more uniform.
214
+ Epsilons added for numerical stability.
215
+ Returns 0 for an empty Tensor.
216
+ Args:
217
+ x: a `Tensor`.
218
+ Returns:
219
+ a `Scalar`.
220
+ """
221
+ eps = 1e-10
222
+ # if only num_experts = 1
223
+
224
+ if x.shape[0] == 1:
225
+ return torch.tensor([0], device=x.device, dtype=x.dtype)
226
+ return x.float().var() / (x.float().mean()**2 + eps)
227
+
228
+ def _gates_to_load(self, gates):
229
+ """Compute the true load per expert, given the gates.
230
+ The load is the number of examples for which the corresponding gate is >0.
231
+ Args:
232
+ gates: a `Tensor` of shape [batch_size, n]
233
+ Returns:
234
+ a float32 `Tensor` of shape [n]
235
+ """
236
+ return (gates > 0).sum(0)
237
+
238
+ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
239
+ """Helper function to NoisyTopKGating.
240
+ Computes the probability that value is in top k, given different random noise.
241
+ This gives us a way of backpropagating from a loss that balances the number
242
+ of times each expert is in the top k experts per example.
243
+ In the case of no noise, pass in None for noise_stddev, and the result will
244
+ not be differentiable.
245
+ Args:
246
+ clean_values: a `Tensor` of shape [batch, n].
247
+ noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
248
+ normally distributed noise with standard deviation noise_stddev.
249
+ noise_stddev: a `Tensor` of shape [batch, n], or None
250
+ noisy_top_values: a `Tensor` of shape [batch, m].
251
+ "values" Output of tf.top_k(noisy_top_values, m). m >= k+1
252
+ Returns:
253
+ a `Tensor` of shape [batch, n].
254
+ """
255
+ # print('1231',clean_values) # 全nan
256
+ batch = clean_values.size(0)
257
+ m = noisy_top_values.size(1)
258
+ top_values_flat = noisy_top_values.flatten()
259
+
260
+ threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k
261
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
262
+ is_in = torch.gt(noisy_values, threshold_if_in)
263
+ threshold_positions_if_out = threshold_positions_if_in - 1
264
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
265
+ # is each value currently in the top k.
266
+ normal = Normal(self.mean, self.std)
267
+ #
268
+
269
+ prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
270
+ prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
271
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
272
+ return prob
273
+
274
+ def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2):
275
+ """Noisy top-k gating.
276
+ See paper: https://arxiv.org/abs/1701.06538.
277
+ Args:
278
+ x: input Tensor with shape [batch_size, input_size]
279
+ train: a boolean - we only add noise at training time.
280
+ noise_epsilon: a float
281
+ Returns:
282
+ gates: a Tensor with shape [batch_size, num_experts]
283
+ load: a Tensor with shape [num_experts]
284
+ """
285
+
286
+ clean_logits = x @ w_gate.to(x)
287
+ if self.noisy_gating and train:
288
+ raw_noise_stddev = x @ w_noise.to(x)
289
+ noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
290
+ noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
291
+ logits = noisy_logits
292
+ else:
293
+ logits = clean_logits
294
+ # calculate topk + 1 that will be needed for the noisy gates
295
+ top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1)
296
+ top_k_logits = top_logits[:, :self.top_k]
297
+ top_k_indices = top_indices[:, :self.top_k]
298
+ top_k_gates = self.softmax(top_k_logits)
299
+ zeros = torch.zeros_like(logits)
300
+ gates = zeros.scatter(1, top_k_indices, top_k_gates)
301
+ if self.noisy_gating and self.top_k < self.experts_num and train: # 目前未用上
302
+ load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
303
+ else:
304
+ load = self._gates_to_load(gates)
305
+ return gates, load
306
+
307
+ def forward(self, x: torch.Tensor):
308
+ x = x + self.attention(self.ln_1(x))
309
+ x = x + self.mlp(self.ln_2(x))
310
+ return x
311
+
312
+
313
+ class Transformer(nn.Module):
314
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, text_or_image=None,
315
+ flag =True,):
316
+ super().__init__()
317
+ self.width = width
318
+ self.layers = layers
319
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, text_or_image, flag) for _ in range(layers)])
320
+
321
+ def forward(self, x: torch.Tensor):
322
+ return self.resblocks(x)
323
+
324
+
325
+ class VisualTransformer(nn.Module):
326
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, text_or_image=None):
327
+ super().__init__()
328
+ self.input_resolution = input_resolution
329
+ self.output_dim = output_dim
330
+ # Added so this info is available. should not change anything.
331
+ self.patch_size = patch_size
332
+ self.width = width
333
+ self.layers = layers
334
+ self.heads = heads
335
+
336
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
337
+
338
+ scale = width ** -0.5
339
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
340
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
341
+ self.ln_pre = LayerNorm(width)
342
+
343
+ self.transformer = Transformer(width, layers, heads, text_or_image=text_or_image, flag=True)
344
+
345
+ self.ln_post = LayerNorm(width)
346
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
347
+
348
+ def forward(self, x: torch.Tensor):
349
+ x = self.conv1(x)
350
+ x = x.reshape(x.shape[0], x.shape[1], -1)
351
+ x = x.permute(0, 2, 1)
352
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
353
+ x = x + self.positional_embedding.to(x.dtype)
354
+ x = self.ln_pre(x)
355
+
356
+ x = x.permute(1, 0, 2) # NLD -> LND
357
+ x = self.transformer(x)
358
+ x = x.permute(1, 0, 2) # LND -> NLD
359
+
360
+ x = self.ln_post(x[:, 0, :])
361
+
362
+ if self.proj is not None:
363
+ x = x @ self.proj
364
+
365
+ return x
366
+
367
+
368
+ class CLIP(nn.Module):
369
+ def __init__(self,
370
+ embed_dim: int,
371
+ # vision
372
+ image_resolution: int,
373
+ vision_layers: Union[Tuple[int, int, int, int], int],
374
+ vision_width: int,
375
+ vision_patch_size: int,
376
+ # text
377
+ context_length: int,
378
+ vocab_size: int,
379
+ transformer_width: int,
380
+ transformer_heads: int,
381
+ transformer_layers: int,
382
+ baseline = False
383
+ ):
384
+ super().__init__()
385
+ self.baseline = baseline
386
+
387
+ self.context_length = context_length
388
+
389
+ if isinstance(vision_layers, (tuple, list)):
390
+ vision_heads = vision_width * 32 // 64
391
+ self.visual = ModifiedResNet(
392
+ layers=vision_layers,
393
+ output_dim=embed_dim,
394
+ heads=vision_heads,
395
+ input_resolution=image_resolution,
396
+ width=vision_width
397
+ )
398
+ else:
399
+ vision_heads = vision_width // 64
400
+ self.visual = VisualTransformer(
401
+ input_resolution=image_resolution,
402
+ patch_size=vision_patch_size,
403
+ width=vision_width,
404
+ layers=vision_layers,
405
+ heads=vision_heads,
406
+ output_dim=embed_dim,
407
+ text_or_image='image',
408
+ )
409
+
410
+ # self.transformer = Transformer(
411
+ # width=transformer_width,
412
+ # layers=transformer_layers,
413
+ # heads=transformer_heads,
414
+ # attn_mask=self.build_attention_mask(),
415
+ # text_or_image='text'
416
+ # )
417
+ self.transformer = Transformer(
418
+ width=transformer_width,
419
+ layers=transformer_layers,
420
+ heads=transformer_heads,
421
+ attn_mask=self.build_attention_mask(),
422
+ text_or_image='text',
423
+ flag = True,
424
+ )
425
+
426
+
427
+ self.vocab_size = vocab_size
428
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
429
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
430
+ self.ln_final = LayerNorm(transformer_width)
431
+
432
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
433
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
434
+
435
+
436
+ self.initialize_parameters()
437
+
438
+ def initialize_parameters(self):
439
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
440
+ nn.init.normal_(self.positional_embedding, std=0.01)
441
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
442
+
443
+ if isinstance(self.visual, ModifiedResNet):
444
+ if self.visual.attnpool is not None:
445
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
446
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
447
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
448
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
449
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
450
+
451
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
452
+ for name, param in resnet_block.named_parameters():
453
+ if name.endswith("bn3.weight"):
454
+ nn.init.zeros_(param)
455
+
456
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
457
+ attn_std = self.transformer.width ** -0.5
458
+ fc_std = (2 * self.transformer.width) ** -0.5
459
+ for block in self.transformer.resblocks:
460
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
461
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
462
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
463
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
464
+
465
+ if self.text_projection is not None:
466
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
467
+
468
+ def build_attention_mask(self):
469
+ # lazily create causal attention mask, with full attention between the vision tokens
470
+ # pytorch uses additive attention mask; fill with -inf
471
+ mask = torch.empty(self.context_length, self.context_length)
472
+ mask.fill_(float("-inf"))
473
+ mask.triu_(1) # zero out the lower diagonal
474
+ return mask
475
+
476
+ @property
477
+ def dtype(self):
478
+ return self.visual.conv1.weight.dtype
479
+
480
+ def encode_image(self, image):
481
+ return self.visual(image.type(self.dtype))
482
+
483
+ def encode_text(self, text):
484
+
485
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
486
+
487
+ x = x + self.positional_embedding.type(self.dtype)
488
+ x = x.permute(1, 0, 2) # NLD -> LND
489
+ x = self.transformer(x)
490
+ x = x.permute(1, 0, 2) # LND -> NLD
491
+ x = self.ln_final(x).type(self.dtype)
492
+
493
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
494
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
495
+
496
+ return x
497
+
498
+ def forward(self, image, text, taskid, is_train):
499
+ global global_taskid, global_is_train
500
+ global_taskid = taskid
501
+ global_is_train = is_train
502
+ return self.encode_image(image)
503
+
504
+
505
+
506
+ def convert_weights(model: nn.Module):
507
+ """Convert applicable model parameters to fp16"""
508
+
509
+ def _convert_weights_to_fp16(l):
510
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
511
+ l.weight.data = l.weight.data.half()
512
+ if l.bias is not None:
513
+ l.bias.data = l.bias.data.half()
514
+
515
+ if isinstance(l, nn.MultiheadAttention):
516
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
517
+ tensor = getattr(l, attr)
518
+ if tensor is not None:
519
+ tensor.data = tensor.data.half()
520
+
521
+ for name in ["text_projection", "proj"]:
522
+ if hasattr(l, name):
523
+ attr = getattr(l, name)
524
+ if attr is not None:
525
+ attr.data = attr.data.half()
526
+
527
+ model.apply(_convert_weights_to_fp16)
528
+
529
+
530
+ def build_model(state_dict: dict):
531
+ vit = "visual.proj" in state_dict
532
+
533
+ if vit:
534
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
535
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
536
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
537
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
538
+ image_resolution = vision_patch_size * grid_size
539
+ else:
540
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
541
+ vision_layers = tuple(counts)
542
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
543
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
544
+ vision_patch_size = None
545
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
546
+ image_resolution = output_width * 32
547
+
548
+ embed_dim = state_dict["text_projection"].shape[1]
549
+ context_length = state_dict["positional_embedding"].shape[0]
550
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
551
+ transformer_width = state_dict["ln_final.weight"].shape[0]
552
+ transformer_heads = transformer_width // 64
553
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
554
+
555
+ model = CLIP(
556
+ embed_dim,
557
+ image_resolution, vision_layers, vision_width, vision_patch_size,
558
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
559
+ )
560
+
561
+ for key in ["input_resolution", "context_length", "vocab_size"]:
562
+ if key in state_dict:
563
+ del state_dict[key]
564
+
565
+ model.load_state_dict(state_dict, strict=False)
566
+ for p in model.parameters():
567
+ p.data = p.data.float()
568
+ return model.eval()
continual_clip/clip_original/tokenizer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ if not special_tokens:
74
+ special_tokens = ['<start_of_text>', '<end_of_text>']
75
+ else:
76
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
77
+ vocab.extend(special_tokens)
78
+ self.encoder = dict(zip(vocab, range(len(vocab))))
79
+ self.decoder = {v: k for k, v in self.encoder.items()}
80
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
81
+ self.cache = {t:t for t in special_tokens}
82
+ special = "|".join(special_tokens)
83
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
84
+
85
+ self.vocab_size = len(self.encoder)
86
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
87
+
88
+ def bpe(self, token):
89
+ if token in self.cache:
90
+ return self.cache[token]
91
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
92
+ pairs = get_pairs(word)
93
+
94
+ if not pairs:
95
+ return token+'</w>'
96
+
97
+ while True:
98
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
99
+ if bigram not in self.bpe_ranks:
100
+ break
101
+ first, second = bigram
102
+ new_word = []
103
+ i = 0
104
+ while i < len(word):
105
+ try:
106
+ j = word.index(first, i)
107
+ new_word.extend(word[i:j])
108
+ i = j
109
+ except:
110
+ new_word.extend(word[i:])
111
+ break
112
+
113
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
114
+ new_word.append(first+second)
115
+ i += 2
116
+ else:
117
+ new_word.append(word[i])
118
+ i += 1
119
+ new_word = tuple(new_word)
120
+ word = new_word
121
+ if len(word) == 1:
122
+ break
123
+ else:
124
+ pairs = get_pairs(word)
125
+ word = ' '.join(word)
126
+ self.cache[token] = word
127
+ return word
128
+
129
+ def encode(self, text):
130
+ bpe_tokens = []
131
+ text = whitespace_clean(basic_clean(text)).lower()
132
+ for token in re.findall(self.pat, text):
133
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
134
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
135
+ return bpe_tokens
136
+
137
+ def decode(self, tokens):
138
+ text = ''.join([self.decoder[token] for token in tokens])
139
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
140
+ return text
continual_clip/datasets.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import torch.nn as nn
5
+
6
+ from continuum import ClassIncremental, InstanceIncremental
7
+ from continuum.datasets import (
8
+ CIFAR100, ImageNet100, TinyImageNet200, ImageFolderDataset, Core50,
9
+ fgvc_aircraft, Caltech101, DTD, EuroSAT, flowers102, food101,
10
+ MNIST, OxfordPet, SUN397
11
+
12
+ )
13
+ from .utils import get_dataset_class_names
14
+
15
+
16
+ class ImageNet1000(ImageFolderDataset):
17
+ """Continuum dataset for datasetsss with tree-like structure.
18
+ :param train_folder: The folder of the train data.
19
+ :param test_folder: The folder of the test data.
20
+ :param download: Dummy parameter.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ data_path: str,
26
+ train: bool = True,
27
+ download: bool = False,
28
+ ):
29
+ super().__init__(data_path=data_path, train=train, download=download)
30
+
31
+ def get_data(self):
32
+ if self.train:
33
+ self.data_path = os.path.join(self.data_path, "train")
34
+ else:
35
+ self.data_path = os.path.join(self.data_path, "val")
36
+ return super().get_data()
37
+
38
+
39
+ def get_dataset(cfg, is_train, transforms=None):
40
+ if cfg.dataset == "cifar100":
41
+ data_path = os.path.join(cfg.dataset_root, cfg.dataset)
42
+ dataset = CIFAR100(
43
+ data_path=data_path,
44
+ download=True,
45
+ train=is_train,
46
+ # transforms=transforms
47
+ )
48
+ classes_names = dataset.dataset.classes
49
+
50
+ # elif cfg.dataset == "tiny-imagenet-200":
51
+ elif cfg.dataset == "tinyimagenet":
52
+ # data_path = '/data/kangborui/'
53
+ data_path = os.path.join(cfg.dataset_root, cfg.dataset)
54
+ dataset = TinyImageNet200(
55
+ data_path,
56
+ train=is_train,
57
+ download=True
58
+ )
59
+ classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
60
+
61
+ elif cfg.dataset == "imagenet100":
62
+ data_path = cfg.dataset_root
63
+ # data_path = os.path.join(cfg.dataset_root, "ImageNet")
64
+ dataset = ImageNet100(
65
+ data_path,
66
+ train=is_train,
67
+ data_subset=os.path.join('/home/kangborui/ClProject/MoE-Adapters4CL-cross-guild-fusion/cil/dataset_reqs/imagenet100_splits', "train_100.txt" if is_train else "val_100.txt")
68
+ )
69
+ classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
70
+
71
+ elif cfg.dataset == "imagenet1000":
72
+ data_path = os.path.join(cfg.dataset_root, cfg.dataset)
73
+ dataset = ImageNet1000(
74
+ data_path,
75
+ train=is_train
76
+ )
77
+ classes_names = get_dataset_class_names(cfg.workdir, cfg.dataset)
78
+
79
+ elif cfg.dataset == "core50":
80
+ data_path = os.path.join(cfg.dataset_root, cfg.dataset)
81
+ dataset = dataset = Core50(
82
+ data_path,
83
+ scenario="domains",
84
+ classification="category",
85
+ train=is_train
86
+ )
87
+ classes_names = [
88
+ "plug adapters", "mobile phones", "scissors", "light bulbs", "cans",
89
+ "glasses", "balls", "markers", "cups", "remote controls"
90
+ ]
91
+
92
+ else:
93
+ ValueError(f"'{cfg.dataset}' is a invalid dataset.")
94
+
95
+ return dataset, classes_names
96
+
97
+
98
+ def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module:
99
+
100
+ dataset, classes_names = get_dataset(cfg, is_train)
101
+
102
+ if cfg.scenario == "class":
103
+ scenario = ClassIncremental(
104
+ dataset,
105
+ initial_increment=cfg.initial_increment,
106
+ increment=cfg.increment,
107
+ transformations=transforms.transforms, # Convert Compose into list
108
+ class_order=cfg.class_order,
109
+ )
110
+
111
+ elif cfg.scenario == "domain":
112
+ scenario = InstanceIncremental(
113
+ dataset,
114
+ transformations=transforms.transforms,
115
+ )
116
+
117
+ elif cfg.scenario == "task-agnostic":
118
+ NotImplementedError("Method has not been implemented. Soon be added.")
119
+
120
+ else:
121
+ ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, "
122
+ "please choose from {{'class', 'domain', 'task-agnostic'}}.")
123
+
124
+ return scenario, classes_names
continual_clip/dynamic_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ # from torchvision import datasetsss
8
+ from . import datasets, utils
9
+ # import clip.clip as clip
10
+ import clip
11
+ import numpy as np
12
+
13
+ class DynamicDataset():
14
+
15
+ def __init__(self, cfg):
16
+ self.ref_database = {} # all data key = dataset_name; value is 4D tensor (num_image, 3, 224, 224)
17
+ self.ref_names = [] # collect the name of the dataset
18
+ self.ref_model, _, self.test_preprocess = clip.load(cfg.model_name, jit=False)
19
+ self.cur_dataset = None
20
+ self.memory_size = 5000
21
+ self.batch_id = 0
22
+
23
+ def update(self, dataset, load):
24
+ # load is a model directly
25
+ self.cur_dataset = dataset
26
+ if not load: # first round USELESS FOR CICL
27
+ new_dataset = self.getNewDataset()
28
+ self.ref_database[self.cur_dataset] = new_dataset[:self.memory_size]
29
+ self.ref_names.append(self.cur_dataset)
30
+ else: # other rounds
31
+ self.ref_model = load
32
+ self.reduceExampleSet()
33
+ self.constructExampleSet()
34
+
35
+ def reduceExampleSet(self):
36
+ print("Reducing Example Set")
37
+ K, t = self.memory_size, len(self.ref_names)+1
38
+ m = K // t
39
+ for dataset in self.ref_names:
40
+ self.ref_database[dataset] = self.ref_database[dataset][:m]
41
+
42
+ def constructExampleSet(self):
43
+ # breakpoint()
44
+ print("Constructing Example Set")
45
+ self.ref_names.append(self.batch_id)
46
+ new_dataset = torch.tensor(self.getNewDataset())
47
+ image_feature = []
48
+ num = new_dataset.shape[0]
49
+
50
+ print("[Constructing] Calculating Distance")
51
+ for ndx in tqdm(np.arange(num)):
52
+ img = torch.unsqueeze(new_dataset[ndx], dim=0)
53
+ img = img.cuda()
54
+ img_feature = self.ref_model(img, None)
55
+ image_feature.append(img_feature.cpu().detach().tolist())
56
+ image_feature = torch.tensor(image_feature)
57
+ image_feature = torch.squeeze(image_feature, dim=1)
58
+ image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
59
+ image_feature = np.array(image_feature.cpu().detach())
60
+ image_feature_average = image_feature.mean(axis=0)
61
+
62
+ K, t = self.memory_size, len(self.ref_names)
63
+ m = K - K // t
64
+ update_dataset = []
65
+ if not m:
66
+ m = self.memory_size
67
+ cur_embedding_sum = None
68
+ print("[Constructing] Collecting Examples")
69
+ for k in tqdm(np.arange(min(m, len(image_feature)))):
70
+ if not k:
71
+ index = np.argmin(
72
+ np.sum((image_feature_average - image_feature)**2, axis=1)
73
+ )
74
+ cur_embedding_sum = image_feature[index]
75
+ update_dataset.append((new_dataset.cpu())[index].tolist())
76
+ image_feature = np.delete(image_feature, index, axis=0)
77
+ else:
78
+ index = np.argmin(
79
+ np.sum((
80
+ image_feature_average - (1/(k+1))*(image_feature + cur_embedding_sum)
81
+ )**2, axis=1)
82
+ )
83
+ cur_embedding_sum += image_feature[index]
84
+ update_dataset.append((new_dataset.cpu())[index].tolist())
85
+ image_feature = np.delete(image_feature, index, axis=0)
86
+
87
+ self.ref_database[self.batch_id] = update_dataset
88
+ print("finishing current task", self.batch_id)
89
+ self.batch_id = self.batch_id + 1
90
+
91
+ def getNewDataset(self):
92
+ samples = []
93
+ count = 0
94
+ for sample in tqdm(self.cur_dataset):
95
+ if count == 10000:
96
+ return samples
97
+ count += 1
98
+ samples.append(sample[0].tolist())
99
+ return samples
100
+
101
+ def get(self):
102
+ print("Getting Reference Images")
103
+ value = list(self.ref_database.values())
104
+ out = []
105
+ for i in tqdm(value):
106
+ out += i
107
+ return torch.tensor(out)
108
+
continual_clip/models.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import DictConfig
2
+ from tqdm import tqdm
3
+ import torch.nn.functional as F
4
+ import clip.clip as clip
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ from .utils import get_class_ids_per_task, get_class_names
9
+ from . import utils
10
+ from .dynamic_dataset import DynamicDataset
11
+
12
+ DEFAULT_THRESHOLD = 0.985
13
+ TOP_SELECT = 1
14
+ EPOCH_NUM = 4
15
+ TOP_K_RATIO = 0.1
16
+ LAMBDA_SCALE = 30
17
+ LAYER_NUM = 12
18
+
19
+ class ClassIncremental(nn.Module):
20
+ def __init__(self, cfg, device, origin_flag, jit=False):
21
+ super().__init__()
22
+ self.prompt_template = cfg.prompt_template
23
+ self.device = device
24
+ self.classes_names = None
25
+ self.origin_flag = origin_flag
26
+ self.model, self.transforms, _ = clip.load(cfg.model_name, device=device, jit=jit)
27
+ self.ref_model = None
28
+ self.class_ids_per_task = list(get_class_ids_per_task(cfg))
29
+ self.current_class_names = []
30
+ self.text_tokens = None
31
+ self.dynamic_dataset = DynamicDataset(cfg)
32
+ self.prev_gradients = None
33
+ self.visual_cur_matrix = {}
34
+ self.visual_U = {}
35
+ self.loss_list = []
36
+
37
+
38
+
39
+ def forward(self, image, taskid):
40
+ with torch.no_grad():
41
+ logits_per_image, _ = self.model(image, self.text_tokens, 0, is_train=False)
42
+ probs = logits_per_image.softmax(dim=-1)
43
+ return probs
44
+
45
+ def adaptation(self, task_id, cfg, train_dataset, train_classes_names, world):
46
+ self.current_class_names += get_class_names(self.classes_names, self.class_ids_per_task[task_id])
47
+ self.text_tokens = clip.tokenize(
48
+ [self.prompt_template.format(c) for c in self.current_class_names]
49
+ ).cuda(device=2)
50
+ if cfg.method != "zeroshot":
51
+ self.train(task_id, cfg, train_dataset, train_classes_names, world)
52
+
53
+
54
+
55
+ def train(self, task_id, cfg, train_dataset, train_classes_names, world):
56
+
57
+ train_loader = DataLoader(train_dataset[task_id:task_id + 1],
58
+ batch_size=cfg.batch_size,
59
+ shuffle=True, num_workers=8)
60
+
61
+ train_iter = iter(train_loader)
62
+ EPOCH = EPOCH_NUM
63
+ num_batches = len(train_loader)
64
+ total_iterations = EPOCH * num_batches
65
+
66
+
67
+ for k, v in self.model.named_parameters():
68
+ if "adapt" not in k:
69
+ v.requires_grad = False
70
+
71
+ params = [
72
+ v for k, v in self.model.named_parameters() if "adapt" in k
73
+ ]
74
+ params_name = [
75
+ k for k, v in self.model.named_parameters() if "adapt" in k
76
+ ]
77
+
78
+ print('========trainable params============', params_name)
79
+ # optimizer
80
+ optimizer = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
81
+ scheduler = utils.cosine_lr(
82
+ optimizer, cfg.lr, 30, total_iterations
83
+ )
84
+ self.model = self.model.cuda(device=2)
85
+
86
+ classnames = get_class_names(self.classes_names, self.class_ids_per_task[task_id])
87
+ print(classnames)
88
+ texts = [self.prompt_template.format(c) for c in classnames]
89
+ texts = clip.tokenize(texts).cuda(device=2)
90
+
91
+ self.model.train()
92
+
93
+ batch_count = 0
94
+ lamda = [[0 for _ in range(LAYER_NUM)] for _ in range(LAYER_NUM)]
95
+ for iteration in tqdm(range(total_iterations + 1)):
96
+ scheduler(iteration)
97
+ try:
98
+ inputs, targets, task_ids = next(train_iter)
99
+ except:
100
+ train_iter = iter(train_loader)
101
+ inputs, targets, task_ids = next(train_iter)
102
+
103
+ if cfg.dataset == "tinyimagenet" and task_id != 0:
104
+ shift = 100 + (task_id - 1) * cfg.increment
105
+ targets -= shift
106
+ elif cfg.dataset == "imagenet100" and task_id != 0:
107
+ shift = cfg.initial_increment + (task_id - 1) * cfg.increment
108
+ targets -= shift
109
+ else:
110
+ shift = task_id * cfg.increment
111
+ targets -= shift
112
+
113
+ inputs, targets = inputs.cuda(device=2), targets.cuda(device=2)
114
+ logits_per_image, _ = self.model.cuda(device=2)(inputs, texts.cuda(device=2), 0, is_train=True) # 分开
115
+
116
+ loss = F.cross_entropy(logits_per_image, targets, label_smoothing=cfg.ls)
117
+ self.loss_list.append(loss)
118
+ print('CELoss: {}'.format(loss))
119
+ optimizer.zero_grad()
120
+ loss.backward()
121
+
122
+ if task_id != 0:
123
+ if batch_count == 0:
124
+ for j in range(LAYER_NUM):
125
+ activation_visual = self.model.visual.transformer.lora_feature[j]
126
+ activation_visual = torch.bmm(activation_visual.detach().permute(1, 2, 0),
127
+ activation_visual.detach().permute(1, 0, 2)).sum(dim=0)
128
+ U_visual, S, Vh = torch.linalg.svd(activation_visual, full_matrices=False)
129
+ U_visual = U_visual[:, :TOP_SELECT]
130
+
131
+ for k in range(LAYER_NUM):
132
+ v_visual = self.visual_U[k]
133
+
134
+ normalized_vector_visual = U_visual / torch.norm(U_visual)
135
+ similarities_visual = []
136
+ for column_visual in v_visual.t():
137
+ normalized_column_visual = column_visual / torch.norm(column_visual)
138
+ cos_sim_visual = torch.dot(normalized_vector_visual.squeeze(),
139
+ normalized_column_visual.squeeze())
140
+ similarities_visual.append(cos_sim_visual)
141
+
142
+ dot_products_visual = torch.mean(
143
+ torch.topk(torch.stack(similarities_visual), int(len(similarities_visual) * TOP_K_RATIO))[0])
144
+ lamda[j][k] = torch.exp(-dot_products_visual) * LAMBDA_SCALE
145
+
146
+ batch_count = batch_count + 1
147
+ for name, params in self.model.named_parameters():
148
+
149
+ for i in range(LAYER_NUM):
150
+ if 'visual' in name and 'adapt' in name and 'down' in name and 'weight' in name:
151
+ v = self.visual_U[i]
152
+ v_ = torch.mm(params.grad.data, v)
153
+ params.grad.data = torch.mm(v_, v.T)* lamda[int(name.split(".")[3])][i]
154
+
155
+ elif 'visual' in name and 'adapt' in name and 'up' in name and 'weight' in name:
156
+ v = self.visual_U[i]
157
+ v_ = torch.mm(v.T, params.grad.data)
158
+ params.grad.data = torch.mm(v, v_)* lamda[int(name.split(".")[3])][i]
159
+
160
+ optimizer.step()
161
+
162
+ torch.cuda.empty_cache()
163
+
164
+ train_loader_ = DataLoader(train_dataset[task_id:task_id + 1],
165
+ batch_size=128,
166
+ shuffle=True, num_workers=8)
167
+ counts = 0
168
+ models = self.model.cuda(2)
169
+ for inputs, targets, task_ids in tqdm(train_loader_):
170
+ inputs = inputs.cuda(device=2)
171
+ with torch.no_grad():
172
+ outputs = models(inputs, texts.cuda(2), 0, is_train=False)
173
+
174
+ for i in range(LAYER_NUM):
175
+ if len(self.visual_cur_matrix) == i:
176
+ activation = models.visual.transformer.lora_feature[i]
177
+ activation = torch.bmm(activation.detach().permute(1, 2, 0),
178
+ activation.detach().permute(1, 0, 2)).sum(dim=0)
179
+ self.visual_cur_matrix[i] = activation
180
+
181
+ U, S, Vh = torch.linalg.svd(activation, full_matrices=False)
182
+ self.visual_U[i] = U[:,TOP_SELECT:]
183
+
184
+ else:
185
+ activation = models.visual.transformer.lora_feature[i]
186
+ activation = torch.bmm(activation.detach().permute(1, 2, 0),
187
+ activation.detach().permute(1, 0, 2)).sum(dim=0)
188
+
189
+ U1, S1, Vh1 = torch.linalg.svd(activation, full_matrices=False)
190
+ Ui = torch.cat((self.visual_U[i], U1[:, TOP_SELECT:]), dim=1)
191
+ self.visual_U[i] = Ui
192
+
193
+ counts = counts + 1
194
+ if counts == 1:
195
+ break
196
+
197
+ torch.cuda.empty_cache()
198
+ self.model.eval()
199
+
200
+ class DomainIncremental(nn.Module):
201
+ pass
202
+
203
+
204
+ class TaskAgnostic(nn.Module):
205
+ pass
206
+
207
+
208
+ def load_model(cfg: DictConfig, device: torch.device, origin_flag) -> nn.Module:
209
+ r"""Load a CLIP model in different continual scenarios.
210
+
211
+ Arguments:
212
+ cfg (DictConfig): Experiment configurations.
213
+ device (torch.device): Device to train (or) evaluate the model on.
214
+
215
+ Returns:
216
+ nn.Module: Return scenario specific CLIP model.
217
+ """
218
+ if cfg.scenario == "class":
219
+ return ClassIncremental(cfg, device, origin_flag)
220
+ elif cfg.scenario == "domain":
221
+ return DomainIncremental(cfg, device)
222
+ elif cfg.scenario == "task-aganostic":
223
+ return TaskAgnostic(cfg, device)
224
+ else:
225
+ raise ValueError(f"""
226
+ `{cfg.scenarios}` is not a valid scenario,
227
+ Please choose from ['class', "domain', 'task-agnostic']
228
+ """)
continual_clip/utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import yaml
5
+
6
+ from omegaconf import DictConfig, OmegaConf
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import pickle
11
+ import random
12
+
13
+ import numpy as np
14
+
15
+ from clip.tokenizer import SimpleTokenizer as _Tokenizer
16
+
17
+ __all__ = ["available_models", "load", "tokenize"]
18
+ _tokenizer = _Tokenizer()
19
+
20
+ def get_class_order(file_name: str) -> list:
21
+ r"""TO BE DOCUMENTED"""
22
+ with open(file_name, "r+") as f:
23
+ data = yaml.safe_load(f)
24
+ return data["class_order"]
25
+
26
+
27
+ def get_class_ids_per_task(args):
28
+ yield args.class_order[:args.initial_increment]
29
+ for i in range(args.initial_increment, len(args.class_order), args.increment):
30
+ yield args.class_order[i:i + args.increment]
31
+
32
+ def get_class_names(classes_names, class_ids_per_task):
33
+ return [classes_names[class_id] for class_id in class_ids_per_task]
34
+
35
+
36
+ def get_dataset_class_names(workdir, dataset_name, long=False):
37
+ with open(os.path.join(workdir, "dataset_reqs", f"{dataset_name}_classes.txt"), "r") as f:
38
+ lines = f.read().splitlines()
39
+ return [line.split("\t")[-1] for line in lines]
40
+
41
+
42
+ def save_config(config: DictConfig) -> None:
43
+ OmegaConf.save(config, "config.yaml")
44
+
45
+
46
+ def get_workdir(path):
47
+ split_path = path.split("/")
48
+ workdir_idx = split_path.index("cil")
49
+ return "/".join(split_path[:workdir_idx+1])
50
+
51
+ ###########################
52
+ def assign_learning_rate(param_group, new_lr):
53
+ param_group["lr"] = new_lr
54
+
55
+
56
+ def _warmup_lr(base_lr, warmup_length, step):
57
+ return base_lr * (step + 1) / warmup_length
58
+
59
+
60
+ def cosine_lr(optimizer, base_lrs, warmup_length, steps):
61
+ if not isinstance(base_lrs, list):
62
+ base_lrs = [base_lrs for _ in optimizer.param_groups]
63
+ assert len(base_lrs) == len(optimizer.param_groups)
64
+
65
+ def _lr_adjuster(step):
66
+ for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
67
+ if step < warmup_length:
68
+ lr = _warmup_lr(base_lr, warmup_length, step)
69
+ else:
70
+ e = step - warmup_length
71
+ es = steps - warmup_length
72
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
73
+ assign_learning_rate(param_group, lr)
74
+
75
+ return _lr_adjuster
76
+
77
+
78
+ def accuracy(output, target, topk=(1,)):
79
+ pred = output.topk(max(topk), 1, True, True)[1].t()
80
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
81
+ return [
82
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
83
+ for k in topk
84
+ ]
85
+
86
+
87
+ def torch_save(classifier, save_path):
88
+ if os.path.dirname(save_path) != "":
89
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
90
+ torch.save({"state_dict": classifier.state_dict()}, save_path)
91
+ print("Checkpoint saved to", save_path)
92
+
93
+ # with open(save_path, 'wb') as f:
94
+ # pickle.dump(classifier.cpu(), f)
95
+
96
+
97
+ def torch_load(classifier, save_path, device=None):
98
+ checkpoint = torch.load(save_path)
99
+ missing_keys, unexpected_keys = classifier.load_state_dict(
100
+ checkpoint["state_dict"], strict=False
101
+ )
102
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
103
+ print("Missing keys:", missing_keys)
104
+ print("Unexpected keys:", unexpected_keys)
105
+ print("Checkpoint loaded from", save_path)
106
+ # with open(save_path, 'rb') as f:
107
+ # classifier = pickle.load(f)
108
+
109
+ if device is not None:
110
+ classifier = classifier.to(device)
111
+ return classifier
112
+
113
+
114
+ def get_logits(inputs, classifier):
115
+ assert callable(classifier)
116
+ if hasattr(classifier, "to"):
117
+ classifier = classifier.to(inputs.device)
118
+ return classifier(inputs)
119
+
120
+
121
+ def get_probs(inputs, classifier):
122
+ if hasattr(classifier, "predict_proba"):
123
+ probs = classifier.predict_proba(inputs.detach().cpu().numpy())
124
+ return torch.from_numpy(probs)
125
+ logits = get_logits(inputs, classifier)
126
+ return logits.softmax(dim=1)
127
+
128
+
129
+ class LabelSmoothing(torch.nn.Module):
130
+ def __init__(self, smoothing=0.0):
131
+ super(LabelSmoothing, self).__init__()
132
+ self.confidence = 1.0 - smoothing
133
+ self.smoothing = smoothing
134
+
135
+ def forward(self, x, target):
136
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
137
+
138
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
139
+ nll_loss = nll_loss.squeeze(1)
140
+ smooth_loss = -logprobs.mean(dim=-1)
141
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
142
+ return loss.mean()
143
+
144
+
145
+ def seed_all(seed):
146
+ torch.manual_seed(seed)
147
+ torch.cuda.manual_seed_all(seed)
148
+ np.random.seed(seed)
149
+ random.seed(seed)
150
+ torch.backends.cudnn.deterministic = True
151
+ torch.backends.cudnn.benchmark = False
152
+
153
+
154
+ def num_parameters(model):
155
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
156
+
157
+ def batch(iterable, n=64):
158
+ l = len(iterable)
159
+ for ndx in range(0, l, n):
160
+ yield iterable[ndx:min(ndx + n, l)]
161
+
162
+ def merge_we(model_0, model_1, sma_count):
163
+ for param_q, param_k in zip(model_0.parameters(), model_1.parameters()):
164
+ param_k.data = (param_k.data * sma_count + param_q.data) / (1.0 + sma_count)
165
+ return model_1
166
+
167
+ def wise_we(model_0, model_1, sma_count, model_n, alpha=0.95):
168
+ for param_q, param_k, param_n in zip(model_0.parameters(), model_1.parameters(), model_n.parameters()):
169
+ param_k.data = (
170
+ (param_k.data * sma_count + param_q.data) / (1.0 + sma_count)
171
+ ) * alpha + param_n.data * (1-alpha)
172
+ return model_1
173
+
174
+ def merge_we_router(model_0, model_1, sma_count):
175
+ for param_q, param_k, name_q, name_k in zip(model_0.parameters(), model_1.parameters(), model_0.named_parameters(), model_1.named_parameters()):
176
+ if "router" in name_k[0] or "noise" in name_k[0]:
177
+ param_k.data = (param_k.data * sma_count + param_q.data) / (1.0 + sma_count)
178
+ # print('111', name_k[0], name_q[0])
179
+ return model_1
180
+
181
+ def moving_avg(model_0, model_1, alpha=0.999):
182
+ for param_q, param_k in zip(model_0.parameters(), model_1.parameters()):
183
+ param_q.data = param_q.data * alpha + param_k.data * (1 - alpha)
184
+
185
+
186
+ def l2_loss(model, model_ref):
187
+ loss = 0.0
188
+ for param_q, param_k in zip(model.parameters(), model_ref.parameters()):
189
+ loss += F.mse_loss(param_q, param_k.detach(), reduction="sum")
190
+ return loss
191
+
192
+
193
+ def virtual_vocab(length=10, n_class=1000):
194
+ voc_len = len(_tokenizer.encoder)
195
+ # breakpoint()
196
+ texts = torch.randint(0, voc_len, (n_class, length))
197
+ start = torch.full((n_class, 1), _tokenizer.encoder["<start_of_text>"])
198
+ end = torch.full((n_class, 1), _tokenizer.encoder["<end_of_text>"])
199
+ zeros = torch.zeros((n_class, 75 - length), dtype=torch.long)
200
+
201
+ texts = torch.cat([start, texts, end, zeros], dim=1)
202
+ return texts
203
+
204
+ def distillation(t, s, T=2):
205
+ p = F.softmax(t / T, dim=1)
206
+ loss = F.cross_entropy(s / T, p, reduction="mean") * (T ** 2)
207
+ return loss
208
+
209
+
210
+
dataset_reqs/imagenet1000_classes.txt ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 tench
2
+ 1 goldfish
3
+ 2 great white shark
4
+ 3 tiger shark
5
+ 4 hammerhead shark
6
+ 5 electric ray
7
+ 6 stingray
8
+ 7 rooster
9
+ 8 hen
10
+ 9 ostrich
11
+ 10 brambling
12
+ 11 goldfinch
13
+ 12 house finch
14
+ 13 junco
15
+ 14 indigo bunting
16
+ 15 American robin
17
+ 16 bulbul
18
+ 17 jay
19
+ 18 magpie
20
+ 19 chickadee
21
+ 20 American dipper
22
+ 21 kite (bird of prey)
23
+ 22 bald eagle
24
+ 23 vulture
25
+ 24 great grey owl
26
+ 25 fire salamander
27
+ 26 smooth newt
28
+ 27 newt
29
+ 28 spotted salamander
30
+ 29 axolotl
31
+ 30 American bullfrog
32
+ 31 tree frog
33
+ 32 tailed frog
34
+ 33 loggerhead sea turtle
35
+ 34 leatherback sea turtle
36
+ 35 mud turtle
37
+ 36 terrapin
38
+ 37 box turtle
39
+ 38 banded gecko
40
+ 39 green iguana
41
+ 40 Carolina anole
42
+ 41 desert grassland whiptail lizard
43
+ 42 agama
44
+ 43 frilled-necked lizard
45
+ 44 alligator lizard
46
+ 45 Gila monster
47
+ 46 European green lizard
48
+ 47 chameleon
49
+ 48 Komodo dragon
50
+ 49 Nile crocodile
51
+ 50 American alligator
52
+ 51 triceratops
53
+ 52 worm snake
54
+ 53 ring-necked snake
55
+ 54 eastern hog-nosed snake
56
+ 55 smooth green snake
57
+ 56 kingsnake
58
+ 57 garter snake
59
+ 58 water snake
60
+ 59 vine snake
61
+ 60 night snake
62
+ 61 boa constrictor
63
+ 62 African rock python
64
+ 63 Indian cobra
65
+ 64 green mamba
66
+ 65 sea snake
67
+ 66 Saharan horned viper
68
+ 67 eastern diamondback rattlesnake
69
+ 68 sidewinder rattlesnake
70
+ 69 trilobite
71
+ 70 harvestman
72
+ 71 scorpion
73
+ 72 yellow garden spider
74
+ 73 barn spider
75
+ 74 European garden spider
76
+ 75 southern black widow
77
+ 76 tarantula
78
+ 77 wolf spider
79
+ 78 tick
80
+ 79 centipede
81
+ 80 black grouse
82
+ 81 ptarmigan
83
+ 82 ruffed grouse
84
+ 83 prairie grouse
85
+ 84 peafowl
86
+ 85 quail
87
+ 86 partridge
88
+ 87 african grey parrot
89
+ 88 macaw
90
+ 89 sulphur-crested cockatoo
91
+ 90 lorikeet
92
+ 91 coucal
93
+ 92 bee eater
94
+ 93 hornbill
95
+ 94 hummingbird
96
+ 95 jacamar
97
+ 96 toucan
98
+ 97 duck
99
+ 98 red-breasted merganser
100
+ 99 goose
101
+ 100 black swan
102
+ 101 tusker
103
+ 102 echidna
104
+ 103 platypus
105
+ 104 wallaby
106
+ 105 koala
107
+ 106 wombat
108
+ 107 jellyfish
109
+ 108 sea anemone
110
+ 109 brain coral
111
+ 110 flatworm
112
+ 111 nematode
113
+ 112 conch
114
+ 113 snail
115
+ 114 slug
116
+ 115 sea slug
117
+ 116 chiton
118
+ 117 chambered nautilus
119
+ 118 Dungeness crab
120
+ 119 rock crab
121
+ 120 fiddler crab
122
+ 121 red king crab
123
+ 122 American lobster
124
+ 123 spiny lobster
125
+ 124 crayfish
126
+ 125 hermit crab
127
+ 126 isopod
128
+ 127 white stork
129
+ 128 black stork
130
+ 129 spoonbill
131
+ 130 flamingo
132
+ 131 little blue heron
133
+ 132 great egret
134
+ 133 bittern bird
135
+ 134 crane bird
136
+ 135 limpkin
137
+ 136 common gallinule
138
+ 137 American coot
139
+ 138 bustard
140
+ 139 ruddy turnstone
141
+ 140 dunlin
142
+ 141 common redshank
143
+ 142 dowitcher
144
+ 143 oystercatcher
145
+ 144 pelican
146
+ 145 king penguin
147
+ 146 albatross
148
+ 147 grey whale
149
+ 148 killer whale
150
+ 149 dugong
151
+ 150 sea lion
152
+ 151 Chihuahua
153
+ 152 Japanese Chin
154
+ 153 Maltese
155
+ 154 Pekingese
156
+ 155 Shih Tzu
157
+ 156 King Charles Spaniel
158
+ 157 Papillon
159
+ 158 toy terrier
160
+ 159 Rhodesian Ridgeback
161
+ 160 Afghan Hound
162
+ 161 Basset Hound
163
+ 162 Beagle
164
+ 163 Bloodhound
165
+ 164 Bluetick Coonhound
166
+ 165 Black and Tan Coonhound
167
+ 166 Treeing Walker Coonhound
168
+ 167 English foxhound
169
+ 168 Redbone Coonhound
170
+ 169 borzoi
171
+ 170 Irish Wolfhound
172
+ 171 Italian Greyhound
173
+ 172 Whippet
174
+ 173 Ibizan Hound
175
+ 174 Norwegian Elkhound
176
+ 175 Otterhound
177
+ 176 Saluki
178
+ 177 Scottish Deerhound
179
+ 178 Weimaraner
180
+ 179 Staffordshire Bull Terrier
181
+ 180 American Staffordshire Terrier
182
+ 181 Bedlington Terrier
183
+ 182 Border Terrier
184
+ 183 Kerry Blue Terrier
185
+ 184 Irish Terrier
186
+ 185 Norfolk Terrier
187
+ 186 Norwich Terrier
188
+ 187 Yorkshire Terrier
189
+ 188 Wire Fox Terrier
190
+ 189 Lakeland Terrier
191
+ 190 Sealyham Terrier
192
+ 191 Airedale Terrier
193
+ 192 Cairn Terrier
194
+ 193 Australian Terrier
195
+ 194 Dandie Dinmont Terrier
196
+ 195 Boston Terrier
197
+ 196 Miniature Schnauzer
198
+ 197 Giant Schnauzer
199
+ 198 Standard Schnauzer
200
+ 199 Scottish Terrier
201
+ 200 Tibetan Terrier
202
+ 201 Australian Silky Terrier
203
+ 202 Soft-coated Wheaten Terrier
204
+ 203 West Highland White Terrier
205
+ 204 Lhasa Apso
206
+ 205 Flat-Coated Retriever
207
+ 206 Curly-coated Retriever
208
+ 207 Golden Retriever
209
+ 208 Labrador Retriever
210
+ 209 Chesapeake Bay Retriever
211
+ 210 German Shorthaired Pointer
212
+ 211 Vizsla
213
+ 212 English Setter
214
+ 213 Irish Setter
215
+ 214 Gordon Setter
216
+ 215 Brittany dog
217
+ 216 Clumber Spaniel
218
+ 217 English Springer Spaniel
219
+ 218 Welsh Springer Spaniel
220
+ 219 Cocker Spaniel
221
+ 220 Sussex Spaniel
222
+ 221 Irish Water Spaniel
223
+ 222 Kuvasz
224
+ 223 Schipperke
225
+ 224 Groenendael dog
226
+ 225 Malinois
227
+ 226 Briard
228
+ 227 Australian Kelpie
229
+ 228 Komondor
230
+ 229 Old English Sheepdog
231
+ 230 Shetland Sheepdog
232
+ 231 collie
233
+ 232 Border Collie
234
+ 233 Bouvier des Flandres dog
235
+ 234 Rottweiler
236
+ 235 German Shepherd Dog
237
+ 236 Dobermann
238
+ 237 Miniature Pinscher
239
+ 238 Greater Swiss Mountain Dog
240
+ 239 Bernese Mountain Dog
241
+ 240 Appenzeller Sennenhund
242
+ 241 Entlebucher Sennenhund
243
+ 242 Boxer
244
+ 243 Bullmastiff
245
+ 244 Tibetan Mastiff
246
+ 245 French Bulldog
247
+ 246 Great Dane
248
+ 247 St. Bernard
249
+ 248 husky
250
+ 249 Alaskan Malamute
251
+ 250 Siberian Husky
252
+ 251 Dalmatian
253
+ 252 Affenpinscher
254
+ 253 Basenji
255
+ 254 pug
256
+ 255 Leonberger
257
+ 256 Newfoundland dog
258
+ 257 Great Pyrenees dog
259
+ 258 Samoyed
260
+ 259 Pomeranian
261
+ 260 Chow Chow
262
+ 261 Keeshond
263
+ 262 brussels griffon
264
+ 263 Pembroke Welsh Corgi
265
+ 264 Cardigan Welsh Corgi
266
+ 265 Toy Poodle
267
+ 266 Miniature Poodle
268
+ 267 Standard Poodle
269
+ 268 Mexican hairless dog (xoloitzcuintli)
270
+ 269 grey wolf
271
+ 270 Alaskan tundra wolf
272
+ 271 red wolf or maned wolf
273
+ 272 coyote
274
+ 273 dingo
275
+ 274 dhole
276
+ 275 African wild dog
277
+ 276 hyena
278
+ 277 red fox
279
+ 278 kit fox
280
+ 279 Arctic fox
281
+ 280 grey fox
282
+ 281 tabby cat
283
+ 282 tiger cat
284
+ 283 Persian cat
285
+ 284 Siamese cat
286
+ 285 Egyptian Mau
287
+ 286 cougar
288
+ 287 lynx
289
+ 288 leopard
290
+ 289 snow leopard
291
+ 290 jaguar
292
+ 291 lion
293
+ 292 tiger
294
+ 293 cheetah
295
+ 294 brown bear
296
+ 295 American black bear
297
+ 296 polar bear
298
+ 297 sloth bear
299
+ 298 mongoose
300
+ 299 meerkat
301
+ 300 tiger beetle
302
+ 301 ladybug
303
+ 302 ground beetle
304
+ 303 longhorn beetle
305
+ 304 leaf beetle
306
+ 305 dung beetle
307
+ 306 rhinoceros beetle
308
+ 307 weevil
309
+ 308 fly
310
+ 309 bee
311
+ 310 ant
312
+ 311 grasshopper
313
+ 312 cricket insect
314
+ 313 stick insect
315
+ 314 cockroach
316
+ 315 praying mantis
317
+ 316 cicada
318
+ 317 leafhopper
319
+ 318 lacewing
320
+ 319 dragonfly
321
+ 320 damselfly
322
+ 321 red admiral butterfly
323
+ 322 ringlet butterfly
324
+ 323 monarch butterfly
325
+ 324 small white butterfly
326
+ 325 sulphur butterfly
327
+ 326 gossamer-winged butterfly
328
+ 327 starfish
329
+ 328 sea urchin
330
+ 329 sea cucumber
331
+ 330 cottontail rabbit
332
+ 331 hare
333
+ 332 Angora rabbit
334
+ 333 hamster
335
+ 334 porcupine
336
+ 335 fox squirrel
337
+ 336 marmot
338
+ 337 beaver
339
+ 338 guinea pig
340
+ 339 common sorrel horse
341
+ 340 zebra
342
+ 341 pig
343
+ 342 wild boar
344
+ 343 warthog
345
+ 344 hippopotamus
346
+ 345 ox
347
+ 346 water buffalo
348
+ 347 bison
349
+ 348 ram (adult male sheep)
350
+ 349 bighorn sheep
351
+ 350 Alpine ibex
352
+ 351 hartebeest
353
+ 352 impala (antelope)
354
+ 353 gazelle
355
+ 354 arabian camel
356
+ 355 llama
357
+ 356 weasel
358
+ 357 mink
359
+ 358 European polecat
360
+ 359 black-footed ferret
361
+ 360 otter
362
+ 361 skunk
363
+ 362 badger
364
+ 363 armadillo
365
+ 364 three-toed sloth
366
+ 365 orangutan
367
+ 366 gorilla
368
+ 367 chimpanzee
369
+ 368 gibbon
370
+ 369 siamang
371
+ 370 guenon
372
+ 371 patas monkey
373
+ 372 baboon
374
+ 373 macaque
375
+ 374 langur
376
+ 375 black-and-white colobus
377
+ 376 proboscis monkey
378
+ 377 marmoset
379
+ 378 white-headed capuchin
380
+ 379 howler monkey
381
+ 380 titi monkey
382
+ 381 Geoffroy's spider monkey
383
+ 382 common squirrel monkey
384
+ 383 ring-tailed lemur
385
+ 384 indri
386
+ 385 Asian elephant
387
+ 386 African bush elephant
388
+ 387 red panda
389
+ 388 giant panda
390
+ 389 snoek fish
391
+ 390 eel
392
+ 391 silver salmon
393
+ 392 rock beauty fish
394
+ 393 clownfish
395
+ 394 sturgeon
396
+ 395 gar fish
397
+ 396 lionfish
398
+ 397 pufferfish
399
+ 398 abacus
400
+ 399 abaya
401
+ 400 academic gown
402
+ 401 accordion
403
+ 402 acoustic guitar
404
+ 403 aircraft carrier
405
+ 404 airliner
406
+ 405 airship
407
+ 406 altar
408
+ 407 ambulance
409
+ 408 amphibious vehicle
410
+ 409 analog clock
411
+ 410 apiary
412
+ 411 apron
413
+ 412 trash can
414
+ 413 assault rifle
415
+ 414 backpack
416
+ 415 bakery
417
+ 416 balance beam
418
+ 417 balloon
419
+ 418 ballpoint pen
420
+ 419 Band-Aid
421
+ 420 banjo
422
+ 421 baluster / handrail
423
+ 422 barbell
424
+ 423 barber chair
425
+ 424 barbershop
426
+ 425 barn
427
+ 426 barometer
428
+ 427 barrel
429
+ 428 wheelbarrow
430
+ 429 baseball
431
+ 430 basketball
432
+ 431 bassinet
433
+ 432 bassoon
434
+ 433 swimming cap
435
+ 434 bath towel
436
+ 435 bathtub
437
+ 436 station wagon
438
+ 437 lighthouse
439
+ 438 beaker
440
+ 439 military hat (bearskin or shako)
441
+ 440 beer bottle
442
+ 441 beer glass
443
+ 442 bell tower
444
+ 443 baby bib
445
+ 444 tandem bicycle
446
+ 445 bikini
447
+ 446 ring binder
448
+ 447 binoculars
449
+ 448 birdhouse
450
+ 449 boathouse
451
+ 450 bobsleigh
452
+ 451 bolo tie
453
+ 452 poke bonnet
454
+ 453 bookcase
455
+ 454 bookstore
456
+ 455 bottle cap
457
+ 456 hunting bow
458
+ 457 bow tie
459
+ 458 brass memorial plaque
460
+ 459 bra
461
+ 460 breakwater
462
+ 461 breastplate
463
+ 462 broom
464
+ 463 bucket
465
+ 464 buckle
466
+ 465 bulletproof vest
467
+ 466 high-speed train
468
+ 467 butcher shop
469
+ 468 taxicab
470
+ 469 cauldron
471
+ 470 candle
472
+ 471 cannon
473
+ 472 canoe
474
+ 473 can opener
475
+ 474 cardigan
476
+ 475 car mirror
477
+ 476 carousel
478
+ 477 tool kit
479
+ 478 cardboard box / carton
480
+ 479 car wheel
481
+ 480 automated teller machine
482
+ 481 cassette
483
+ 482 cassette player
484
+ 483 castle
485
+ 484 catamaran
486
+ 485 CD player
487
+ 486 cello
488
+ 487 mobile phone
489
+ 488 chain
490
+ 489 chain-link fence
491
+ 490 chain mail
492
+ 491 chainsaw
493
+ 492 storage chest
494
+ 493 chiffonier
495
+ 494 bell or wind chime
496
+ 495 china cabinet
497
+ 496 Christmas stocking
498
+ 497 church
499
+ 498 movie theater
500
+ 499 cleaver
501
+ 500 cliff dwelling
502
+ 501 cloak
503
+ 502 clogs
504
+ 503 cocktail shaker
505
+ 504 coffee mug
506
+ 505 coffeemaker
507
+ 506 spiral or coil
508
+ 507 combination lock
509
+ 508 computer keyboard
510
+ 509 candy store
511
+ 510 container ship
512
+ 511 convertible
513
+ 512 corkscrew
514
+ 513 cornet
515
+ 514 cowboy boot
516
+ 515 cowboy hat
517
+ 516 cradle
518
+ 517 construction crane
519
+ 518 crash helmet
520
+ 519 crate
521
+ 520 infant bed
522
+ 521 Crock Pot
523
+ 522 croquet ball
524
+ 523 crutch
525
+ 524 cuirass
526
+ 525 dam
527
+ 526 desk
528
+ 527 desktop computer
529
+ 528 rotary dial telephone
530
+ 529 diaper
531
+ 530 digital clock
532
+ 531 digital watch
533
+ 532 dining table
534
+ 533 dishcloth
535
+ 534 dishwasher
536
+ 535 disc brake
537
+ 536 dock
538
+ 537 dog sled
539
+ 538 dome
540
+ 539 doormat
541
+ 540 drilling rig
542
+ 541 drum
543
+ 542 drumstick
544
+ 543 dumbbell
545
+ 544 Dutch oven
546
+ 545 electric fan
547
+ 546 electric guitar
548
+ 547 electric locomotive
549
+ 548 entertainment center
550
+ 549 envelope
551
+ 550 espresso machine
552
+ 551 face powder
553
+ 552 feather boa
554
+ 553 filing cabinet
555
+ 554 fireboat
556
+ 555 fire truck
557
+ 556 fire screen
558
+ 557 flagpole
559
+ 558 flute
560
+ 559 folding chair
561
+ 560 football helmet
562
+ 561 forklift
563
+ 562 fountain
564
+ 563 fountain pen
565
+ 564 four-poster bed
566
+ 565 freight car
567
+ 566 French horn
568
+ 567 frying pan
569
+ 568 fur coat
570
+ 569 garbage truck
571
+ 570 gas mask or respirator
572
+ 571 gas pump
573
+ 572 goblet
574
+ 573 go-kart
575
+ 574 golf ball
576
+ 575 golf cart
577
+ 576 gondola
578
+ 577 gong
579
+ 578 gown
580
+ 579 grand piano
581
+ 580 greenhouse
582
+ 581 radiator grille
583
+ 582 grocery store
584
+ 583 guillotine
585
+ 584 hair clip
586
+ 585 hair spray
587
+ 586 half-track
588
+ 587 hammer
589
+ 588 hamper
590
+ 589 hair dryer
591
+ 590 hand-held computer
592
+ 591 handkerchief
593
+ 592 hard disk drive
594
+ 593 harmonica
595
+ 594 harp
596
+ 595 combine harvester
597
+ 596 hatchet
598
+ 597 holster
599
+ 598 home theater
600
+ 599 honeycomb
601
+ 600 hook
602
+ 601 hoop skirt
603
+ 602 gymnastic horizontal bar
604
+ 603 horse-drawn vehicle
605
+ 604 hourglass
606
+ 605 iPod
607
+ 606 clothes iron
608
+ 607 carved pumpkin
609
+ 608 jeans
610
+ 609 jeep
611
+ 610 T-shirt
612
+ 611 jigsaw puzzle
613
+ 612 rickshaw
614
+ 613 joystick
615
+ 614 kimono
616
+ 615 knee pad
617
+ 616 knot
618
+ 617 lab coat
619
+ 618 ladle
620
+ 619 lampshade
621
+ 620 laptop computer
622
+ 621 lawn mower
623
+ 622 lens cap
624
+ 623 letter opener
625
+ 624 library
626
+ 625 lifeboat
627
+ 626 lighter
628
+ 627 limousine
629
+ 628 ocean liner
630
+ 629 lipstick
631
+ 630 slip-on shoe
632
+ 631 lotion
633
+ 632 music speaker
634
+ 633 loupe magnifying glass
635
+ 634 sawmill
636
+ 635 magnetic compass
637
+ 636 messenger bag
638
+ 637 mailbox
639
+ 638 tights
640
+ 639 one-piece bathing suit
641
+ 640 manhole cover
642
+ 641 maraca
643
+ 642 marimba
644
+ 643 mask
645
+ 644 matchstick
646
+ 645 maypole
647
+ 646 maze
648
+ 647 measuring cup
649
+ 648 medicine cabinet
650
+ 649 megalith
651
+ 650 microphone
652
+ 651 microwave oven
653
+ 652 military uniform
654
+ 653 milk can
655
+ 654 minibus
656
+ 655 miniskirt
657
+ 656 minivan
658
+ 657 missile
659
+ 658 mitten
660
+ 659 mixing bowl
661
+ 660 mobile home
662
+ 661 ford model t
663
+ 662 modem
664
+ 663 monastery
665
+ 664 monitor
666
+ 665 moped
667
+ 666 mortar and pestle
668
+ 667 graduation cap
669
+ 668 mosque
670
+ 669 mosquito net
671
+ 670 vespa
672
+ 671 mountain bike
673
+ 672 tent
674
+ 673 computer mouse
675
+ 674 mousetrap
676
+ 675 moving van
677
+ 676 muzzle
678
+ 677 metal nail
679
+ 678 neck brace
680
+ 679 necklace
681
+ 680 baby pacifier
682
+ 681 notebook computer
683
+ 682 obelisk
684
+ 683 oboe
685
+ 684 ocarina
686
+ 685 odometer
687
+ 686 oil filter
688
+ 687 pipe organ
689
+ 688 oscilloscope
690
+ 689 overskirt
691
+ 690 bullock cart
692
+ 691 oxygen mask
693
+ 692 product packet / packaging
694
+ 693 paddle
695
+ 694 paddle wheel
696
+ 695 padlock
697
+ 696 paintbrush
698
+ 697 pajamas
699
+ 698 palace
700
+ 699 pan flute
701
+ 700 paper towel
702
+ 701 parachute
703
+ 702 parallel bars
704
+ 703 park bench
705
+ 704 parking meter
706
+ 705 railroad car
707
+ 706 patio
708
+ 707 payphone
709
+ 708 pedestal
710
+ 709 pencil case
711
+ 710 pencil sharpener
712
+ 711 perfume
713
+ 712 Petri dish
714
+ 713 photocopier
715
+ 714 plectrum
716
+ 715 Pickelhaube
717
+ 716 picket fence
718
+ 717 pickup truck
719
+ 718 pier
720
+ 719 piggy bank
721
+ 720 pill bottle
722
+ 721 pillow
723
+ 722 ping-pong ball
724
+ 723 pinwheel
725
+ 724 pirate ship
726
+ 725 drink pitcher
727
+ 726 block plane
728
+ 727 planetarium
729
+ 728 plastic bag
730
+ 729 plate rack
731
+ 730 farm plow
732
+ 731 plunger
733
+ 732 Polaroid camera
734
+ 733 pole
735
+ 734 police van
736
+ 735 poncho
737
+ 736 pool table
738
+ 737 soda bottle
739
+ 738 plant pot
740
+ 739 potter's wheel
741
+ 740 power drill
742
+ 741 prayer rug
743
+ 742 printer
744
+ 743 prison
745
+ 744 missile
746
+ 745 projector
747
+ 746 hockey puck
748
+ 747 punching bag
749
+ 748 purse
750
+ 749 quill
751
+ 750 quilt
752
+ 751 race car
753
+ 752 racket
754
+ 753 radiator
755
+ 754 radio
756
+ 755 radio telescope
757
+ 756 rain barrel
758
+ 757 recreational vehicle
759
+ 758 fishing casting reel
760
+ 759 reflex camera
761
+ 760 refrigerator
762
+ 761 remote control
763
+ 762 restaurant
764
+ 763 revolver
765
+ 764 rifle
766
+ 765 rocking chair
767
+ 766 rotisserie
768
+ 767 eraser
769
+ 768 rugby ball
770
+ 769 ruler measuring stick
771
+ 770 sneaker
772
+ 771 safe
773
+ 772 safety pin
774
+ 773 salt shaker
775
+ 774 sandal
776
+ 775 sarong
777
+ 776 saxophone
778
+ 777 scabbard
779
+ 778 weighing scale
780
+ 779 school bus
781
+ 780 schooner
782
+ 781 scoreboard
783
+ 782 CRT monitor
784
+ 783 screw
785
+ 784 screwdriver
786
+ 785 seat belt
787
+ 786 sewing machine
788
+ 787 shield
789
+ 788 shoe store
790
+ 789 shoji screen / room divider
791
+ 790 shopping basket
792
+ 791 shopping cart
793
+ 792 shovel
794
+ 793 shower cap
795
+ 794 shower curtain
796
+ 795 ski
797
+ 796 balaclava ski mask
798
+ 797 sleeping bag
799
+ 798 slide rule
800
+ 799 sliding door
801
+ 800 slot machine
802
+ 801 snorkel
803
+ 802 snowmobile
804
+ 803 snowplow
805
+ 804 soap dispenser
806
+ 805 soccer ball
807
+ 806 sock
808
+ 807 solar thermal collector
809
+ 808 sombrero
810
+ 809 soup bowl
811
+ 810 keyboard space bar
812
+ 811 space heater
813
+ 812 space shuttle
814
+ 813 spatula
815
+ 814 motorboat
816
+ 815 spider web
817
+ 816 spindle
818
+ 817 sports car
819
+ 818 spotlight
820
+ 819 stage
821
+ 820 steam locomotive
822
+ 821 through arch bridge
823
+ 822 steel drum
824
+ 823 stethoscope
825
+ 824 scarf
826
+ 825 stone wall
827
+ 826 stopwatch
828
+ 827 stove
829
+ 828 strainer
830
+ 829 tram
831
+ 830 stretcher
832
+ 831 couch
833
+ 832 stupa
834
+ 833 submarine
835
+ 834 suit
836
+ 835 sundial
837
+ 836 sunglasses
838
+ 837 sunglasses
839
+ 838 sunscreen
840
+ 839 suspension bridge
841
+ 840 mop
842
+ 841 sweatshirt
843
+ 842 swim trunks / shorts
844
+ 843 swing
845
+ 844 electrical switch
846
+ 845 syringe
847
+ 846 table lamp
848
+ 847 tank
849
+ 848 tape player
850
+ 849 teapot
851
+ 850 teddy bear
852
+ 851 television
853
+ 852 tennis ball
854
+ 853 thatched roof
855
+ 854 front curtain
856
+ 855 thimble
857
+ 856 threshing machine
858
+ 857 throne
859
+ 858 tile roof
860
+ 859 toaster
861
+ 860 tobacco shop
862
+ 861 toilet seat
863
+ 862 torch
864
+ 863 totem pole
865
+ 864 tow truck
866
+ 865 toy store
867
+ 866 tractor
868
+ 867 semi-trailer truck
869
+ 868 tray
870
+ 869 trench coat
871
+ 870 tricycle
872
+ 871 trimaran
873
+ 872 tripod
874
+ 873 triumphal arch
875
+ 874 trolleybus
876
+ 875 trombone
877
+ 876 hot tub
878
+ 877 turnstile
879
+ 878 typewriter keyboard
880
+ 879 umbrella
881
+ 880 unicycle
882
+ 881 upright piano
883
+ 882 vacuum cleaner
884
+ 883 vase
885
+ 884 vaulted or arched ceiling
886
+ 885 velvet fabric
887
+ 886 vending machine
888
+ 887 vestment
889
+ 888 viaduct
890
+ 889 violin
891
+ 890 volleyball
892
+ 891 waffle iron
893
+ 892 wall clock
894
+ 893 wallet
895
+ 894 wardrobe
896
+ 895 military aircraft
897
+ 896 sink
898
+ 897 washing machine
899
+ 898 water bottle
900
+ 899 water jug
901
+ 900 water tower
902
+ 901 whiskey jug
903
+ 902 whistle
904
+ 903 hair wig
905
+ 904 window screen
906
+ 905 window shade
907
+ 906 Windsor tie
908
+ 907 wine bottle
909
+ 908 airplane wing
910
+ 909 wok
911
+ 910 wooden spoon
912
+ 911 wool
913
+ 912 split-rail fence
914
+ 913 shipwreck
915
+ 914 sailboat
916
+ 915 yurt
917
+ 916 website
918
+ 917 comic book
919
+ 918 crossword
920
+ 919 traffic or street sign
921
+ 920 traffic light
922
+ 921 dust jacket
923
+ 922 menu
924
+ 923 plate
925
+ 924 guacamole
926
+ 925 consomme
927
+ 926 hot pot
928
+ 927 trifle
929
+ 928 ice cream
930
+ 929 popsicle
931
+ 930 baguette
932
+ 931 bagel
933
+ 932 pretzel
934
+ 933 cheeseburger
935
+ 934 hot dog
936
+ 935 mashed potatoes
937
+ 936 cabbage
938
+ 937 broccoli
939
+ 938 cauliflower
940
+ 939 zucchini
941
+ 940 spaghetti squash
942
+ 941 acorn squash
943
+ 942 butternut squash
944
+ 943 cucumber
945
+ 944 artichoke
946
+ 945 bell pepper
947
+ 946 cardoon
948
+ 947 mushroom
949
+ 948 Granny Smith apple
950
+ 949 strawberry
951
+ 950 orange
952
+ 951 lemon
953
+ 952 fig
954
+ 953 pineapple
955
+ 954 banana
956
+ 955 jackfruit
957
+ 956 cherimoya (custard apple)
958
+ 957 pomegranate
959
+ 958 hay
960
+ 959 carbonara
961
+ 960 chocolate syrup
962
+ 961 dough
963
+ 962 meatloaf
964
+ 963 pizza
965
+ 964 pot pie
966
+ 965 burrito
967
+ 966 red wine
968
+ 967 espresso
969
+ 968 tea cup
970
+ 969 eggnog
971
+ 970 mountain
972
+ 971 bubble
973
+ 972 cliff
974
+ 973 coral reef
975
+ 974 geyser
976
+ 975 lakeshore
977
+ 976 promontory
978
+ 977 sandbar
979
+ 978 beach
980
+ 979 valley
981
+ 980 volcano
982
+ 981 baseball player
983
+ 982 bridegroom
984
+ 983 scuba diver
985
+ 984 rapeseed
986
+ 985 daisy
987
+ 986 yellow lady's slipper
988
+ 987 corn
989
+ 988 acorn
990
+ 989 rose hip
991
+ 990 horse chestnut seed
992
+ 991 coral fungus
993
+ 992 agaric
994
+ 993 gyromitra
995
+ 994 stinkhorn mushroom
996
+ 995 earth star fungus
997
+ 996 hen of the woods mushroom
998
+ 997 bolete
999
+ 998 corn cob
1000
+ 999 toilet paper
dataset_reqs/imagenet100_classes.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 n01440764 tench
2
+ 1 n01443537 goldfish
3
+ 2 n01484850 great white shark
4
+ 3 n01491361 tiger shark
5
+ 4 n01494475 hammerhead shark
6
+ 5 n01496331 electric ray
7
+ 6 n01498041 stingray
8
+ 7 n01514668 rooster
9
+ 8 n01514859 hen
10
+ 9 n01518878 ostrich
11
+ 10 n01530575 brambling
12
+ 11 n01531178 goldfinch
13
+ 12 n01532829 house finch
14
+ 13 n01534433 junco
15
+ 14 n01537544 indigo bunting
16
+ 15 n01558993 American robin
17
+ 16 n01560419 bulbul
18
+ 17 n01580077 jay
19
+ 18 n01582220 magpie
20
+ 19 n01592084 chickadee
21
+ 20 n01601694 American dipper
22
+ 21 n01608432 kite (bird of prey)
23
+ 22 n01614925 bald eagle
24
+ 23 n01616318 vulture
25
+ 24 n01622779 great grey owl
26
+ 25 n01629819 fire salamander
27
+ 26 n01630670 smooth newt
28
+ 27 n01631663 newt
29
+ 28 n01632458 spotted salamander
30
+ 29 n01632777 axolotl
31
+ 30 n01641577 American bullfrog
32
+ 31 n01644373 tree frog
33
+ 32 n01644900 tailed frog
34
+ 33 n01664065 loggerhead sea turtle
35
+ 34 n01665541 leatherback sea turtle
36
+ 35 n01667114 mud turtle
37
+ 36 n01667778 terrapin
38
+ 37 n01669191 box turtle
39
+ 38 n01675722 banded gecko
40
+ 39 n01677366 green iguana
41
+ 40 n01682714 Carolina anole
42
+ 41 n01685808 desert grassland whiptail lizard
43
+ 42 n01687978 agama
44
+ 43 n01688243 frilled-necked lizard
45
+ 44 n01689811 alligator lizard
46
+ 45 n01692333 Gila monster
47
+ 46 n01693334 European green lizard
48
+ 47 n01694178 chameleon
49
+ 48 n01695060 Komodo dragon
50
+ 49 n01697457 Nile crocodile
51
+ 50 n01698640 American alligator
52
+ 51 n01704323 triceratops
53
+ 52 n01728572 worm snake
54
+ 53 n01728920 ring-necked snake
55
+ 54 n01729322 eastern hog-nosed snake
56
+ 55 n01729977 smooth green snake
57
+ 56 n01734418 kingsnake
58
+ 57 n01735189 garter snake
59
+ 58 n01737021 water snake
60
+ 59 n01739381 vine snake
61
+ 60 n01740131 night snake
62
+ 61 n01742172 boa constrictor
63
+ 62 n01744401 African rock python
64
+ 63 n01748264 Indian cobra
65
+ 64 n01749939 green mamba
66
+ 65 n01751748 sea snake
67
+ 66 n01753488 Saharan horned viper
68
+ 67 n01755581 eastern diamondback rattlesnake
69
+ 68 n01756291 sidewinder rattlesnake
70
+ 69 n01768244 trilobite
71
+ 70 n01770081 harvestman
72
+ 71 n01770393 scorpion
73
+ 72 n01773157 yellow garden spider
74
+ 73 n01773549 barn spider
75
+ 74 n01773797 European garden spider
76
+ 75 n01774384 southern black widow
77
+ 76 n01774750 tarantula
78
+ 77 n01775062 wolf spider
79
+ 78 n01776313 tick
80
+ 79 n01784675 centipede
81
+ 80 n01795545 black grouse
82
+ 81 n01796340 ptarmigan
83
+ 82 n01797886 ruffed grouse
84
+ 83 n01798484 prairie grouse
85
+ 84 n01806143 peafowl
86
+ 85 n01806567 quail
87
+ 86 n01807496 partridge
88
+ 87 n01817953 african grey parrot
89
+ 88 n01818515 macaw
90
+ 89 n01819313 sulphur-crested cockatoo
91
+ 90 n01820546 lorikeet
92
+ 91 n01824575 coucal
93
+ 92 n01828970 bee eater
94
+ 93 n01829413 hornbill
95
+ 94 n01833805 hummingbird
96
+ 95 n01843065 jacamar
97
+ 96 n01843383 toucan
98
+ 97 n01847000 duck
99
+ 98 n01855032 red-breasted merganser
100
+ 99 n01855672 goose
dataset_reqs/imagenet100_splits/train_100.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset_reqs/imagenet100_splits/val_100.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset_reqs/tinyimagenet_classes.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 n02124075 Egyptian Mau
2
+ 1 n04067472 fishing casting reel
3
+ 2 n04540053 volleyball
4
+ 3 n04099969 rocking chair
5
+ 4 n07749582 lemon
6
+ 5 n01641577 American bullfrog
7
+ 6 n02802426 basketball
8
+ 7 n09246464 cliff
9
+ 8 n07920052 espresso
10
+ 9 n03970156 plunger
11
+ 10 n03891332 parking meter
12
+ 11 n02106662 German Shepherd Dog
13
+ 12 n03201208 dining table
14
+ 13 n02279972 monarch butterfly
15
+ 14 n02132136 brown bear
16
+ 15 n04146614 school bus
17
+ 16 n07873807 pizza
18
+ 17 n02364673 guinea pig
19
+ 18 n04507155 umbrella
20
+ 19 n03854065 pipe organ
21
+ 20 n03838899 oboe
22
+ 21 n03733131 maypole
23
+ 22 n01443537 goldfish
24
+ 23 n07875152 pot pie
25
+ 24 n03544143 hourglass
26
+ 25 n09428293 beach
27
+ 26 n03085013 computer keyboard
28
+ 27 n02437312 arabian camel
29
+ 28 n07614500 ice cream
30
+ 29 n03804744 metal nail
31
+ 30 n04265275 space heater
32
+ 31 n02963159 cardigan
33
+ 32 n02486410 baboon
34
+ 33 n01944390 snail
35
+ 34 n09256479 coral reef
36
+ 35 n02058221 albatross
37
+ 36 n04275548 spider web
38
+ 37 n02321529 sea cucumber
39
+ 38 n02769748 backpack
40
+ 39 n02099712 Labrador Retriever
41
+ 40 n07695742 pretzel
42
+ 41 n02056570 king penguin
43
+ 42 n02281406 sulphur butterfly
44
+ 43 n01774750 tarantula
45
+ 44 n02509815 red panda
46
+ 45 n03983396 soda bottle
47
+ 46 n07753592 banana
48
+ 47 n04254777 sock
49
+ 48 n02233338 cockroach
50
+ 49 n04008634 missile
51
+ 50 n02823428 beer bottle
52
+ 51 n02236044 praying mantis
53
+ 52 n03393912 freight car
54
+ 53 n07583066 guacamole
55
+ 54 n04074963 remote control
56
+ 55 n01629819 fire salamander
57
+ 56 n09332890 lakeshore
58
+ 57 n02481823 chimpanzee
59
+ 58 n03902125 payphone
60
+ 59 n03404251 fur coat
61
+ 60 n09193705 mountain
62
+ 61 n03637318 lampshade
63
+ 62 n04456115 torch
64
+ 63 n02666196 abacus
65
+ 64 n03796401 moving van
66
+ 65 n02795169 barrel
67
+ 66 n02123045 tabby cat
68
+ 67 n01855672 goose
69
+ 68 n01882714 koala
70
+ 69 n02917067 high-speed train
71
+ 70 n02988304 CD player
72
+ 71 n04398044 teapot
73
+ 72 n02843684 birdhouse
74
+ 73 n02423022 gazelle
75
+ 74 n02669723 academic gown
76
+ 75 n04465501 tractor
77
+ 76 n02165456 ladybug
78
+ 77 n03770439 miniskirt
79
+ 78 n02099601 Golden Retriever
80
+ 79 n04486054 triumphal arch
81
+ 80 n02950826 cannon
82
+ 81 n03814639 neck brace
83
+ 82 n04259630 sombrero
84
+ 83 n03424325 gas mask or respirator
85
+ 84 n02948072 candle
86
+ 85 n03179701 desk
87
+ 86 n03400231 frying pan
88
+ 87 n02206856 bee
89
+ 88 n03160309 dam
90
+ 89 n01984695 spiny lobster
91
+ 90 n03977966 police van
92
+ 91 n03584254 iPod
93
+ 92 n04023962 punching bag
94
+ 93 n02814860 lighthouse
95
+ 94 n01910747 jellyfish
96
+ 95 n04596742 wok
97
+ 96 n03992509 potter's wheel
98
+ 97 n04133789 sandal
99
+ 98 n03937543 pill bottle
100
+ 99 n02927161 butcher shop
101
+ 100 n01945685 slug
102
+ 101 n02395406 pig
103
+ 102 n02125311 cougar
104
+ 103 n03126707 construction crane
105
+ 104 n04532106 vestment
106
+ 105 n02268443 dragonfly
107
+ 106 n02977058 automated teller machine
108
+ 107 n07734744 mushroom
109
+ 108 n03599486 rickshaw
110
+ 109 n04562935 water tower
111
+ 110 n03014705 storage chest
112
+ 111 n04251144 snorkel
113
+ 112 n04356056 sunglasses
114
+ 113 n02190166 fly
115
+ 114 n03670208 limousine
116
+ 115 n02002724 black stork
117
+ 116 n02074367 dugong
118
+ 117 n04285008 sports car
119
+ 118 n04560804 water jug
120
+ 119 n04366367 suspension bridge
121
+ 120 n02403003 ox
122
+ 121 n07615774 popsicle
123
+ 122 n04501370 turnstile
124
+ 123 n03026506 Christmas stocking
125
+ 124 n02906734 broom
126
+ 125 n01770393 scorpion
127
+ 126 n04597913 wooden spoon
128
+ 127 n03930313 picket fence
129
+ 128 n04118538 rugby ball
130
+ 129 n04179913 sewing machine
131
+ 130 n04311004 through arch bridge
132
+ 131 n02123394 Persian cat
133
+ 132 n04070727 refrigerator
134
+ 133 n02793495 barn
135
+ 134 n02730930 apron
136
+ 135 n02094433 Yorkshire Terrier
137
+ 136 n04371430 swim trunks / shorts
138
+ 137 n04328186 stopwatch
139
+ 138 n03649909 lawn mower
140
+ 139 n04417672 thatched roof
141
+ 140 n03388043 fountain
142
+ 141 n01774384 southern black widow
143
+ 142 n02837789 bikini
144
+ 143 n07579787 plate
145
+ 144 n04399382 teddy bear
146
+ 145 n02791270 barbershop
147
+ 146 n03089624 candy store
148
+ 147 n02814533 station wagon
149
+ 148 n04149813 scoreboard
150
+ 149 n07747607 orange
151
+ 150 n03355925 flagpole
152
+ 151 n01983481 American lobster
153
+ 152 n04487081 trolleybus
154
+ 153 n03250847 drumstick
155
+ 154 n03255030 dumbbell
156
+ 155 n02892201 brass memorial plaque
157
+ 156 n02883205 bow tie
158
+ 157 n03100240 convertible
159
+ 158 n02415577 bighorn sheep
160
+ 159 n02480495 orangutan
161
+ 160 n01698640 American alligator
162
+ 161 n01784675 centipede
163
+ 162 n04376876 syringe
164
+ 163 n03444034 go-kart
165
+ 164 n01917289 brain coral
166
+ 165 n01950731 sea slug
167
+ 166 n03042490 cliff dwelling
168
+ 167 n07711569 mashed potatoes
169
+ 168 n04532670 viaduct
170
+ 169 n03763968 military uniform
171
+ 170 n07768694 pomegranate
172
+ 171 n02999410 chain
173
+ 172 n03617480 kimono
174
+ 173 n06596364 comic book
175
+ 174 n01768244 trilobite
176
+ 175 n02410509 bison
177
+ 176 n03976657 pole
178
+ 177 n01742172 boa constrictor
179
+ 178 n03980874 poncho
180
+ 179 n02808440 bathtub
181
+ 180 n02226429 grasshopper
182
+ 181 n02231487 stick insect
183
+ 182 n02085620 Chihuahua
184
+ 183 n01644900 tailed frog
185
+ 184 n02129165 lion
186
+ 185 n02699494 altar
187
+ 186 n03837869 obelisk
188
+ 187 n02815834 beaker
189
+ 188 n07720875 bell pepper
190
+ 189 n02788148 baluster / handrail
191
+ 190 n02909870 bucket
192
+ 191 n03706229 magnetic compass
193
+ 192 n07871810 meatloaf
194
+ 193 n03447447 gondola
195
+ 194 n02113799 Standard Poodle
196
+ 195 n12267677 acorn
197
+ 196 n03662601 lifeboat
198
+ 197 n02841315 binoculars
199
+ 198 n07715103 cauliflower
200
+ 199 n02504458 African bush elephant
main.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+
5
+ import hydra
6
+ import logging
7
+ import numpy as np
8
+ from omegaconf import DictConfig
9
+ from tqdm import tqdm
10
+ import torch
11
+ import statistics
12
+ from continuum.metrics import Logger
13
+ from continual_clip import utils
14
+ from continual_clip.models import load_model
15
+ from continual_clip.datasets import build_cl_scenarios
16
+ from torch.utils.data import DataLoader, DistributedSampler
17
+
18
+ WORLD_NUM = 1
19
+ @hydra.main(config_path=None, config_name=None, version_base="1.1")
20
+ def continual_clip(cfg: DictConfig) -> None:
21
+
22
+ set_seed(RANDOM_SEED)
23
+
24
+ cfg.workdir = "/***/DMNSP/cil"
25
+ cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root)
26
+
27
+ utils.save_config(cfg)
28
+ cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order))
29
+ origin_flag = False
30
+ devices = [0]
31
+ model = load_model(cfg, devices[0], origin_flag)
32
+
33
+ eval_dataset, classes_names = build_cl_scenarios(
34
+ cfg, is_train=False, transforms=model.transforms
35
+ )
36
+ print(eval_dataset, eval_dataset)
37
+ train_dataset, train_classes_names = build_cl_scenarios(
38
+ cfg, is_train=True, transforms=model.transforms
39
+ )
40
+ model.classes_names = classes_names
41
+
42
+ print("Using devices", devices)
43
+ model = torch.nn.DataParallel(model, device_ids=devices)
44
+
45
+ with open(cfg.log_path, 'w+') as f:
46
+ pass
47
+
48
+ acc_list = []
49
+ forgetting_list = []
50
+ metric_logger = Logger(list_subsets=["test"])
51
+ world = WORLD_NUM
52
+
53
+ for task_id, _ in enumerate(eval_dataset):
54
+
55
+ logging.info(f"Evaluation for task {task_id} has started.")
56
+
57
+ model.module.adaptation(task_id, cfg, train_dataset, train_classes_names, world) # task id 已经传入mode
58
+ eval_sampler = DistributedSampler(eval_dataset[:task_id + 1], num_replicas=world, rank=0)
59
+ eval_loader = DataLoader(eval_dataset[:task_id + 1], batch_size=64, sampler=eval_sampler, num_workers=8)
60
+
61
+ for inputs, targets, task_ids in tqdm(eval_loader):
62
+ inputs, targets = inputs.cuda(device=devices[0]), targets.cuda(device=devices[0])
63
+ outputs = model.module.cuda(devices[0])(inputs.cuda(devices[0]), task_ids)
64
+ metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
65
+
66
+
67
+ acc_list.append(100 * metric_logger.accuracy)
68
+ forgetting_list.append(100 * metric_logger.forgetting)
69
+
70
+ with open(cfg.log_path, 'a+') as f:
71
+ f.write(json.dumps({
72
+ 'task': task_id,
73
+ 'acc': round(100 * metric_logger.accuracy, 2),
74
+ 'avg_acc': round(100 * metric_logger.average_incremental_accuracy, 2),
75
+ 'forgetting': round(100 * metric_logger.forgetting, 6),
76
+ 'acc_per_task': [round(100 * acc_t, 2) for acc_t in metric_logger.accuracy_per_task],
77
+ 'bwt': round(100 * metric_logger.backward_transfer, 2),
78
+ 'fwt': round(100 * metric_logger.forward_transfer, 2),
79
+ }) + '\n')
80
+ metric_logger.end_task()
81
+
82
+ with open(cfg.log_path, 'a+') as f:
83
+ f.write(json.dumps({
84
+ 'last_Cifar100': round(acc_list[-1], 2),
85
+ 'avg_Cifar100': round(statistics.mean(acc_list), 2),
86
+ 'avg_forgetting': round(statistics.mean(forgetting_list), 2)
87
+ }) + '\n')
88
+
89
+
90
+ # Seeds: 386, 2345, 157 (Performance might slightly vary across different machines)
91
+ RANDOM_SEED = 386
92
+
93
+ def set_seed(seed):
94
+ random.seed(seed)
95
+ np.random.seed(seed)
96
+ torch.manual_seed(seed)
97
+ torch.cuda.manual_seed(seed)
98
+ torch.cuda.manual_seed_all(seed)
99
+ os.environ['PYTHONHASHSEED'] = str(seed)
100
+ torch.backends.cudnn.deterministic = True
101
+ torch.backends.cudnn.benchmark = False
102
+
103
+ if __name__ == "__main__":
104
+ continual_clip()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ continuum
2
+ hydra-core==1.2.0
3
+ numpy
4
+ oauthlib==3.2.1
5
+ omegaconf==2.2.3
6
+ open-clip-torch==1.3.0
7
+ pandas==1.4.3
8
+ Pillow==9.2.0
9
+ pipreqs==0.4.11
10
+ scikit-image==0.19.3
11
+ scikit-learn==1.1.1
12
+ scipy==1.8.1
13
+ tensorboard==2.10.0
14
+ timm @ git+https://github.com/Arnav0400/pytorch-image-models.git@ceea7127c1ef608179ba06eaeddc22ad3ef22de0
15
+ tokenizers==0.12.1
16
+ tqdm==4.64.0
17
+ transformers==4.21.1
18
+ ftfy
19
+ regex
run_cifar100-10-10.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!bin/bash
2
+
3
+ python main.py \
4
+ --config-path /DMNSP/configs/class \
5
+ --config-name cifar100_10-10.yaml \
6
+ dataset_root="/data/**/" \
7
+ class_order="/DMNSP/class_orders/cifar100.yaml"
8
+
9
+
templates/__init__.py ADDED
File without changes
templates/fmow_template.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .template_utils import append_proper_article
2
+
3
+ fmow_template = [
4
+ lambda c : f"satellite photo of a {c}.",
5
+ lambda c : f"aerial photo of a {c}.",
6
+ lambda c : f"satellite photo of {append_proper_article(c)}.",
7
+ lambda c : f"aerial photo of {append_proper_article(c)}.",
8
+ lambda c : f"satellite photo of a {c} in asia.",
9
+ lambda c : f"aerial photo of a {c} in asia.",
10
+ lambda c : f"satellite photo of a {c} in africa.",
11
+ lambda c : f"aerial photo of a {c} in africa.",
12
+ lambda c : f"satellite photo of a {c} in the americas.",
13
+ lambda c : f"aerial photo of a {c} in the americas.",
14
+ lambda c : f"satellite photo of a {c} in europe.",
15
+ lambda c : f"aerial photo of a {c} in europe.",
16
+ lambda c : f"satellite photo of a {c} in oceania.",
17
+ lambda c : f"aerial photo of a {c} in oceania.",
18
+ lambda c: f"a photo of a {c}.",
19
+ lambda c: f"{c}.",
20
+ ]
templates/iwildcam_template.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ iwildcam_template = [
2
+ lambda c: f"a photo of {c}.",
3
+ lambda c: f"{c} in the wild.",
4
+ ]
templates/openai_imagenet_template.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_imagenet_template = [
2
+ lambda c: f'a bad photo of a {c}.',
3
+ lambda c: f'a photo of many {c}.',
4
+ lambda c: f'a sculpture of a {c}.',
5
+ lambda c: f'a photo of the hard to see {c}.',
6
+ lambda c: f'a low resolution photo of the {c}.',
7
+ lambda c: f'a rendering of a {c}.',
8
+ lambda c: f'graffiti of a {c}.',
9
+ lambda c: f'a bad photo of the {c}.',
10
+ lambda c: f'a cropped photo of the {c}.',
11
+ lambda c: f'a tattoo of a {c}.',
12
+ lambda c: f'the embroidered {c}.',
13
+ lambda c: f'a photo of a hard to see {c}.',
14
+ lambda c: f'a bright photo of a {c}.',
15
+ lambda c: f'a photo of a clean {c}.',
16
+ lambda c: f'a photo of a dirty {c}.',
17
+ lambda c: f'a dark photo of the {c}.',
18
+ lambda c: f'a drawing of a {c}.',
19
+ lambda c: f'a photo of my {c}.',
20
+ lambda c: f'the plastic {c}.',
21
+ lambda c: f'a photo of the cool {c}.',
22
+ lambda c: f'a close-up photo of a {c}.',
23
+ lambda c: f'a black and white photo of the {c}.',
24
+ lambda c: f'a painting of the {c}.',
25
+ lambda c: f'a painting of a {c}.',
26
+ lambda c: f'a pixelated photo of the {c}.',
27
+ lambda c: f'a sculpture of the {c}.',
28
+ lambda c: f'a bright photo of the {c}.',
29
+ lambda c: f'a cropped photo of a {c}.',
30
+ lambda c: f'a plastic {c}.',
31
+ lambda c: f'a photo of the dirty {c}.',
32
+ lambda c: f'a jpeg corrupted photo of a {c}.',
33
+ lambda c: f'a blurry photo of the {c}.',
34
+ lambda c: f'a photo of the {c}.',
35
+ lambda c: f'a good photo of the {c}.',
36
+ lambda c: f'a rendering of the {c}.',
37
+ lambda c: f'a {c} in a video game.',
38
+ lambda c: f'a photo of one {c}.',
39
+ lambda c: f'a doodle of a {c}.',
40
+ lambda c: f'a close-up photo of the {c}.',
41
+ lambda c: f'a photo of a {c}.',
42
+ lambda c: f'the origami {c}.',
43
+ lambda c: f'the {c} in a video game.',
44
+ lambda c: f'a sketch of a {c}.',
45
+ lambda c: f'a doodle of the {c}.',
46
+ lambda c: f'a origami {c}.',
47
+ lambda c: f'a low resolution photo of a {c}.',
48
+ lambda c: f'the toy {c}.',
49
+ lambda c: f'a rendition of the {c}.',
50
+ lambda c: f'a photo of the clean {c}.',
51
+ lambda c: f'a photo of a large {c}.',
52
+ lambda c: f'a rendition of a {c}.',
53
+ lambda c: f'a photo of a nice {c}.',
54
+ lambda c: f'a photo of a weird {c}.',
55
+ lambda c: f'a blurry photo of a {c}.',
56
+ lambda c: f'a cartoon {c}.',
57
+ lambda c: f'art of a {c}.',
58
+ lambda c: f'a sketch of the {c}.',
59
+ lambda c: f'a embroidered {c}.',
60
+ lambda c: f'a pixelated photo of a {c}.',
61
+ lambda c: f'itap of the {c}.',
62
+ lambda c: f'a jpeg corrupted photo of the {c}.',
63
+ lambda c: f'a good photo of a {c}.',
64
+ lambda c: f'a plushie {c}.',
65
+ lambda c: f'a photo of the nice {c}.',
66
+ lambda c: f'a photo of the small {c}.',
67
+ lambda c: f'a photo of the weird {c}.',
68
+ lambda c: f'the cartoon {c}.',
69
+ lambda c: f'art of the {c}.',
70
+ lambda c: f'a drawing of the {c}.',
71
+ lambda c: f'a photo of the large {c}.',
72
+ lambda c: f'a black and white photo of a {c}.',
73
+ lambda c: f'the plushie {c}.',
74
+ lambda c: f'a dark photo of a {c}.',
75
+ lambda c: f'itap of a {c}.',
76
+ lambda c: f'graffiti of the {c}.',
77
+ lambda c: f'a toy {c}.',
78
+ lambda c: f'itap of my {c}.',
79
+ lambda c: f'a photo of a cool {c}.',
80
+ lambda c: f'a photo of a small {c}.',
81
+ lambda c: f'a tattoo of the {c}.',
82
+ ]
templates/simple_template.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ simple_template = [
2
+ lambda c: f"a photo of a {c}."
3
+ ]
templates/template_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_plural(name):
2
+ name = name.replace('_', ' ')
3
+ if name[-2:] == 'sh':
4
+ name = name + 'es'
5
+ elif name[-2:] == 'ch':
6
+ name = name + 'es'
7
+ elif name[-1:] == 'y':
8
+ name = name[:-1] + 'ies'
9
+ elif name[-1:] == 's':
10
+ name = name + 'es'
11
+ elif name[-1:] == 'x':
12
+ name = name + 'es'
13
+ elif name[-3:] == 'man':
14
+ name = name[:-3] + 'men'
15
+ elif name == 'mouse':
16
+ name = 'mice'
17
+ elif name[-1:] == 'f':
18
+ name = name[:-1] + 'ves'
19
+ else:
20
+ name = name + 's'
21
+ return name
22
+
23
+
24
+ def append_proper_article(name):
25
+ name = name.replace('_', ' ')
26
+ if name[0] in 'aeiou':
27
+ return 'an ' + name
28
+ return 'a ' + name
templates/testing_template.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ testing_template = [
2
+ lambda c : f'a photo of the number: "{c}".',
3
+ lambda c: f'a bad photo of a {c}.',
4
+ lambda c: f'a photo of many {c}.',
5
+ lambda c: f'a sculpture of a {c}.',
6
+ lambda c: f'a photo of the hard to see {c}.',
7
+ lambda c: f'a low resolution photo of the {c}.',
8
+ lambda c: f'a rendering of a {c}.',
9
+ lambda c: f'graffiti of a {c}.',
10
+ lambda c: f'a bad photo of the {c}.',
11
+ lambda c: f'a cropped photo of the {c}.',
12
+ lambda c: f'a tattoo of a {c}.',
13
+ lambda c: f'the embroidered {c}.',
14
+ lambda c: f'a photo of a hard to see {c}.',
15
+ lambda c: f'a bright photo of a {c}.',
16
+ lambda c: f'a photo of a clean {c}.',
17
+ lambda c: f'a photo of a dirty {c}.',
18
+ lambda c: f'a dark photo of the {c}.',
19
+ lambda c: f'a drawing of a {c}.',
20
+ lambda c: f'a photo of my {c}.',
21
+ lambda c: f'the plastic {c}.',
22
+ lambda c: f'a photo of the cool {c}.',
23
+ lambda c: f'a close-up photo of a {c}.',
24
+ lambda c: f'a black and white photo of the {c}.',
25
+ lambda c: f'a painting of the {c}.',
26
+ lambda c: f'a painting of a {c}.',
27
+ lambda c: f'a pixelated photo of the {c}.',
28
+ lambda c: f'a sculpture of the {c}.',
29
+ lambda c: f'a bright photo of the {c}.',
30
+ lambda c: f'a cropped photo of a {c}.',
31
+ lambda c: f'a plastic {c}.',
32
+ lambda c: f'a photo of the dirty {c}.',
33
+ lambda c: f'a jpeg corrupted photo of a {c}.',
34
+ lambda c: f'a blurry photo of the {c}.',
35
+ lambda c: f'a photo of the {c}.',
36
+ lambda c: f'a good photo of the {c}.',
37
+ lambda c: f'a rendering of the {c}.',
38
+ lambda c: f'a {c} in a video game.',
39
+ lambda c: f'a photo of one {c}.',
40
+ lambda c: f'a doodle of a {c}.',
41
+ lambda c: f'a close-up photo of the {c}.',
42
+ lambda c: f'a photo of a {c}.',
43
+ lambda c: f'the origami {c}.',
44
+ lambda c: f'the {c} in a video game.',
45
+ lambda c: f'a sketch of a {c}.',
46
+ lambda c: f'a doodle of the {c}.',
47
+ lambda c: f'a origami {c}.',
48
+ lambda c: f'a low resolution photo of a {c}.',
49
+ lambda c: f'the toy {c}.',
50
+ lambda c: f'a rendition of the {c}.',
51
+ lambda c: f'a photo of the clean {c}.',
52
+ lambda c: f'a photo of a large {c}.',
53
+ lambda c: f'a rendition of a {c}.',
54
+ lambda c: f'a photo of a nice {c}.',
55
+ lambda c: f'a photo of a weird {c}.',
56
+ lambda c: f'a blurry photo of a {c}.',
57
+ lambda c: f'a cartoon {c}.',
58
+ lambda c: f'art of a {c}.',
59
+ lambda c: f'a sketch of the {c}.',
60
+ lambda c: f'a embroidered {c}.',
61
+ lambda c: f'a pixelated photo of a {c}.',
62
+ lambda c: f'itap of the {c}.',
63
+ lambda c: f'a jpeg corrupted photo of the {c}.',
64
+ lambda c: f'a good photo of a {c}.',
65
+ lambda c: f'a plushie {c}.',
66
+ lambda c: f'a photo of the nice {c}.',
67
+ lambda c: f'a photo of the small {c}.',
68
+ lambda c: f'a photo of the weird {c}.',
69
+ lambda c: f'the cartoon {c}.',
70
+ lambda c: f'art of the {c}.',
71
+ lambda c: f'a drawing of the {c}.',
72
+ lambda c: f'a photo of the large {c}.',
73
+ lambda c: f'a black and white photo of a {c}.',
74
+ lambda c: f'the plushie {c}.',
75
+ lambda c: f'a dark photo of a {c}.',
76
+ lambda c: f'itap of a {c}.',
77
+ lambda c: f'graffiti of the {c}.',
78
+ lambda c: f'a toy {c}.',
79
+ lambda c: f'itap of my {c}.',
80
+ lambda c: f'a photo of a cool {c}.',
81
+ lambda c: f'a photo of a small {c}.',
82
+ lambda c: f'a tattoo of the {c}.',
83
+ ]