| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from collections import OrderedDict |
| from collections.abc import Iterable |
|
|
| class ERBuffer(nn.Module): |
| def __init__(self, capacity): |
| super().__init__() |
|
|
| |
| self.buffers = [] |
|
|
| self.cap = capacity |
| self.buffer_size = capacity |
| self.current_index = 0 |
| self.n_seen_so_far = 0 |
| self.is_full = 0 |
|
|
| |
| self.add = self.add_reservoir |
| self.sample = self.sample_random |
|
|
| def __len__(self): |
| return self.current_index |
|
|
| def add_buffer(self, name, dtype, size): |
| """ used to add extra containers (e.g. for logit storage) """ |
|
|
| tmp = torch.zeros(size=(self.cap,) + size, dtype=dtype).to(self.device) |
| self.register_buffer(f'b{name}', tmp) |
| self.buffers += [f'b{name}'] |
|
|
| def _init_buffers(self, batch): |
| created = 0 |
|
|
| for name, tensor in batch.items(): |
| bname = f'b{name}' |
| if bname not in self.buffers: |
| |
| if not type(tensor) == torch.Tensor: |
| tensor = torch.from_numpy(np.array([tensor])) |
|
|
| self.add_buffer(name, tensor.dtype, tensor.shape[1:]) |
| created += 1 |
|
|
| print(f'created buffer {name}\t {tensor.dtype}, {tensor.shape[1:]}') |
|
|
| assert created in [0, len(batch)], 'not all buffers created at the same time' |
|
|
| def add_reservoir(self, batch): |
|
|
| self._init_buffers(batch) |
|
|
| n_elem = batch['x'].shape[0] |
|
|
| place_left = max(0, self.cap - self.current_index) |
|
|
| indices = torch.FloatTensor(n_elem).to(self.device) |
| indices = indices.uniform_(0, self.n_seen_so_far).long() |
|
|
| if place_left > 0: |
| upper_bound = min(place_left, n_elem) |
| indices[:upper_bound] = torch.arange(upper_bound) + self.current_index |
|
|
| valid_indices = (indices < self.cap).long() |
| idx_new_data = valid_indices.nonzero().squeeze(-1) |
| idx_buffer = indices[idx_new_data] |
|
|
| self.n_seen_so_far += n_elem |
| self.current_index = min(self.n_seen_so_far, self.cap) |
|
|
| if idx_buffer.numel() == 0: |
| return |
|
|
| |
| for name, data in batch.items(): |
| buffer = getattr(self, f'b{name}') |
|
|
| if isinstance(data, Iterable): |
| buffer[idx_buffer] = data[idx_new_data] |
| else: |
| buffer[idx_buffer] = data |
|
|
| def add_balanced(self, batch): |
| self._init_buffers(batch) |
|
|
| n_elem = batch['x'].size(0) |
|
|
| |
| self.n_seen_so_far += n_elem |
| self.current_index = min(self.n_seen_so_far, self.cap) |
|
|
| |
| for name, data in batch.items(): |
| buffer = getattr(self, f'b{name}') |
|
|
| if not isinstance(data, Iterable): |
| data = buffer.new(size=(n_elem, *buffer.shape[1:])).fill_(data) |
|
|
| buffer = torch.cat((data, buffer))[:self.n_seen_so_far] |
| setattr(self, f'b{name}', buffer) |
|
|
| n_samples_over = buffer.size(0) - self.cap |
|
|
| |
| if n_samples_over <= 0: |
| return |
|
|
| |
| class_count = self.by.bincount() |
| rem_per_class = torch.zeros_like(class_count) |
|
|
| while rem_per_class.sum() < n_samples_over: |
| max_idx = class_count.argmax() |
| rem_per_class[max_idx] += 1 |
| class_count[max_idx] -= 1 |
|
|
| |
| classes_trimmed = rem_per_class.nonzero().flatten() |
| idx_remove = [] |
|
|
| for cls in classes_trimmed: |
| cls_idx = (self.by == cls).nonzero().view(-1) |
| idx_remove += [cls_idx[-rem_per_class[cls]:]] |
|
|
| idx_remove = torch.cat(idx_remove) |
| idx_mask = torch.BoolTensor(buffer.size(0)).to(self.device) |
| idx_mask.fill_(0) |
| idx_mask[idx_remove] = 1 |
|
|
| |
| for name, data in batch.items(): |
| buffer = getattr(self, f'b{name}') |
| buffer = buffer[~idx_mask] |
| setattr(self, f'b{name}', buffer) |
|
|
| def add_queue(self, batch): |
| self._init_buffers(batch) |
|
|
| if not hasattr(self, 'queue_ptr'): |
| self.queue_ptr = 0 |
|
|
| start_idx = self.queue_ptr |
| end_idx = (start_idx + batch['x'].size(0)) % self.cap |
|
|
| for name, data in batch.items(): |
| buffer = getattr(self, f'b{name}') |
| buffer[start_idx:end_idx] = data |
|
|
| def sample_random(self, amt, exclude_task=None, **kwargs): |
| buffers = OrderedDict() |
|
|
| if exclude_task is not None: |
| assert hasattr(self, 'bt') |
| valid_indices = torch.where(self.bt != exclude_task)[0] |
| valid_indices = valid_indices[valid_indices < self.current_index] |
| for buffer_name in self.buffers: |
| buffers[buffer_name[1:]] = getattr(self, buffer_name)[valid_indices] |
| else: |
| for buffer_name in self.buffers: |
| buffers[buffer_name[1:]] = getattr(self, buffer_name)[:self.current_index] |
|
|
| n_selected = buffers['x'].size(0) |
| if n_selected <= amt: |
| assert n_selected > 0 |
| return buffers |
| else: |
| idx_np = np.random.choice(buffers['x'].size(0), amt, replace=False) |
| indices = torch.from_numpy(idx_np).to(self.bx.device) |
|
|
| return OrderedDict({k:v[indices] for (k,v) in buffers.items()}) |
|
|
| def sample_balanced(self, amt, exclude_task=None, **kwargs): |
| buffers = OrderedDict() |
|
|
| if exclude_task is not None: |
| assert hasattr(self, 'bt') |
| valid_indices = (self.bt != exclude_task).nonzero().squeeze() |
| for buffer_name in self.buffers: |
| buffers[buffer_name[1:]] = getattr(self, buffer_name)[valid_indices] |
| else: |
| for buffer_name in self.buffers: |
| buffers[buffer_name[1:]] = getattr(self, buffer_name)[:self.current_index] |
|
|
| class_count = buffers['y'].bincount() |
|
|
| |
| class_sample_p = 1. / class_count.float() / class_count.size(0) |
| per_sample_p = class_sample_p.gather(0, buffers['y']) |
| indices = torch.multinomial(per_sample_p, amt) |
|
|
| return OrderedDict({k:v[indices] for (k,v) in buffers.items()}) |
|
|
| def sample_pos_neg(self, inc_data, task_free=True, same_task_neg=True): |
|
|
| x = inc_data['x'] |
| label = inc_data['y'] |
| task = torch.zeros_like(label).fill_(inc_data['t']) |
|
|
| |
| bx = torch.cat((self.bx[:self.current_index], x)) |
| by = torch.cat((self.by[:self.current_index], label)) |
| bt = torch.cat((self.bt[:self.current_index], task)) |
| bidx = torch.arange(bx.size(0)).to(bx.device) |
|
|
| |
| same_label = label.view(1, -1) == by.view(-1, 1) |
| same_task = task.view(1, -1) == bt.view(-1, 1) |
| same_ex = bidx[-x.size(0):].view(1, -1) == bidx.view(-1, 1) |
|
|
| task_labels = label.unique() |
| real_same_task = same_task |
|
|
| if task_free: |
| same_task = torch.zeros_like(same_task) |
|
|
| for label_ in task_labels: |
| label_exp = label_.view(1, -1).expand_as(same_task) |
| same_task = same_task | (label_exp == by.view(-1, 1)) |
|
|
| valid_pos = same_label & ~same_ex |
|
|
| if same_task_neg: |
| valid_neg = ~same_label & same_task |
| else: |
| valid_neg = ~same_label |
|
|
| |
| has_valid_pos = valid_pos.sum(0) > 0 |
| has_valid_neg = valid_neg.sum(0) > 0 |
|
|
| invalid_idx = ~has_valid_pos | ~has_valid_neg |
|
|
| if invalid_idx.sum() > 0: |
| |
| valid_pos[:, invalid_idx] = 1 |
| valid_neg[:, invalid_idx] = 1 |
|
|
| |
| is_invalid = torch.zeros_like(label).bool() |
| is_invalid[invalid_idx] = 1 |
|
|
| |
| pos_idx = torch.multinomial(valid_pos.float().T, 1).squeeze(1) |
| neg_idx = torch.multinomial(valid_neg.float().T, 1).squeeze(1) |
|
|
| n_fwd = torch.stack((bidx[-x.size(0):], pos_idx, neg_idx), 1)[~invalid_idx].unique().size(0) |
|
|
| return bx[pos_idx], \ |
| bx[neg_idx], \ |
| by[pos_idx], \ |
| by[neg_idx], \ |
| is_invalid, \ |
| n_fwd |
|
|
| def sample_minimal_pos_neg(self, inc_data, task_free=True, same_task_neg=True): |
| """ maximize choosing the incoming data to minimize forward passes """ |
|
|
| x = inc_data['x'] |
| label = inc_data['y'] |
| task = torch.zeros_like(label).fill_(inc_data['t']) |
|
|
| ''' |
| # we need to create an "augmented" buffer containing the incoming data |
| bx = torch.cat((self.bx[:self.current_index], x)) |
| by = torch.cat((self.by[:self.current_index], label)) |
| bt = torch.cat((self.bt[:self.current_index], task)) |
| bidx = torch.arange(bx.size(0)).to(bx.device) |
| |
| # buf_size x label_size |
| same_label = label.view(1, -1) == by.view(-1, 1) |
| same_task = task.view(1, -1) == bt.view(-1, 1) |
| same_ex = bidx[-x.size(0):].view(1, -1) == bidx.view(-1, 1) |
| ''' |
|
|
| bidx = torch.arange(x.size(0)).to(x.device) |
|
|
| |
| same_label = label.view(1, -1) == label.view(-1, 1) |
| same_task = task.view(1, -1) == task.view(-1, 1) |
| same_ex = bidx.view(1, -1) == bidx.view(-1, 1) |
|
|
| task_labels = label.unique() |
| real_same_task = same_task |
|
|
| |
| |
| if task_free: |
| same_task = torch.zeros_like(same_task) |
|
|
| for label_ in task_labels: |
| label_exp = label_.view(1, -1).expand_as(same_task) |
| same_task = same_task | (label_exp == label.view(-1, 1)) |
|
|
| valid_pos = same_label & ~same_ex |
|
|
| if same_task_neg: |
| valid_neg = ~same_label & same_task |
| else: |
| valid_neg = ~same_label |
|
|
| |
| has_valid_pos = valid_pos.sum(0) > 0 |
| has_valid_neg = valid_neg.sum(0) > 0 |
|
|
| invalid_idx = ~has_valid_pos | ~has_valid_neg |
|
|
| if invalid_idx.any(): |
| |
| valid_pos[:, invalid_idx] = 1 |
| valid_neg[:, invalid_idx] = 1 |
|
|
| |
| is_invalid = torch.zeros_like(label).bool() |
| is_invalid[invalid_idx] = 1 |
|
|
| |
| pos_idx = torch.multinomial(valid_pos.float().T, 1).squeeze(1) |
| neg_idx = torch.multinomial(valid_neg.float().T, 1).squeeze(1) |
|
|
| |
| pos_x, neg_x = x[pos_idx], x[neg_idx] |
| pos_y, neg_y = label[pos_idx], label[neg_idx] |
|
|
| n_fwd = torch.stack((bidx, pos_idx, neg_idx), 1)[~invalid_idx].unique().size(0) |
|
|
| |
| if invalid_idx.any(): |
| |
| invalid_data = OrderedDict() |
| invalid_data['x'] = x[invalid_idx] |
| invalid_data['y'] = label[invalid_idx] |
| invalid_data['t'] = inc_data['t'] |
|
|
| n_pos_x, n_neg_x, n_pos_y, n_neg_y, n_is_invalid, n_new_fwd = \ |
| self.sample_pos_neg(invalid_data, task_free=task_free, same_task_neg=same_task_neg) |
|
|
| |
| pos_x[invalid_idx][~n_is_invalid].data.copy_(n_pos_x[~n_is_invalid]) |
| neg_x[invalid_idx][~n_is_invalid].data.copy_(n_neg_x[~n_is_invalid]) |
| pos_y[invalid_idx][~n_is_invalid].data.copy_(n_pos_y[~n_is_invalid]) |
| neg_y[invalid_idx][~n_is_invalid].data.copy_(n_neg_y[~n_is_invalid]) |
|
|
| invalid_idx[invalid_idx].data.copy_(n_is_invalid) |
|
|
| n_fwd += n_new_fwd |
|
|
| return pos_x, neg_x, pos_y, neg_y, is_invalid, n_fwd |
|
|