File size: 6,507 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import numpy as np
import torch
import torch.nn as nn
import PIL
import os
from typing import List
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

class LinearHerdingBuffer:
    def __init__(self, buffer_size, batch_size):    
        self.buffer_size = buffer_size
        self.strategy = None
        self.batch_size = batch_size
        self.images, self.labels = [], []
        self.total_classes = 0

    def is_empty(self):
        return len(self.labels) == 0

    def clear(self):
        # clear the buffer
        del self.images
        del self.labels
        self.images = []
        self.labels = []
    
    def get_all_data(self):
        # return images and labels in the format of np.array
        return np.array(self.images), np.array(self.labels)
    
    def add_data(self, data:List[str], targets:List[str]):
        # add data and its labels to the buffer
        self.images.extend(data)
        self.labels.extend(targets)    


    def update(self, model:nn.Module, train_loader, val_transform, task_idx:int, 
               total_cls_num:int, cur_cls_indexes, device):
        
        # get the chosen global index in the dataset for buffer
        chosen_indexes = self.herding_select(model, train_loader, val_transform, 
                                             task_idx, total_cls_num, cur_cls_indexes, 
                                             device)
        
        cur_task_dataset = train_loader.dataset
        new_images = []
        new_labels = []
        for i in chosen_indexes:
            new_images.append(cur_task_dataset.images[i])
            new_labels.append(cur_task_dataset.labels[i])
            
        self.add_data(new_images, new_labels)

    def reduce_old_data(self, task_idx:int, total_cls_num:int) -> None:
        # subsample previous categories in the buffer
        samples_per_class = self.buffer_size // total_cls_num

        if samples_per_class == 0:
            print(
                f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ",
                f"Samples per class will be set to 1, to avoid empty buffer."
            )
            samples_per_class = 1

        if task_idx > 0:
            buffer_X, buffer_Y = self.get_all_data()
            self.clear()
            for y in np.unique(buffer_Y):
                idx = (buffer_Y == y)
                selected_X, selected_Y = buffer_X[idx], buffer_Y[idx]
                self.add_data(
                    data=selected_X[:samples_per_class],
                    targets=selected_Y[:samples_per_class],
                )

    
    def herding_select(self, model:nn.Module, train_loader, val_transform, 
                       task_idx:int, total_cls_num:int, cur_cls_indexes, device):

        # Remove buffer samples from the dataset
        # and keep only the samples belonging to the current task category.
        def remove_buffer_sample_in_dataset(dataset, cur_cls_indexes):
            new_labels = []
            new_images = []
            for i in cur_cls_indexes:
                ind = np.array(dataset.labels) == i
                new_images.extend(list(np.array(dataset.images)[ind]))
                new_labels.extend(list(np.array(dataset.labels)[ind]))
            dataset.labels = new_labels
            dataset.images = new_images

        # get dataset containing buffer samples
        dataset = train_loader.dataset

        # remove buffer samples and only keep 
        remove_buffer_sample_in_dataset(dataset, cur_cls_indexes)

        # reset the transform
        dataset.trfms = val_transform

        # get loader for herding
        loader = DataLoader(
                dataset,
                # Note that `shuffle = False` should be set.
                # otherwise otherwise the generated indexes will not match with the paths of the images
                shuffle = False,
                batch_size = 32,
                # `drop_last = False` should be set as False, otherwise some samples are lost
                drop_last = False
            )
        
        # how many sample per class do we want
        samples_per_class = self.buffer_size // total_cls_num
        if samples_per_class == 0:
            print(
                f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ",
                f"Samples per class will be set to 1, to avoid empty buffer."
            )
            samples_per_class = 1

    
        # compute feature for all training sample for all train samples
        extracted_features = []
        extracted_targets = []
        # print("!!!!! The origin code is\'feats = model.backbone(image)['features'] \', change to \'feats = model.extract_vector(image) \' by WA")
        with torch.no_grad():
            model.eval()
            for data in loader:
                image = data['image'].to(device)
                label = data['label'].to(device)
                # feats = model.extract_vector(image)
                feats = model.backbone(image)['features']
                feats = feats / feats.norm(dim=1).view(-1, 1)  # Feature normalization
                extracted_features.append(feats)
                extracted_targets.append(label)
        extracted_features = (torch.cat(extracted_features)).cpu()
        extracted_targets = (torch.cat(extracted_targets)).cpu()

        result = []
        for curr_cls in np.unique(extracted_targets):
            
            cls_ind = np.where(extracted_targets == curr_cls)[0]
            cls_feats = extracted_features[cls_ind]
            mean_feat = cls_feats.mean(0, keepdim=True)
            running_sum = torch.zeros_like(mean_feat)
            i = 0
            begin_index = cls_ind[0]
            while i < samples_per_class and i < cls_feats.shape[0]:
                cost = (mean_feat - (cls_feats + running_sum) / (i + 1)).norm(2, 1)

                # Notice that the initial offset should be added
                # since indexes we want are global in the dataset
                # hence we should guarantee indexes belonging to the same class 
                # should be continuous
                idx_min = cost.argmin().item()
                global_index = idx_min + begin_index
                result.append(global_index)
                running_sum += cls_feats[idx_min:idx_min + 1]
                cls_feats[idx_min] = cls_feats[idx_min] + 1e6
                i += 1

        return result