File size: 4,346 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F

# modified from https://github.com/gydpku/OCM/blob/main/buffer.py

class OnlineBuffer(nn.Module):
    def __init__(self, buffer_size, batch_size, input_size):
        super().__init__()

        self.place_left = True
        self.strategy = None
        self.buffer_size = buffer_size
        print('buffer has %d slots' % buffer_size, buffer_size)

        buf_data = torch.FloatTensor(buffer_size, *input_size).fill_(0)
        buf_targets = torch.LongTensor(buffer_size).fill_(0)
        buf_tasks = torch.LongTensor(buffer_size).fill_(0)

        self.current_index = 0
        self.n_seen_so_far = 0
        self.is_full       = 0
        self.total_classes = 0
        # registering as buffer allows us to save the object using `torch.save`
        self.register_buffer('buf_data', buf_data)
        self.register_buffer('buf_targets', buf_targets)
        self.register_buffer('buf_tasks', buf_tasks)


    def tensor_to_device(self, device):
        self.device = device
        self.buf_data.to(device)
        self.buf_targets.to(device)
        self.buf_tasks.to(device)



    def add_reservoir(self, x, y, task):
        n_elem = x.size(0)
        
        self.device = x.device
        place_left = max(0, self.buffer_size - self.current_index)
        offset = min(place_left, n_elem)

        if place_left:
            offset = min(place_left, n_elem)

            self.buf_data[self.current_index: self.current_index + offset].data.copy_(x[:offset])
            self.buf_targets[self.current_index: self.current_index + offset].data.copy_(y[:offset])
            self.buf_tasks[self.current_index: self.current_index + offset].fill_(task)
            self.current_index += offset
            self.n_seen_so_far += offset

            if offset == x.size(0):
                return

        self.place_left = False

        # remove what is already in the buffer
        x, y = x[place_left:], y[place_left:]

        indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, self.n_seen_so_far).long()
        valid_indices = (indices < self.buf_data.size(0)).long()

        idx_new_data = valid_indices.nonzero().squeeze(-1)
        idx_buffer   = indices[idx_new_data]

        self.n_seen_so_far += x.size(0)

        if idx_buffer.numel() == 0:
            return

        assert idx_buffer.max() < self.buf_data.size(0), pdb.set_trace()
        assert idx_buffer.max() < self.buf_targets.size(0), pdb.set_trace()
        assert idx_buffer.max() < self.buf_tasks.size(0), pdb.set_trace()

        assert idx_new_data.max() < x.size(0), pdb.set_trace()
        assert idx_new_data.max() < y.size(0), pdb.set_trace()

        if self.buf_data.device != x.device:
            self.buf_data = self.buf_data.to(x.device)
            self.buf_targets = self.buf_targets.to(x.device)
            self.buf_tasks = self.buf_tasks.to(x.device)

        self.buf_data[idx_buffer] = x[idx_new_data]
        self.buf_targets[idx_buffer] = y[idx_new_data]
        self.buf_tasks[idx_buffer] = task




    def sample(self, amount, exclude_task = None, ret_ind = False):

        if self.buf_data.device != self.device:
            self.buf_data = self.buf_data.to(self.device)
            self.buf_targets = self.buf_targets.to(self.device)
            self.buf_tasks = self.buf_tasks.to(self.device)

        if exclude_task is not None:
            valid_indices = (self.t != exclude_task)
            valid_indices = valid_indices.nonzero().squeeze()
            bx, by, bt = self.buf_data[valid_indices], self.buf_targets[valid_indices], self.buf_tasks[valid_indices]
        else:
            bx, by, bt = self.buf_data[:self.current_index], self.buf_targets[:self.current_index], self.buf_tasks[:self.current_index]

        if bx.size(0) < amount:
            if ret_ind:
                return bx, by, bt, torch.from_numpy(np.arange(bx.size(0)))
            else:
                return bx, by, bt
        else:
            indices = torch.from_numpy(np.random.choice(bx.size(0), amount, replace=False))
            indices = indices.to(self.device)

            if ret_ind:
                return bx[indices], by[indices], bt[indices], indices
            else:
                return bx[indices], by[indices], bt[indices]