| """
|
| @inproceedings{DBLP:conf/iccv/ShiY23,
|
| title = {Prototype Reminiscence and Augmented Asymmetric Knowledge Aggregation for Non-Exemplar Class-Incremental Learning},
|
| author = {Shi, Wuxuan and Ye, Mang},
|
| booktitle = {2023 IEEE/CVF International Conference on Computer Vision (ICCV)},
|
| pages = {1772-1781},
|
| publisher = {Computer Vision Foundation / {IEEE}},
|
| year = {2023}
|
| }
|
|
|
| https://openaccess.thecvf.com/content/ICCV2023/papers/Shi_Prototype_Reminiscence_and_Augmented_Asymmetric_Knowledge_Aggregation_for_Non-Exemplar_Class-Incremental_ICCV_2023_paper.pdf
|
|
|
| Adapted from https://github.com/ShiWuxuan/PRAKA
|
| """
|
|
|
| from torch.nn import functional as F
|
| import os
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| import math
|
| import copy
|
| from core.model import Finetune
|
|
|
| class joint_network(nn.Module):
|
| def __init__(self, numclass, feature_extractor):
|
| '''
|
| Code Reference:
|
| https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| '''
|
| super(joint_network, self).__init__()
|
| self.feature = feature_extractor
|
| self.fc = nn.Linear(512, numclass * 4, bias=True)
|
| self.classifier = nn.Linear(512, numclass, bias=True)
|
|
|
| def forward(self, input):
|
| '''
|
| Code Reference:
|
| https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| '''
|
| x = self.feature(input)
|
| x = self.classifier(x)
|
| return x
|
|
|
| def Incremental_learning(self, numclass):
|
| '''
|
| Update the fully connected (fc) layer and classifier layer to accommodate the new number of classes.
|
|
|
| This function modifies the output dimensions of the model's fully connected layer (`fc`)
|
| and the classifier layer based on the total number of classes after the current task.
|
| It ensures that the new layers retain the weights and biases from the previous configuration
|
| for the classes that were previously learned.
|
|
|
| Parameters:
|
| - numclass (int): The total number of classes after the current task, including both old and new classes.
|
|
|
| Notes:
|
| - The `fc` layer's output dimension is set to `numclass * 4`.
|
| - The classifier layer is adjusted to match the new total number of classes, while retaining the previously learned weights and biases.
|
|
|
| Code Reference:
|
| https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| '''
|
| weight = self.fc.weight.data
|
| bias = self.fc.bias.data
|
| in_feature = self.fc.in_features
|
| out_feature = self.fc.out_features
|
|
|
| self.fc = nn.Linear(in_feature, numclass * 4, bias=True)
|
| self.fc.weight.data[:out_feature] = weight[:out_feature]
|
| self.fc.bias.data[:out_feature] = bias[:out_feature]
|
|
|
| weight = self.classifier.weight.data
|
| bias = self.classifier.bias.data
|
| in_feature = self.classifier.in_features
|
| out_feature = self.classifier.out_features
|
|
|
| self.classifier = nn.Linear(in_feature, numclass, bias=True)
|
| self.classifier.weight.data[:out_feature] = weight[:out_feature]
|
| self.classifier.bias.data[:out_feature] = bias[:out_feature]
|
|
|
| def feature_extractor(self, inputs):
|
| '''
|
| Code Reference:
|
| https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py
|
| '''
|
| return self.feature(inputs)
|
|
|
| class PRAKA(nn.Module):
|
| def __init__(self, backbone, feat_dim, num_class, **kwargs):
|
|
|
| super().__init__()
|
| self.device = kwargs['device']
|
| self.kwargs = kwargs
|
| self.size = 32
|
|
|
| encoder = backbone
|
| self.model = joint_network(kwargs["init_cls_num"], encoder)
|
| self.radius = 0
|
| self.prototype = None
|
| self.numsamples = None
|
| self.numclass = kwargs["init_cls_num"]
|
| self.task_size = kwargs["inc_cls_num"]
|
| self.old_model = None
|
|
|
| self.task_idx = 0
|
|
|
| def before_task(self, task_idx, buffer, train_loader, test_loaders):
|
| self.task_idx = task_idx
|
| if task_idx > 0:
|
| self.model.Incremental_learning(self.numclass)
|
| self.model.to(self.device)
|
|
|
| def observe(self, data):
|
| '''
|
| Processes a batch of training data to compute predictions, accuracy, and loss.
|
|
|
| Parameters:
|
| - data: Dictionary containing the batch of training samples
|
| - 'image': Tensor of input images
|
| - 'label': Tensor of ground truth labels
|
|
|
| Returns:
|
| - predictions: Tensor of predicted class labels for the input images
|
| - accuracy: Float value representing the accuracy of the model on the current batch
|
| - loss: Float value representing the computed loss for the batch
|
|
|
| Description:
|
| This function is called during the training phase. It performs the following steps:
|
| 1. Extracts the images and labels from the provided data dictionary and transfers them to the device.
|
| 2. Augments the images by rotating them by 0, 90, 180, and 270 degrees, and creates corresponding labels for these augmented images.
|
| 3. Computes the loss using the augmented images and labels.
|
| 4. Evaluates the model's performance on the current batch by calculating the accuracy and loss.
|
| 5. Returns the predictions, accuracy, and loss for the batch.
|
|
|
| Example Usage:
|
| predictions, accuracy, loss = observe(data)
|
| '''
|
| images, labels = data['image'].to(self.device), data['label'].to(self.device)
|
|
|
|
|
| images = torch.stack([torch.rot90(images, k, (2, 3)) for k in range(4)], 1)
|
| images = images.view(-1, 3, self.size, self.size)
|
|
|
| joint_labels = torch.stack([labels * 4 + k for k in range(4)], 1).view(-1)
|
| if self.task_idx == 0:
|
| old_class = 0
|
| else:
|
| old_class = self.kwargs['init_cls_num'] + self.kwargs['inc_cls_num'] * (self.task_idx - 1)
|
|
|
| loss, single_preds = self._compute_loss(images, joint_labels, labels, old_class)
|
|
|
| preds = torch.argmax(single_preds, dim=-1)
|
| return preds, (preds == labels).sum().item() / len(labels), loss
|
|
|
| def inference(self, data):
|
| '''
|
| Performs inference on a batch of test samples and computes the classification results and accuracy.
|
|
|
| Parameters:
|
| - data: Dictionary containing the batch of test samples
|
| - 'image': Tensor of input images
|
| - 'label': Tensor of ground truth labels
|
|
|
| Returns:
|
| - predictions: Tensor of predicted class labels for the input images
|
| - accuracy: Float value representing the accuracy of the model on the current batch
|
|
|
| Example Usage:
|
| predictions, accuracy = inference(data)
|
| '''
|
|
|
| imgs, labels = data['image'].to(self.device), data['label'].to(self.device)
|
|
|
| preds = torch.argmax(self.model(imgs), dim=-1)
|
|
|
| return preds, (preds == labels).sum().item() / len(labels)
|
|
|
| def _compute_loss(self, imgs, joint_labels, labels, old_class=0):
|
| '''
|
| Computes the loss for a batch of images and labels.
|
|
|
| Parameters:
|
| - imgs: Tensor of input images
|
| - joint_labels: Tensor of labels for images augmented with rotations (0°, 90°, 180°, 270°)
|
| - labels: Tensor of ground truth labels for the images
|
| - old_class: Integer indicating the number of old classes (default is 0)
|
|
|
| Returns:
|
| - loss: Scalar tensor representing the total computed loss
|
| - preds: Tensor of predictions for the original (non-augmented) images
|
|
|
| Example Usage:
|
| loss, preds = self._compute_loss(imgs, joint_labels, labels, old_class)
|
|
|
|
|
| Code Reference:
|
| https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py
|
| '''
|
|
|
| feature = self.model.feature(imgs)
|
|
|
|
|
| joint_preds = self.model.fc(feature)
|
| single_preds = self.model.classifier(feature)[::4]
|
| joint_preds, joint_labels, single_preds, labels = joint_preds.to(self.device), joint_labels.to(self.device), single_preds.to(self.device), labels.to(self.device)
|
| joint_loss = nn.CrossEntropyLoss()(joint_preds/self.kwargs["temp"], joint_labels)
|
| single_loss = nn.CrossEntropyLoss()(single_preds/self.kwargs["temp"], labels)
|
|
|
|
|
| agg_preds = 0
|
| for i in range(4):
|
| agg_preds = agg_preds + joint_preds[i::4, i::4] / 4
|
|
|
| distillation_loss = F.kl_div(F.log_softmax(single_preds, 1),
|
| F.softmax(agg_preds.detach(), 1),
|
| reduction='batchmean')
|
| if old_class == 0:
|
| return joint_loss + single_loss + distillation_loss, single_preds
|
| else:
|
| feature_old = self.old_model.feature(imgs)
|
|
|
| loss_kd = torch.dist(feature, feature_old, 2)
|
|
|
|
|
| proto_aug = []
|
| proto_aug_label = []
|
| old_class_list = list(self.prototype.keys())
|
| for _ in range(feature.shape[0] // 4):
|
| i = np.random.randint(0, feature.shape[0])
|
| np.random.shuffle(old_class_list)
|
| lam = np.random.beta(0.5, 0.5)
|
| if lam > 0.6:
|
| lam = lam * 0.6
|
|
|
| if np.random.random() >= 0.5:
|
|
|
| temp = (1 + lam) * self.prototype[old_class_list[0]] - lam * feature.detach().cpu().numpy()[i]
|
| else:
|
| temp = (1 - lam) * self.prototype[old_class_list[0]] + lam * feature.detach().cpu().numpy()[i]
|
|
|
|
|
| proto_aug.append(temp)
|
| proto_aug_label.append(old_class_list[0])
|
|
|
| proto_aug = torch.from_numpy(np.float32(np.asarray(proto_aug))).float().to(self.device)
|
| proto_aug_label = torch.from_numpy(np.asarray(proto_aug_label)).to(self.device)
|
| aug_preds = self.model.classifier(proto_aug)
|
| joint_aug_preds = self.model.fc(proto_aug)
|
| agg_preds = joint_aug_preds[:, ::4]
|
| aug_distillation_loss = F.kl_div(F.log_softmax(aug_preds, 1),
|
| F.softmax(agg_preds.detach(), 1),
|
| reduction='batchmean')
|
|
|
| loss_protoAug = nn.CrossEntropyLoss()(aug_preds/self.kwargs["temp"], proto_aug_label) + nn.CrossEntropyLoss()(joint_aug_preds/self.kwargs["temp"], proto_aug_label*4) + aug_distillation_loss
|
| return joint_loss + single_loss + distillation_loss + self.kwargs["protoAug_weight"]*loss_protoAug + self.kwargs["kd_weight"]*loss_kd, single_preds
|
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders):
|
| '''
|
| Perform operations after completing the training for a specific task.
|
| 1. Save the prototypes of the current model.
|
| 2. Save the current model state to a file.
|
| 3. Load the saved model state as the old model for future reference.
|
|
|
| Parameters:
|
| - task_idx (int): The index of the current task.
|
| - buffer: Data buffer for storing samples (not used in this function).
|
| - train_loader (DataLoader): DataLoader for the training dataset of the current task.
|
| - test_loaders (list of DataLoader): List of DataLoaders for test datasets of different tasks.
|
|
|
| Example Usage:
|
| self.after_task(task_idx, buffer, train_loader, test_loaders)
|
| '''
|
|
|
| self.protoSave(self.model, train_loader, self.task_idx)
|
| self.numclass += self.task_size
|
|
|
| self.old_model = copy.deepcopy(self.model)
|
| self.old_model.eval()
|
|
|
| def protoSave(self, model, loader, current_task):
|
| '''
|
| Save the class prototypes for the current task.
|
|
|
| This function extracts features from the data using the provided model and computes
|
| class prototypes based on these features. The prototypes are then saved to the class
|
| attributes. If it's the first task, the prototypes are initialized. For subsequent
|
| tasks, the prototypes are updated with new class information.
|
|
|
| Parameters:
|
| - model: The model used for feature extraction.
|
| - loader: DataLoader providing the dataset for the current task.
|
| - current_task (int): The index of the current task.
|
|
|
| Code Reference:
|
| https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py
|
| '''
|
|
|
| features = []
|
| labels = []
|
| model.eval()
|
|
|
| with torch.no_grad():
|
| for i, batch in enumerate(loader):
|
| images, target = batch['image'], batch['label']
|
| feature = model.feature(images.to(self.device))
|
| if feature.shape[0] == loader.batch_size:
|
| labels.append(target.numpy())
|
| features.append(feature.cpu().numpy())
|
|
|
| labels_set = np.unique(labels)
|
| labels = np.array(labels)
|
| labels = np.reshape(labels, labels.shape[0] * labels.shape[1])
|
| features = np.array(features)
|
| features = np.reshape(features, (features.shape[0] * features.shape[1], features.shape[2]))
|
|
|
|
|
| prototype = {}
|
| class_label = []
|
| numsamples = {}
|
|
|
| for item in labels_set:
|
| index = np.where(item == labels)[0]
|
| class_label.append(item)
|
| feature_classwise = features[index]
|
| prototype[item] = np.mean(feature_classwise, axis=0)
|
|
|
| numsamples[item] = feature_classwise.shape[0]
|
| if current_task == 0:
|
| self.prototype = prototype
|
| self.class_label = class_label
|
| self.numsamples = numsamples
|
| else:
|
| self.prototype.update(prototype)
|
| self.class_label = np.concatenate((class_label, self.class_label), axis=0)
|
| self.numsamples.update(numsamples)
|
|
|
| def get_parameters(self, config):
|
| return self.model.parameters()
|
|
|
|
|