| | import numpy as np |
| | import pdb |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | |
| |
|
| | class OnlineBuffer(nn.Module): |
| | def __init__(self, buffer_size, batch_size, input_size): |
| | super().__init__() |
| |
|
| | self.place_left = True |
| | self.strategy = None |
| | self.buffer_size = buffer_size |
| | print('buffer has %d slots' % buffer_size, buffer_size) |
| |
|
| | buf_data = torch.FloatTensor(buffer_size, *input_size).fill_(0) |
| | buf_targets = torch.LongTensor(buffer_size).fill_(0) |
| | buf_tasks = torch.LongTensor(buffer_size).fill_(0) |
| |
|
| | self.current_index = 0 |
| | self.n_seen_so_far = 0 |
| | self.is_full = 0 |
| | self.total_classes = 0 |
| | |
| | self.register_buffer('buf_data', buf_data) |
| | self.register_buffer('buf_targets', buf_targets) |
| | self.register_buffer('buf_tasks', buf_tasks) |
| |
|
| |
|
| | def tensor_to_device(self, device): |
| | self.device = device |
| | self.buf_data.to(device) |
| | self.buf_targets.to(device) |
| | self.buf_tasks.to(device) |
| |
|
| |
|
| |
|
| | def add_reservoir(self, x, y, task): |
| | n_elem = x.size(0) |
| | |
| | self.device = x.device |
| | place_left = max(0, self.buffer_size - self.current_index) |
| | offset = min(place_left, n_elem) |
| |
|
| | if place_left: |
| | offset = min(place_left, n_elem) |
| |
|
| | self.buf_data[self.current_index: self.current_index + offset].data.copy_(x[:offset]) |
| | self.buf_targets[self.current_index: self.current_index + offset].data.copy_(y[:offset]) |
| | self.buf_tasks[self.current_index: self.current_index + offset].fill_(task) |
| | self.current_index += offset |
| | self.n_seen_so_far += offset |
| |
|
| | if offset == x.size(0): |
| | return |
| |
|
| | self.place_left = False |
| |
|
| | |
| | x, y = x[place_left:], y[place_left:] |
| |
|
| | indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, self.n_seen_so_far).long() |
| | valid_indices = (indices < self.buf_data.size(0)).long() |
| |
|
| | idx_new_data = valid_indices.nonzero().squeeze(-1) |
| | idx_buffer = indices[idx_new_data] |
| |
|
| | self.n_seen_so_far += x.size(0) |
| |
|
| | if idx_buffer.numel() == 0: |
| | return |
| |
|
| | assert idx_buffer.max() < self.buf_data.size(0), pdb.set_trace() |
| | assert idx_buffer.max() < self.buf_targets.size(0), pdb.set_trace() |
| | assert idx_buffer.max() < self.buf_tasks.size(0), pdb.set_trace() |
| |
|
| | assert idx_new_data.max() < x.size(0), pdb.set_trace() |
| | assert idx_new_data.max() < y.size(0), pdb.set_trace() |
| |
|
| | if self.buf_data.device != x.device: |
| | self.buf_data = self.buf_data.to(x.device) |
| | self.buf_targets = self.buf_targets.to(x.device) |
| | self.buf_tasks = self.buf_tasks.to(x.device) |
| |
|
| | self.buf_data[idx_buffer] = x[idx_new_data] |
| | self.buf_targets[idx_buffer] = y[idx_new_data] |
| | self.buf_tasks[idx_buffer] = task |
| |
|
| |
|
| |
|
| |
|
| | def sample(self, amount, exclude_task = None, ret_ind = False): |
| |
|
| | if self.buf_data.device != self.device: |
| | self.buf_data = self.buf_data.to(self.device) |
| | self.buf_targets = self.buf_targets.to(self.device) |
| | self.buf_tasks = self.buf_tasks.to(self.device) |
| |
|
| | if exclude_task is not None: |
| | valid_indices = (self.t != exclude_task) |
| | valid_indices = valid_indices.nonzero().squeeze() |
| | bx, by, bt = self.buf_data[valid_indices], self.buf_targets[valid_indices], self.buf_tasks[valid_indices] |
| | else: |
| | bx, by, bt = self.buf_data[:self.current_index], self.buf_targets[:self.current_index], self.buf_tasks[:self.current_index] |
| |
|
| | if bx.size(0) < amount: |
| | if ret_ind: |
| | return bx, by, bt, torch.from_numpy(np.arange(bx.size(0))) |
| | else: |
| | return bx, by, bt |
| | else: |
| | indices = torch.from_numpy(np.random.choice(bx.size(0), amount, replace=False)) |
| | indices = indices.to(self.device) |
| |
|
| | if ret_ind: |
| | return bx[indices], by[indices], bt[indices], indices |
| | else: |
| | return bx[indices], by[indices], bt[indices] |