| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import PIL |
| | import os |
| | from typing import List |
| | from torch.utils.data import DataLoader |
| | from torch.utils.data import Dataset |
| |
|
| | class LinearHerdingBuffer: |
| | def __init__(self, buffer_size, batch_size): |
| | self.buffer_size = buffer_size |
| | self.strategy = None |
| | self.batch_size = batch_size |
| | self.images, self.labels = [], [] |
| | self.total_classes = 0 |
| |
|
| | def is_empty(self): |
| | return len(self.labels) == 0 |
| |
|
| | def clear(self): |
| | |
| | del self.images |
| | del self.labels |
| | self.images = [] |
| | self.labels = [] |
| | |
| | def get_all_data(self): |
| | |
| | return np.array(self.images), np.array(self.labels) |
| | |
| | def add_data(self, data:List[str], targets:List[str]): |
| | |
| | self.images.extend(data) |
| | self.labels.extend(targets) |
| |
|
| |
|
| | def update(self, model:nn.Module, train_loader, val_transform, task_idx:int, |
| | total_cls_num:int, cur_cls_indexes, device): |
| | |
| | |
| | chosen_indexes = self.herding_select(model, train_loader, val_transform, |
| | task_idx, total_cls_num, cur_cls_indexes, |
| | device) |
| | |
| | cur_task_dataset = train_loader.dataset |
| | new_images = [] |
| | new_labels = [] |
| | for i in chosen_indexes: |
| | new_images.append(cur_task_dataset.images[i]) |
| | new_labels.append(cur_task_dataset.labels[i]) |
| | |
| | self.add_data(new_images, new_labels) |
| |
|
| | def reduce_old_data(self, task_idx:int, total_cls_num:int) -> None: |
| | |
| | samples_per_class = self.buffer_size // total_cls_num |
| |
|
| | if samples_per_class == 0: |
| | print( |
| | f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ", |
| | f"Samples per class will be set to 1, to avoid empty buffer." |
| | ) |
| | samples_per_class = 1 |
| |
|
| | if task_idx > 0: |
| | buffer_X, buffer_Y = self.get_all_data() |
| | self.clear() |
| | for y in np.unique(buffer_Y): |
| | idx = (buffer_Y == y) |
| | selected_X, selected_Y = buffer_X[idx], buffer_Y[idx] |
| | self.add_data( |
| | data=selected_X[:samples_per_class], |
| | targets=selected_Y[:samples_per_class], |
| | ) |
| |
|
| | |
| | def herding_select(self, model:nn.Module, train_loader, val_transform, |
| | task_idx:int, total_cls_num:int, cur_cls_indexes, device): |
| |
|
| | |
| | |
| | def remove_buffer_sample_in_dataset(dataset, cur_cls_indexes): |
| | new_labels = [] |
| | new_images = [] |
| | for i in cur_cls_indexes: |
| | ind = np.array(dataset.labels) == i |
| | new_images.extend(list(np.array(dataset.images)[ind])) |
| | new_labels.extend(list(np.array(dataset.labels)[ind])) |
| | dataset.labels = new_labels |
| | dataset.images = new_images |
| |
|
| | |
| | dataset = train_loader.dataset |
| |
|
| | |
| | remove_buffer_sample_in_dataset(dataset, cur_cls_indexes) |
| |
|
| | |
| | dataset.trfms = val_transform |
| |
|
| | |
| | loader = DataLoader( |
| | dataset, |
| | |
| | |
| | shuffle = False, |
| | batch_size = 32, |
| | |
| | drop_last = False |
| | ) |
| | |
| | |
| | samples_per_class = self.buffer_size // total_cls_num |
| | if samples_per_class == 0: |
| | print( |
| | f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ", |
| | f"Samples per class will be set to 1, to avoid empty buffer." |
| | ) |
| | samples_per_class = 1 |
| |
|
| | |
| | |
| | extracted_features = [] |
| | extracted_targets = [] |
| | |
| | with torch.no_grad(): |
| | model.eval() |
| | for data in loader: |
| | image = data['image'].to(device) |
| | label = data['label'].to(device) |
| | |
| | feats = model.backbone(image)['features'] |
| | feats = feats / feats.norm(dim=1).view(-1, 1) |
| | extracted_features.append(feats) |
| | extracted_targets.append(label) |
| | extracted_features = (torch.cat(extracted_features)).cpu() |
| | extracted_targets = (torch.cat(extracted_targets)).cpu() |
| |
|
| | result = [] |
| | for curr_cls in np.unique(extracted_targets): |
| | |
| | cls_ind = np.where(extracted_targets == curr_cls)[0] |
| | cls_feats = extracted_features[cls_ind] |
| | mean_feat = cls_feats.mean(0, keepdim=True) |
| | running_sum = torch.zeros_like(mean_feat) |
| | i = 0 |
| | begin_index = cls_ind[0] |
| | while i < samples_per_class and i < cls_feats.shape[0]: |
| | cost = (mean_feat - (cls_feats + running_sum) / (i + 1)).norm(2, 1) |
| |
|
| | |
| | |
| | |
| | |
| | idx_min = cost.argmin().item() |
| | global_index = idx_min + begin_index |
| | result.append(global_index) |
| | running_sum += cls_feats[idx_min:idx_min + 1] |
| | cls_feats[idx_min] = cls_feats[idx_min] + 1e6 |
| | i += 1 |
| |
|
| | return result |
| | |
| |
|
| |
|