| import copy |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import os |
| import random |
|
|
| from tqdm import tqdm |
|
|
| from .backbone.clip import tokenize |
| from core.data import dataloader |
| from core.model import backbone |
| from core.model.finetune import Finetune |
| from torch.utils.data import DataLoader |
|
|
|
|
| def get_class_ids_per_task(init_cls_num, inc_cls_num, class_order): |
| yield class_order[:init_cls_num] |
| for i in range(init_cls_num, len(class_order), inc_cls_num): |
| yield class_order[i:i + inc_cls_num] |
|
|
| def get_class_names(classes_names, prev_cls_num, accu_cls_num): |
| return [classes_names[i] for i in range(prev_cls_num, accu_cls_num)] |
|
|
| def shrink_cov(cov): |
| diag_mean = torch.mean(torch.diagonal(cov)) |
| off_diag = cov.clone() |
| off_diag.fill_diagonal_(0.0) |
| mask = off_diag != 0.0 |
| off_diag_mean = (off_diag*mask).sum() / mask.sum() |
| iden = torch.eye(cov.shape[0], device=cov.device) |
| alpha1 = 1 |
| alpha2 = 1 |
| cov_ = cov + (alpha1*diag_mean*iden) + (alpha2*off_diag_mean*(1-iden)) |
| return cov_ |
| def sample(mean, cov, size, shrink=False): |
| vec = torch.randn(size, mean.shape[-1], device=mean.device) |
| if shrink: |
| cov = shrink_cov(cov) |
| sqrt_cov = torch.linalg.cholesky(cov) |
| vec = vec @ sqrt_cov.t() |
| vec = vec + mean |
| return vec |
|
|
| def seed_everything(seed=0): |
| """Fix all random seeds""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
|
|
| """ |
| This clas refer to the following repository: |
| https://github.com/linlany/RAPF |
| """ |
| class ClassIncrementalCLIP(nn.Module): |
| def __init__(self, model, **kwargs): |
| super().__init__() |
| device = kwargs['device'] |
| fp16 = kwargs['fp16'] |
| mix_bias = kwargs['mix_bias'] |
| self.prompt_template = kwargs['prompt_template'] |
| self.initial_increment = kwargs['init_cls_num'] |
| self.increment = kwargs['inc_cls_num'] |
| self.device = device |
| self.classes_names = None |
| |
| self.visual = model.visual |
| self.transformer = model.transformer |
| self.positional_embedding = model.positional_embedding |
| self.token_embedding = model.token_embedding |
| self.ln_final = model.ln_final |
| self.text_projection = model.text_projection |
| self.logit_scale = model.logit_scale |
| |
| |
| self.current_class_names = [] |
| self.text_tokens = None |
| self.dtype = torch.float16 if fp16 else torch.float32 |
| self.adapter = nn.Linear(512, 512, bias=False ,device=device) |
| self.clip_type = model.dtype |
|
|
|
|
| |
| self.old_adapter = None |
| self.old_edge_samples = [] |
| self.old_edge_samples_labels = [] |
| self.old_edge_samples_nearest_labels = [] |
|
|
| |
| self.class_mean_list = [] |
| self.class_cov_list = [] |
|
|
| self.class_diff = None |
| self.nearest_class = None |
| self.class_edge_distance = [] |
| self.mix_b = mix_bias |
|
|
| def encode_text(self, text, prompt=False): |
| x = self.token_embedding(text).type(self.clip_type) |
| x = x + self.positional_embedding.type(self.clip_type) |
| x = x.permute(1, 0, 2) |
| x = self.transformer(x) |
| x = x.permute(1, 0, 2) |
| x = self.ln_final(x) |
|
|
| x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|
|
| return x |
| |
| def encode_image(self, image): |
| |
| image = image.to(self.clip_type) |
| return self.visual(image) |
|
|
| |
| @torch.no_grad() |
| def get_class_name_features(self): |
| class_name_features = self.encode_text(self.text_tokens) |
| return class_name_features.type(torch.float32) |
|
|
| def forward(self, image, ori_ima_f=False, memory_data=None, not_ini=False, edge_sample=None, prompt=False): |
| image = image.type(torch.float16) |
| with torch.no_grad(): |
| text_features = self.encode_text(self.text_tokens) |
|
|
|
|
| with torch.no_grad(): |
| image_features = self.encode_image(image) |
| original_image_features = image_features.clone() |
| if memory_data is not None: |
| memory_data = memory_data.type(self.dtype) |
| image_features = torch.cat([image_features, memory_data], dim=0) |
| if edge_sample is not None: |
| edge_sample = edge_sample.type(self.dtype) |
| edge_num = edge_sample.shape[0] |
| image_features = torch.cat([image_features, edge_sample], dim=0) |
|
|
| image_features = self.adapter(image_features.type(self.dtype).detach()).type(self.clip_type) |
|
|
| image_features = image_features / image_features.norm(dim=1, keepdim=True) |
| if edge_sample is not None: |
| edge_sample_features = image_features[-edge_num:] |
| image_features = image_features[:-edge_num] |
| text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
|
|
|
| logit_scale = self.logit_scale.exp() |
| logits_per_image = logit_scale * image_features @ text_features.t().type(image_features.dtype) |
| |
| probs = logits_per_image |
| if not_ini: |
| with torch.no_grad(): |
| old_memory_feature = self.old_adapter(memory_data) |
| old_memory_feature = old_memory_feature / old_memory_feature.norm(dim=1, keepdim=True) |
| if edge_sample is not None: |
| return probs, image_features, old_memory_feature, edge_sample_features |
| return probs, image_features, old_memory_feature, text_features |
| if ori_ima_f: |
| if memory_data is not None: |
| image_features = image_features[:-memory_data.shape[0]] |
| return probs, original_image_features, image_features |
| return probs, image_features, None, None |
|
|
| def adaptation(self, task_id, prev_cls_num, accu_cls_num, threshold=0): |
| self.current_class_names += get_class_names(self.classes_names, prev_cls_num, accu_cls_num) |
| self.text_tokens = tokenize( |
| [self.prompt_template.format(c) for c in self.current_class_names] |
| ).to(self.device) |
| self.text_end = self.text_tokens.max(dim=-1)[1] |
| self.class_name_features = self.get_class_name_features() |
| self.class_name_features = self.class_name_features / self.class_name_features.norm(dim=-1, p=2, keepdim=True) |
| self.queue_empty = True |
| self.hard_pairs = None |
| if task_id>0: |
| self.old_adapter = copy.deepcopy(self.adapter) |
| dist_list = [] |
| for k, class_name_feature in enumerate(self.class_name_features[:prev_cls_num]): |
| diff = torch.cdist(self.class_name_features[prev_cls_num:].type(torch.float32), class_name_feature.unsqueeze(0).type(torch.float32)).squeeze() |
| dist_list.append(diff) |
| dist_list = torch.stack(dist_list) |
| self.class_diff = dist_list |
| mask = self.class_diff < threshold |
| indices = torch.nonzero(mask) |
| self.hard_new_class = torch.unique(indices[:,1]) + self.initial_increment+(task_id-1) * self.increment |
| num_hard_class = self.hard_new_class.shape[0] |
| self.hard_pairs = indices |
| self.hard_pairs[:,1] = self.hard_pairs[:,1]+self.initial_increment+(task_id-1) * self.increment |
| def get_old_edge_samples(self, batch_size): |
| random_select = torch.randperm(self.old_edge_samples.shape[0])[:batch_size] |
| return self.old_edge_samples[random_select], self.old_edge_samples_labels[random_select], self.old_edge_samples_nearest_labels[random_select] |
|
|
|
|
| def analyze_mean_cov(self, features, labels): |
| label = torch.sort(torch.unique(labels))[0] |
| for l in label: |
| index = torch.nonzero(labels == l) |
| index = index.squeeze() |
| class_data = features[index] |
| mean = class_data.mean(dim=0) |
| cov = torch.cov(class_data.t()) + 1e-4* torch.eye(class_data.shape[-1], device=class_data.device) |
| distance = torch.cdist(class_data, mean.unsqueeze(0)).squeeze() |
| max_distance = torch.sort(distance)[0][-10:] |
| self.class_edge_distance.append((max_distance.mean()-max_distance.min(), max_distance.max() - max_distance.mean(), max_distance.mean())) |
| self.class_mean_list.append(mean) |
| self.class_cov_list.append(cov) |
|
|
| def mix_matrix(self): |
| if self.old_adapter is not None: |
| weight_new = self.adapter.weight.data |
| weight_old = self.old_adapter.weight.data |
| dist = (weight_new - weight_old).abs() |
| U_old, S_old, V_old = torch.linalg.svd(weight_old) |
| P_new = U_old.T @ weight_new |
| dist = (P_new - torch.diag(S_old)@V_old).abs() |
| mask = dist / dist.max() |
| mask += self.mix_b |
| mask = torch.clamp(mask, max=1) |
| right = P_new * mask + torch.diag(S_old)@V_old * (1-mask) |
| weight = U_old @ right |
| self.adapter.weight.data = weight |
|
|
| """ |
| This clas refer to the following repository: |
| https://github.com/linlany/RAPF |
| """ |
| class RAPF(nn.Module): |
| def __init__(self, backbone, **kwargs): |
| super().__init__() |
| seed = kwargs['seed'] |
| seed_everything(seed) |
| self.backbone = backbone |
| self.kwargs = kwargs |
| self.model = ClassIncrementalCLIP(self.backbone, **kwargs) |
| self.device = kwargs['device'] |
| self.init_cls_num = kwargs['init_cls_num'] |
| self.inc_cls_num = kwargs['inc_cls_num'] |
| self.beta = kwargs['beta'] |
| self.shrinkage = kwargs['shrinkage'] |
| self.threshold = kwargs['threshold'] |
| self.train_batch_size = kwargs['train_batch_size'] |
| self.batch_size = kwargs['batch_size'] |
| self.num_workers = kwargs['num_workers'] |
| self.seed = seed |
|
|
| self.prev_cls_num = 0 |
| self.accu_cls_num = 0 |
|
|
|
|
|
|
| def before_task(self, task_id, buffer, train_loader, test_loaders): |
| self.task_id = task_id |
| if self.task_id == 0: |
| self.accu_cls_num = self.init_cls_num |
| else: |
| self.accu_cls_num += self.inc_cls_num |
|
|
| self.model.adaptation(task_id, self.prev_cls_num, self.accu_cls_num, self.threshold) |
| if self.task_id > 0: |
| random_class_order_list = list(range(self.init_cls_num+(self.task_id-1)*self.inc_cls_num)) |
| random.shuffle(random_class_order_list) |
| self.random_class_order_list = random_class_order_list |
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| sample_data = [] |
| sample_target = [] |
| sample_after_adapt_feature = [] |
| model = self.model |
| for batch in tqdm(train_loader, total=len(train_loader)): |
| feats = batch['image'] |
| target = batch['label'] |
| feats, target = feats.to(self.device), target.to(self.device) |
| with torch.no_grad(): |
| _, ori_ima_feat, after_adapt_feature = model(feats, ori_ima_f=True) |
| sample_data.append(ori_ima_feat) |
| sample_target.append(target) |
| sample_after_adapt_feature.append(after_adapt_feature) |
| sample_target = torch.cat(sample_target, dim=0) |
| sample_data = torch.cat(sample_data, dim=0) |
| sample_after_adapt_feature = torch.cat(sample_after_adapt_feature, dim=0) |
| model.analyze_mean_cov(sample_data, sample_target) |
| model.mix_matrix() |
| self.prev_cls_num = self.accu_cls_num |
|
|
| def get_parameters(self, config): |
| return self.model.adapter.parameters() |
|
|
| def observe(self, data): |
| loss = torch.tensor(0.0).to(self.device) |
| loss_c = torch.tensor(0.0).to(self.device) |
| loss_hinge = torch.tensor(0.0).to(self.device) |
| |
| inputs = data['image'] |
| targets = data['label'] |
| inputs, targets = inputs.to(self.device), targets.to(self.device) |
| sg_inputs = None |
| edge_sample = None |
| ori_targets = targets.clone() |
| model = self.model |
| if self.task_id > 0: |
| sg_inputs = [] |
| sg_targets = [] |
| |
| |
|
|
| random_class_order_list = self.random_class_order_list |
| batch_id = data['batch_id'] |
| if self.inc_cls_num == 5: |
| list_for_one_batch = [random_class_order_list[batch_id*4%len(random_class_order_list)], random_class_order_list[(batch_id*4+1)%len(random_class_order_list)], random_class_order_list[(batch_id*4+2)%len(random_class_order_list)], random_class_order_list[(batch_id*4+3)%len(random_class_order_list)]] |
| else: |
| list_for_one_batch = [random_class_order_list[batch_id*2%len(random_class_order_list)], random_class_order_list[(batch_id*2+1)%len(random_class_order_list)]] |
|
|
| |
| for i in list_for_one_batch: |
| sg_inputs.append(sample(model.class_mean_list[i], model.class_cov_list[i],int(10*self.beta), shrink=self.shrinkage)) |
| sg_targets.append(torch.ones(int(10*self.beta), dtype=torch.long, device=self.device)*i) |
| sg_inputs = torch.cat(sg_inputs, dim=0) |
| sg_targets = torch.cat(sg_targets, dim=0) |
| targets = torch.cat([targets, sg_targets], dim=0) |
| if model.hard_pairs is not None and model.hard_pairs.shape[0] > 0: |
| edge_sample = [] |
| edge_p_target = [] |
| edge_n_target = [] |
| for hard_pair in model.hard_pairs: |
| edge_sample.append(sample(model.class_mean_list[hard_pair[0]], model.class_cov_list[hard_pair[0]],int(20*self.beta), shrink=self.shrinkage)) |
| edge_p_target.append(torch.ones(int(20*self.beta), dtype=torch.long, device=self.device)*hard_pair[0]) |
| edge_n_target.append(torch.ones(int(20*self.beta), dtype=torch.long, device=self.device)*hard_pair[1]) |
| edge_sample = torch.cat(edge_sample, dim=0) |
| edge_p_target = torch.cat(edge_p_target, dim=0) |
| edge_n_target = torch.cat(edge_n_target, dim=0) |
| if self.task_id > 0: |
| not_ini = True |
| else: |
| not_ini = False |
| outputs, _, __, edge_sample_features = model(inputs, memory_data=sg_inputs, not_ini=not_ini, edge_sample=edge_sample, prompt=False) |
|
|
| if self.task_id > 0: |
| if edge_sample is not None: |
| edge_sample_features = edge_sample_features / edge_sample_features.norm(dim=-1, keepdim=True) |
| edge_target_features = model.class_name_features[edge_p_target].type(edge_sample_features.dtype) |
| edge_target_features = edge_target_features / edge_target_features.norm(dim=-1, keepdim=True) |
| edge_nearest_class_features = model.class_name_features[edge_n_target].type(edge_sample_features.dtype) |
| edge_nearest_class_features = edge_nearest_class_features / edge_nearest_class_features.norm(dim=-1, keepdim=True) |
| loss_hinge = torch.relu(- (edge_sample_features * edge_target_features.clone().detach()).sum(-1) + (edge_sample_features * edge_nearest_class_features.clone().detach()).sum(-1) + 0.1).mean() |
| loss_c = torch.nn.functional.cross_entropy(outputs, targets.detach()) |
| if edge_sample is not None: |
| loss = loss_c + loss_hinge |
| else: |
| loss = loss_c |
| |
| |
| |
| predicted_labels = outputs.argmax(dim=1) |
| predicted_labels = predicted_labels[:ori_targets.size(0)] |
| corrects = (predicted_labels == ori_targets).sum().item() |
| total_predictions = ori_targets.size(0) |
| accuracy = corrects / total_predictions |
| return predicted_labels, accuracy, loss |
|
|
|
|
| def inference(self, data): |
| feats = data['image'] |
| target = data['label'] |
| feats, target = feats.to(self.device), target.to(self.device) |
| model = self.model |
| with torch.no_grad(): |
| outputs, _, __, ___ = model(feats, prompt=False) |
| prob_outputs = torch.nn.functional.softmax(outputs, dim=-1) |
| predicted_labels = prob_outputs.argmax(dim=1) |
| corrects = (predicted_labels == target).sum().item() |
| total_predictions = target.size(0) |
| accurcy = corrects / total_predictions |
| return prob_outputs, accurcy |