File size: 6,507 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import numpy as np
import torch
import torch.nn as nn
import PIL
import os
from typing import List
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
class LinearHerdingBuffer:
def __init__(self, buffer_size, batch_size):
self.buffer_size = buffer_size
self.strategy = None
self.batch_size = batch_size
self.images, self.labels = [], []
self.total_classes = 0
def is_empty(self):
return len(self.labels) == 0
def clear(self):
# clear the buffer
del self.images
del self.labels
self.images = []
self.labels = []
def get_all_data(self):
# return images and labels in the format of np.array
return np.array(self.images), np.array(self.labels)
def add_data(self, data:List[str], targets:List[str]):
# add data and its labels to the buffer
self.images.extend(data)
self.labels.extend(targets)
def update(self, model:nn.Module, train_loader, val_transform, task_idx:int,
total_cls_num:int, cur_cls_indexes, device):
# get the chosen global index in the dataset for buffer
chosen_indexes = self.herding_select(model, train_loader, val_transform,
task_idx, total_cls_num, cur_cls_indexes,
device)
cur_task_dataset = train_loader.dataset
new_images = []
new_labels = []
for i in chosen_indexes:
new_images.append(cur_task_dataset.images[i])
new_labels.append(cur_task_dataset.labels[i])
self.add_data(new_images, new_labels)
def reduce_old_data(self, task_idx:int, total_cls_num:int) -> None:
# subsample previous categories in the buffer
samples_per_class = self.buffer_size // total_cls_num
if samples_per_class == 0:
print(
f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ",
f"Samples per class will be set to 1, to avoid empty buffer."
)
samples_per_class = 1
if task_idx > 0:
buffer_X, buffer_Y = self.get_all_data()
self.clear()
for y in np.unique(buffer_Y):
idx = (buffer_Y == y)
selected_X, selected_Y = buffer_X[idx], buffer_Y[idx]
self.add_data(
data=selected_X[:samples_per_class],
targets=selected_Y[:samples_per_class],
)
def herding_select(self, model:nn.Module, train_loader, val_transform,
task_idx:int, total_cls_num:int, cur_cls_indexes, device):
# Remove buffer samples from the dataset
# and keep only the samples belonging to the current task category.
def remove_buffer_sample_in_dataset(dataset, cur_cls_indexes):
new_labels = []
new_images = []
for i in cur_cls_indexes:
ind = np.array(dataset.labels) == i
new_images.extend(list(np.array(dataset.images)[ind]))
new_labels.extend(list(np.array(dataset.labels)[ind]))
dataset.labels = new_labels
dataset.images = new_images
# get dataset containing buffer samples
dataset = train_loader.dataset
# remove buffer samples and only keep
remove_buffer_sample_in_dataset(dataset, cur_cls_indexes)
# reset the transform
dataset.trfms = val_transform
# get loader for herding
loader = DataLoader(
dataset,
# Note that `shuffle = False` should be set.
# otherwise otherwise the generated indexes will not match with the paths of the images
shuffle = False,
batch_size = 32,
# `drop_last = False` should be set as False, otherwise some samples are lost
drop_last = False
)
# how many sample per class do we want
samples_per_class = self.buffer_size // total_cls_num
if samples_per_class == 0:
print(
f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ",
f"Samples per class will be set to 1, to avoid empty buffer."
)
samples_per_class = 1
# compute feature for all training sample for all train samples
extracted_features = []
extracted_targets = []
# print("!!!!! The origin code is\'feats = model.backbone(image)['features'] \', change to \'feats = model.extract_vector(image) \' by WA")
with torch.no_grad():
model.eval()
for data in loader:
image = data['image'].to(device)
label = data['label'].to(device)
# feats = model.extract_vector(image)
feats = model.backbone(image)['features']
feats = feats / feats.norm(dim=1).view(-1, 1) # Feature normalization
extracted_features.append(feats)
extracted_targets.append(label)
extracted_features = (torch.cat(extracted_features)).cpu()
extracted_targets = (torch.cat(extracted_targets)).cpu()
result = []
for curr_cls in np.unique(extracted_targets):
cls_ind = np.where(extracted_targets == curr_cls)[0]
cls_feats = extracted_features[cls_ind]
mean_feat = cls_feats.mean(0, keepdim=True)
running_sum = torch.zeros_like(mean_feat)
i = 0
begin_index = cls_ind[0]
while i < samples_per_class and i < cls_feats.shape[0]:
cost = (mean_feat - (cls_feats + running_sum) / (i + 1)).norm(2, 1)
# Notice that the initial offset should be added
# since indexes we want are global in the dataset
# hence we should guarantee indexes belonging to the same class
# should be continuous
idx_min = cost.argmin().item()
global_index = idx_min + begin_index
result.append(global_index)
running_sum += cls_feats[idx_min:idx_min + 1]
cls_feats[idx_min] = cls_feats[idx_min] + 1e6
i += 1
return result
|