File size: 4,346 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import numpy as np
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
# modified from https://github.com/gydpku/OCM/blob/main/buffer.py
class OnlineBuffer(nn.Module):
def __init__(self, buffer_size, batch_size, input_size):
super().__init__()
self.place_left = True
self.strategy = None
self.buffer_size = buffer_size
print('buffer has %d slots' % buffer_size, buffer_size)
buf_data = torch.FloatTensor(buffer_size, *input_size).fill_(0)
buf_targets = torch.LongTensor(buffer_size).fill_(0)
buf_tasks = torch.LongTensor(buffer_size).fill_(0)
self.current_index = 0
self.n_seen_so_far = 0
self.is_full = 0
self.total_classes = 0
# registering as buffer allows us to save the object using `torch.save`
self.register_buffer('buf_data', buf_data)
self.register_buffer('buf_targets', buf_targets)
self.register_buffer('buf_tasks', buf_tasks)
def tensor_to_device(self, device):
self.device = device
self.buf_data.to(device)
self.buf_targets.to(device)
self.buf_tasks.to(device)
def add_reservoir(self, x, y, task):
n_elem = x.size(0)
self.device = x.device
place_left = max(0, self.buffer_size - self.current_index)
offset = min(place_left, n_elem)
if place_left:
offset = min(place_left, n_elem)
self.buf_data[self.current_index: self.current_index + offset].data.copy_(x[:offset])
self.buf_targets[self.current_index: self.current_index + offset].data.copy_(y[:offset])
self.buf_tasks[self.current_index: self.current_index + offset].fill_(task)
self.current_index += offset
self.n_seen_so_far += offset
if offset == x.size(0):
return
self.place_left = False
# remove what is already in the buffer
x, y = x[place_left:], y[place_left:]
indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, self.n_seen_so_far).long()
valid_indices = (indices < self.buf_data.size(0)).long()
idx_new_data = valid_indices.nonzero().squeeze(-1)
idx_buffer = indices[idx_new_data]
self.n_seen_so_far += x.size(0)
if idx_buffer.numel() == 0:
return
assert idx_buffer.max() < self.buf_data.size(0), pdb.set_trace()
assert idx_buffer.max() < self.buf_targets.size(0), pdb.set_trace()
assert idx_buffer.max() < self.buf_tasks.size(0), pdb.set_trace()
assert idx_new_data.max() < x.size(0), pdb.set_trace()
assert idx_new_data.max() < y.size(0), pdb.set_trace()
if self.buf_data.device != x.device:
self.buf_data = self.buf_data.to(x.device)
self.buf_targets = self.buf_targets.to(x.device)
self.buf_tasks = self.buf_tasks.to(x.device)
self.buf_data[idx_buffer] = x[idx_new_data]
self.buf_targets[idx_buffer] = y[idx_new_data]
self.buf_tasks[idx_buffer] = task
def sample(self, amount, exclude_task = None, ret_ind = False):
if self.buf_data.device != self.device:
self.buf_data = self.buf_data.to(self.device)
self.buf_targets = self.buf_targets.to(self.device)
self.buf_tasks = self.buf_tasks.to(self.device)
if exclude_task is not None:
valid_indices = (self.t != exclude_task)
valid_indices = valid_indices.nonzero().squeeze()
bx, by, bt = self.buf_data[valid_indices], self.buf_targets[valid_indices], self.buf_tasks[valid_indices]
else:
bx, by, bt = self.buf_data[:self.current_index], self.buf_targets[:self.current_index], self.buf_tasks[:self.current_index]
if bx.size(0) < amount:
if ret_ind:
return bx, by, bt, torch.from_numpy(np.arange(bx.size(0)))
else:
return bx, by, bt
else:
indices = torch.from_numpy(np.random.choice(bx.size(0), amount, replace=False))
indices = indices.to(self.device)
if ret_ind:
return bx[indices], by[indices], bt[indices], indices
else:
return bx[indices], by[indices], bt[indices] |