| import torch |
| import torch.nn as nn |
| import copy |
|
|
| from .vit_inflora import VisionTransformer, PatchEmbed, Block, resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn |
|
|
| class ViT_lora_co(VisionTransformer): |
| def __init__( |
| self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', |
| embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, |
| drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, |
| embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, n_tasks=10, rank=64): |
|
|
| super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, |
| embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size, |
| drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values, |
| embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn, n_tasks=n_tasks, rank=rank) |
|
|
| def forward(self, x, task_id, register_blk=-1, get_feat=False, get_cur_feat=False): |
| x = self.patch_embed(x) |
| x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
|
| x = x + self.pos_embed[:, :x.size(1), :] |
| x = self.pos_drop(x) |
|
|
| prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device) |
| for i, blk in enumerate(self.blocks): |
| x = blk(x, task_id, register_blk == i, |
| get_feat=get_feat, get_cur_feat=get_cur_feat) |
|
|
| x = self.norm(x) |
|
|
| return x, prompt_loss |
|
|
|
|
| def _create_vision_transformer(variant, pretrained=False, **kwargs): |
| if kwargs.get('features_only', None): |
| raise RuntimeError( |
| 'features_only not implemented for Vision Transformer models.') |
|
|
| |
| |
| pretrained_cfg = resolve_pretrained_cfg(variant) |
| default_num_classes = pretrained_cfg['num_classes'] |
| num_classes = kwargs.get('num_classes', default_num_classes) |
| repr_size = kwargs.pop('representation_size', None) |
| if repr_size is not None and num_classes != default_num_classes: |
| repr_size = None |
|
|
| model = build_model_with_cfg( |
| ViT_lora_co, variant, pretrained, |
| pretrained_cfg=pretrained_cfg, |
| representation_size=repr_size, |
| pretrained_filter_fn=checkpoint_filter_fn, |
| pretrained_custom_load='npz' in pretrained_cfg['url'], |
| **kwargs) |
| return model |
|
|
|
|
| class SiNet_vit(nn.Module): |
|
|
| def __init__(self, **args): |
| ''' |
| args is a dictionary with the required arguments. |
| image_encoder is defined in vit_inflora. |
| class_num is the number of initial class. |
| ''' |
| super(SiNet_vit, self).__init__() |
| model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, |
| num_heads=12, n_tasks=args["total_sessions"], rank=args["rank"]) |
| self.image_encoder = _create_vision_transformer( |
| 'vit_base_patch16_224_in21k', pretrained=True, **model_kwargs) |
| self.class_num = 1 |
| self.class_num = args["init_cls"] |
| self.classifier_pool = nn.ModuleList([ |
| nn.Linear(args["embd_dim"], self.class_num, bias=True) |
| for i in range(args["total_sessions"]) |
| ]) |
| self.classifier_pool_backup = nn.ModuleList([ |
| nn.Linear(args["embd_dim"], self.class_num, bias=True) |
| for i in range(args["total_sessions"]) |
| ]) |
| self.numtask = 0 |
|
|
| @property |
| def feature_dim(self): |
| return self.image_encoder.out_dim |
|
|
| def extract_vector(self, image, task=None): |
| if task == None: |
| image_features, _ = self.image_encoder(image, self.numtask-1) |
| else: |
| image_features, _ = self.image_encoder(image, task) |
| image_features = image_features[:, 0, :] |
| return image_features |
|
|
| def forward(self, image, get_feat=False, get_cur_feat=False, fc_only=False): |
| """ |
| return the output of fully connected layer. |
| """ |
| if fc_only: |
| fc_outs = [] |
| for ti in range(self.numtask): |
| fc_out = self.classifier_pool[ti](image) |
| fc_outs.append(fc_out) |
| return torch.cat(fc_outs, dim=1) |
|
|
| logits = [] |
| image_features, prompt_loss = self.image_encoder( |
| image, task_id=self.numtask-1, get_feat=get_feat, get_cur_feat=get_cur_feat) |
| image_features = image_features[:, 0, :] |
| image_features = image_features.view(image_features.size(0), -1) |
| for prompts in [self.classifier_pool[self.numtask-1]]: |
| logits.append(prompts(image_features)) |
|
|
| return { |
| 'logits': torch.cat(logits, dim=1), |
| 'features': image_features, |
| 'prompt_loss': prompt_loss |
| } |
|
|
| def interface(self, image): |
| image_features, _ = self.image_encoder(image, task_id=self.numtask-1) |
|
|
| image_features = image_features[:, 0, :] |
| image_features = image_features.view(image_features.size(0), -1) |
|
|
| logits = [] |
| for prompt in self.classifier_pool[:self.numtask]: |
| logits.append(prompt(image_features)) |
|
|
| logits = torch.cat(logits, 1) |
| return logits |
|
|
| def update_fc(self, nb_classes): |
| """ |
| update the number of tasks. |
| """ |
| self.numtask += 1 |
|
|
| def classifier_backup(self, task_id): |
| self.classifier_pool_backup[task_id].load_state_dict( |
| self.classifier_pool[task_id].state_dict()) |
|
|
| def classifier_recall(self): |
| self.classifier_pool.load_state_dict(self.old_state_dict) |
|
|
| def copy(self): |
| return copy.deepcopy(self) |
|
|
| def freeze(self): |
| for param in self.parameters(): |
| param.requires_grad = False |
| self.eval() |
|
|
| return self |
|
|