| import numpy as np |
| import torch |
| import torch.nn as nn |
| import PIL |
| import os |
| from typing import List |
| from torch.utils.data import DataLoader |
| from torch.utils.data import Dataset |
|
|
| class LinearHerdingBuffer: |
| def __init__(self, buffer_size, batch_size): |
| self.buffer_size = buffer_size |
| self.strategy = None |
| self.batch_size = batch_size |
| self.images, self.labels = [], [] |
| self.total_classes = 0 |
|
|
| def is_empty(self): |
| return len(self.labels) == 0 |
|
|
| def clear(self): |
| |
| del self.images |
| del self.labels |
| self.images = [] |
| self.labels = [] |
| |
| def get_all_data(self): |
| |
| return np.array(self.images), np.array(self.labels) |
| |
| def add_data(self, data:List[str], targets:List[str]): |
| |
| self.images.extend(data) |
| self.labels.extend(targets) |
|
|
|
|
| def update(self, model:nn.Module, train_loader, val_transform, task_idx:int, |
| total_cls_num:int, cur_cls_indexes, device): |
| |
| |
| chosen_indexes = self.herding_select(model, train_loader, val_transform, |
| task_idx, total_cls_num, cur_cls_indexes, |
| device) |
| |
| cur_task_dataset = train_loader.dataset |
| new_images = [] |
| new_labels = [] |
| for i in chosen_indexes: |
| new_images.append(cur_task_dataset.images[i]) |
| new_labels.append(cur_task_dataset.labels[i]) |
| |
| self.add_data(new_images, new_labels) |
|
|
| def reduce_old_data(self, task_idx:int, total_cls_num:int) -> None: |
| |
| samples_per_class = self.buffer_size // total_cls_num |
|
|
| if samples_per_class == 0: |
| print( |
| f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ", |
| f"Samples per class will be set to 1, to avoid empty buffer." |
| ) |
| samples_per_class = 1 |
|
|
| if task_idx > 0: |
| buffer_X, buffer_Y = self.get_all_data() |
| self.clear() |
| for y in np.unique(buffer_Y): |
| idx = (buffer_Y == y) |
| selected_X, selected_Y = buffer_X[idx], buffer_Y[idx] |
| self.add_data( |
| data=selected_X[:samples_per_class], |
| targets=selected_Y[:samples_per_class], |
| ) |
|
|
| |
| def herding_select(self, model:nn.Module, train_loader, val_transform, |
| task_idx:int, total_cls_num:int, cur_cls_indexes, device): |
|
|
| |
| |
| def remove_buffer_sample_in_dataset(dataset, cur_cls_indexes): |
| new_labels = [] |
| new_images = [] |
| for i in cur_cls_indexes: |
| ind = np.array(dataset.labels) == i |
| new_images.extend(list(np.array(dataset.images)[ind])) |
| new_labels.extend(list(np.array(dataset.labels)[ind])) |
| dataset.labels = new_labels |
| dataset.images = new_images |
|
|
| |
| dataset = train_loader.dataset |
|
|
| |
| remove_buffer_sample_in_dataset(dataset, cur_cls_indexes) |
|
|
| |
| dataset.trfms = val_transform |
|
|
| |
| loader = DataLoader( |
| dataset, |
| |
| |
| shuffle = False, |
| batch_size = 32, |
| |
| drop_last = False |
| ) |
| |
| |
| samples_per_class = self.buffer_size // total_cls_num |
| if samples_per_class == 0: |
| print( |
| f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ", |
| f"Samples per class will be set to 1, to avoid empty buffer." |
| ) |
| samples_per_class = 1 |
|
|
| |
| |
| extracted_features = [] |
| extracted_targets = [] |
| |
| with torch.no_grad(): |
| model.eval() |
| for data in loader: |
| image = data['image'].to(device) |
| label = data['label'].to(device) |
| |
| feats = model.backbone(image)['features'] |
| feats = feats / feats.norm(dim=1).view(-1, 1) |
| extracted_features.append(feats) |
| extracted_targets.append(label) |
| extracted_features = (torch.cat(extracted_features)).cpu() |
| extracted_targets = (torch.cat(extracted_targets)).cpu() |
|
|
| result = [] |
| for curr_cls in np.unique(extracted_targets): |
| |
| cls_ind = np.where(extracted_targets == curr_cls)[0] |
| cls_feats = extracted_features[cls_ind] |
| mean_feat = cls_feats.mean(0, keepdim=True) |
| running_sum = torch.zeros_like(mean_feat) |
| i = 0 |
| begin_index = cls_ind[0] |
| while i < samples_per_class and i < cls_feats.shape[0]: |
| cost = (mean_feat - (cls_feats + running_sum) / (i + 1)).norm(2, 1) |
|
|
| |
| |
| |
| |
| idx_min = cost.argmin().item() |
| global_index = idx_min + begin_index |
| result.append(global_index) |
| running_sum += cls_feats[idx_min:idx_min + 1] |
| cls_feats[idx_min] = cls_feats[idx_min] + 1e6 |
| i += 1 |
|
|
| return result |
| |
|
|
|
|