| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from collections import OrderedDict |
| | from collections.abc import Iterable |
| |
|
| | class ERBuffer(nn.Module): |
| | def __init__(self, capacity): |
| | super().__init__() |
| |
|
| | |
| | self.buffers = [] |
| |
|
| | self.cap = capacity |
| | self.buffer_size = capacity |
| | self.current_index = 0 |
| | self.n_seen_so_far = 0 |
| | self.is_full = 0 |
| |
|
| | |
| | self.add = self.add_reservoir |
| | self.sample = self.sample_random |
| |
|
| | def __len__(self): |
| | return self.current_index |
| |
|
| | def add_buffer(self, name, dtype, size): |
| | """ used to add extra containers (e.g. for logit storage) """ |
| |
|
| | tmp = torch.zeros(size=(self.cap,) + size, dtype=dtype).to(self.device) |
| | self.register_buffer(f'b{name}', tmp) |
| | self.buffers += [f'b{name}'] |
| |
|
| | def _init_buffers(self, batch): |
| | created = 0 |
| |
|
| | for name, tensor in batch.items(): |
| | bname = f'b{name}' |
| | if bname not in self.buffers: |
| | |
| | if not type(tensor) == torch.Tensor: |
| | tensor = torch.from_numpy(np.array([tensor])) |
| |
|
| | self.add_buffer(name, tensor.dtype, tensor.shape[1:]) |
| | created += 1 |
| |
|
| | print(f'created buffer {name}\t {tensor.dtype}, {tensor.shape[1:]}') |
| |
|
| | assert created in [0, len(batch)], 'not all buffers created at the same time' |
| |
|
| | def add_reservoir(self, batch): |
| |
|
| | self._init_buffers(batch) |
| |
|
| | n_elem = batch['x'].shape[0] |
| |
|
| | place_left = max(0, self.cap - self.current_index) |
| |
|
| | indices = torch.FloatTensor(n_elem).to(self.device) |
| | indices = indices.uniform_(0, self.n_seen_so_far).long() |
| |
|
| | if place_left > 0: |
| | upper_bound = min(place_left, n_elem) |
| | indices[:upper_bound] = torch.arange(upper_bound) + self.current_index |
| |
|
| | valid_indices = (indices < self.cap).long() |
| | idx_new_data = valid_indices.nonzero().squeeze(-1) |
| | idx_buffer = indices[idx_new_data] |
| |
|
| | self.n_seen_so_far += n_elem |
| | self.current_index = min(self.n_seen_so_far, self.cap) |
| |
|
| | if idx_buffer.numel() == 0: |
| | return |
| |
|
| | |
| | for name, data in batch.items(): |
| | buffer = getattr(self, f'b{name}') |
| |
|
| | if isinstance(data, Iterable): |
| | buffer[idx_buffer] = data[idx_new_data] |
| | else: |
| | buffer[idx_buffer] = data |
| |
|
| | def add_balanced(self, batch): |
| | self._init_buffers(batch) |
| |
|
| | n_elem = batch['x'].size(0) |
| |
|
| | |
| | self.n_seen_so_far += n_elem |
| | self.current_index = min(self.n_seen_so_far, self.cap) |
| |
|
| | |
| | for name, data in batch.items(): |
| | buffer = getattr(self, f'b{name}') |
| |
|
| | if not isinstance(data, Iterable): |
| | data = buffer.new(size=(n_elem, *buffer.shape[1:])).fill_(data) |
| |
|
| | buffer = torch.cat((data, buffer))[:self.n_seen_so_far] |
| | setattr(self, f'b{name}', buffer) |
| |
|
| | n_samples_over = buffer.size(0) - self.cap |
| |
|
| | |
| | if n_samples_over <= 0: |
| | return |
| |
|
| | |
| | class_count = self.by.bincount() |
| | rem_per_class = torch.zeros_like(class_count) |
| |
|
| | while rem_per_class.sum() < n_samples_over: |
| | max_idx = class_count.argmax() |
| | rem_per_class[max_idx] += 1 |
| | class_count[max_idx] -= 1 |
| |
|
| | |
| | classes_trimmed = rem_per_class.nonzero().flatten() |
| | idx_remove = [] |
| |
|
| | for cls in classes_trimmed: |
| | cls_idx = (self.by == cls).nonzero().view(-1) |
| | idx_remove += [cls_idx[-rem_per_class[cls]:]] |
| |
|
| | idx_remove = torch.cat(idx_remove) |
| | idx_mask = torch.BoolTensor(buffer.size(0)).to(self.device) |
| | idx_mask.fill_(0) |
| | idx_mask[idx_remove] = 1 |
| |
|
| | |
| | for name, data in batch.items(): |
| | buffer = getattr(self, f'b{name}') |
| | buffer = buffer[~idx_mask] |
| | setattr(self, f'b{name}', buffer) |
| |
|
| | def add_queue(self, batch): |
| | self._init_buffers(batch) |
| |
|
| | if not hasattr(self, 'queue_ptr'): |
| | self.queue_ptr = 0 |
| |
|
| | start_idx = self.queue_ptr |
| | end_idx = (start_idx + batch['x'].size(0)) % self.cap |
| |
|
| | for name, data in batch.items(): |
| | buffer = getattr(self, f'b{name}') |
| | buffer[start_idx:end_idx] = data |
| |
|
| | def sample_random(self, amt, exclude_task=None, **kwargs): |
| | buffers = OrderedDict() |
| |
|
| | if exclude_task is not None: |
| | assert hasattr(self, 'bt') |
| | valid_indices = torch.where(self.bt != exclude_task)[0] |
| | valid_indices = valid_indices[valid_indices < self.current_index] |
| | for buffer_name in self.buffers: |
| | buffers[buffer_name[1:]] = getattr(self, buffer_name)[valid_indices] |
| | else: |
| | for buffer_name in self.buffers: |
| | buffers[buffer_name[1:]] = getattr(self, buffer_name)[:self.current_index] |
| |
|
| | n_selected = buffers['x'].size(0) |
| | if n_selected <= amt: |
| | assert n_selected > 0 |
| | return buffers |
| | else: |
| | idx_np = np.random.choice(buffers['x'].size(0), amt, replace=False) |
| | indices = torch.from_numpy(idx_np).to(self.bx.device) |
| |
|
| | return OrderedDict({k:v[indices] for (k,v) in buffers.items()}) |
| |
|
| | def sample_balanced(self, amt, exclude_task=None, **kwargs): |
| | buffers = OrderedDict() |
| |
|
| | if exclude_task is not None: |
| | assert hasattr(self, 'bt') |
| | valid_indices = (self.bt != exclude_task).nonzero().squeeze() |
| | for buffer_name in self.buffers: |
| | buffers[buffer_name[1:]] = getattr(self, buffer_name)[valid_indices] |
| | else: |
| | for buffer_name in self.buffers: |
| | buffers[buffer_name[1:]] = getattr(self, buffer_name)[:self.current_index] |
| |
|
| | class_count = buffers['y'].bincount() |
| |
|
| | |
| | class_sample_p = 1. / class_count.float() / class_count.size(0) |
| | per_sample_p = class_sample_p.gather(0, buffers['y']) |
| | indices = torch.multinomial(per_sample_p, amt) |
| |
|
| | return OrderedDict({k:v[indices] for (k,v) in buffers.items()}) |
| |
|
| | def sample_pos_neg(self, inc_data, task_free=True, same_task_neg=True): |
| |
|
| | x = inc_data['x'] |
| | label = inc_data['y'] |
| | task = torch.zeros_like(label).fill_(inc_data['t']) |
| |
|
| | |
| | bx = torch.cat((self.bx[:self.current_index], x)) |
| | by = torch.cat((self.by[:self.current_index], label)) |
| | bt = torch.cat((self.bt[:self.current_index], task)) |
| | bidx = torch.arange(bx.size(0)).to(bx.device) |
| |
|
| | |
| | same_label = label.view(1, -1) == by.view(-1, 1) |
| | same_task = task.view(1, -1) == bt.view(-1, 1) |
| | same_ex = bidx[-x.size(0):].view(1, -1) == bidx.view(-1, 1) |
| |
|
| | task_labels = label.unique() |
| | real_same_task = same_task |
| |
|
| | if task_free: |
| | same_task = torch.zeros_like(same_task) |
| |
|
| | for label_ in task_labels: |
| | label_exp = label_.view(1, -1).expand_as(same_task) |
| | same_task = same_task | (label_exp == by.view(-1, 1)) |
| |
|
| | valid_pos = same_label & ~same_ex |
| |
|
| | if same_task_neg: |
| | valid_neg = ~same_label & same_task |
| | else: |
| | valid_neg = ~same_label |
| |
|
| | |
| | has_valid_pos = valid_pos.sum(0) > 0 |
| | has_valid_neg = valid_neg.sum(0) > 0 |
| |
|
| | invalid_idx = ~has_valid_pos | ~has_valid_neg |
| |
|
| | if invalid_idx.sum() > 0: |
| | |
| | valid_pos[:, invalid_idx] = 1 |
| | valid_neg[:, invalid_idx] = 1 |
| |
|
| | |
| | is_invalid = torch.zeros_like(label).bool() |
| | is_invalid[invalid_idx] = 1 |
| |
|
| | |
| | pos_idx = torch.multinomial(valid_pos.float().T, 1).squeeze(1) |
| | neg_idx = torch.multinomial(valid_neg.float().T, 1).squeeze(1) |
| |
|
| | n_fwd = torch.stack((bidx[-x.size(0):], pos_idx, neg_idx), 1)[~invalid_idx].unique().size(0) |
| |
|
| | return bx[pos_idx], \ |
| | bx[neg_idx], \ |
| | by[pos_idx], \ |
| | by[neg_idx], \ |
| | is_invalid, \ |
| | n_fwd |
| |
|
| | def sample_minimal_pos_neg(self, inc_data, task_free=True, same_task_neg=True): |
| | """ maximize choosing the incoming data to minimize forward passes """ |
| |
|
| | x = inc_data['x'] |
| | label = inc_data['y'] |
| | task = torch.zeros_like(label).fill_(inc_data['t']) |
| |
|
| | ''' |
| | # we need to create an "augmented" buffer containing the incoming data |
| | bx = torch.cat((self.bx[:self.current_index], x)) |
| | by = torch.cat((self.by[:self.current_index], label)) |
| | bt = torch.cat((self.bt[:self.current_index], task)) |
| | bidx = torch.arange(bx.size(0)).to(bx.device) |
| | |
| | # buf_size x label_size |
| | same_label = label.view(1, -1) == by.view(-1, 1) |
| | same_task = task.view(1, -1) == bt.view(-1, 1) |
| | same_ex = bidx[-x.size(0):].view(1, -1) == bidx.view(-1, 1) |
| | ''' |
| |
|
| | bidx = torch.arange(x.size(0)).to(x.device) |
| |
|
| | |
| | same_label = label.view(1, -1) == label.view(-1, 1) |
| | same_task = task.view(1, -1) == task.view(-1, 1) |
| | same_ex = bidx.view(1, -1) == bidx.view(-1, 1) |
| |
|
| | task_labels = label.unique() |
| | real_same_task = same_task |
| |
|
| | |
| | |
| | if task_free: |
| | same_task = torch.zeros_like(same_task) |
| |
|
| | for label_ in task_labels: |
| | label_exp = label_.view(1, -1).expand_as(same_task) |
| | same_task = same_task | (label_exp == label.view(-1, 1)) |
| |
|
| | valid_pos = same_label & ~same_ex |
| |
|
| | if same_task_neg: |
| | valid_neg = ~same_label & same_task |
| | else: |
| | valid_neg = ~same_label |
| |
|
| | |
| | has_valid_pos = valid_pos.sum(0) > 0 |
| | has_valid_neg = valid_neg.sum(0) > 0 |
| |
|
| | invalid_idx = ~has_valid_pos | ~has_valid_neg |
| |
|
| | if invalid_idx.any(): |
| | |
| | valid_pos[:, invalid_idx] = 1 |
| | valid_neg[:, invalid_idx] = 1 |
| |
|
| | |
| | is_invalid = torch.zeros_like(label).bool() |
| | is_invalid[invalid_idx] = 1 |
| |
|
| | |
| | pos_idx = torch.multinomial(valid_pos.float().T, 1).squeeze(1) |
| | neg_idx = torch.multinomial(valid_neg.float().T, 1).squeeze(1) |
| |
|
| | |
| | pos_x, neg_x = x[pos_idx], x[neg_idx] |
| | pos_y, neg_y = label[pos_idx], label[neg_idx] |
| |
|
| | n_fwd = torch.stack((bidx, pos_idx, neg_idx), 1)[~invalid_idx].unique().size(0) |
| |
|
| | |
| | if invalid_idx.any(): |
| | |
| | invalid_data = OrderedDict() |
| | invalid_data['x'] = x[invalid_idx] |
| | invalid_data['y'] = label[invalid_idx] |
| | invalid_data['t'] = inc_data['t'] |
| |
|
| | n_pos_x, n_neg_x, n_pos_y, n_neg_y, n_is_invalid, n_new_fwd = \ |
| | self.sample_pos_neg(invalid_data, task_free=task_free, same_task_neg=same_task_neg) |
| |
|
| | |
| | pos_x[invalid_idx][~n_is_invalid].data.copy_(n_pos_x[~n_is_invalid]) |
| | neg_x[invalid_idx][~n_is_invalid].data.copy_(n_neg_x[~n_is_invalid]) |
| | pos_y[invalid_idx][~n_is_invalid].data.copy_(n_pos_y[~n_is_invalid]) |
| | neg_y[invalid_idx][~n_is_invalid].data.copy_(n_neg_y[~n_is_invalid]) |
| |
|
| | invalid_idx[invalid_idx].data.copy_(n_is_invalid) |
| |
|
| | n_fwd += n_new_fwd |
| |
|
| | return pos_x, neg_x, pos_y, neg_y, is_invalid, n_fwd |
| |
|