| import numpy as np | |
| import torch | |
| import copy | |
| from collections import Counter | |
| from torch.utils.data import DataLoader | |
| def random_update(datasets, buffer): | |
| images = np.array(datasets.images + buffer.images) | |
| labels = np.array(datasets.labels + buffer.labels) | |
| perm = np.random.permutation(len(labels)) | |
| images, labels = images[perm[:buffer.buffer_size]], labels[perm[:buffer.buffer_size]] | |
| buffer.images = images.tolist() | |
| buffer.labels = labels.tolist() | |
| def herding_update(datasets, buffer, feature_extractor, device): | |
| print("Using Herding Update Strategy") | |
| per_classes = buffer.buffer_size // buffer.total_classes | |
| selected_images, selected_labels = [], [] | |
| images = np.array(datasets.images + buffer.images) | |
| labels = np.array(datasets.labels + buffer.labels) | |
| for cls in range(buffer.total_classes): | |
| cls_images_idx = np.where(labels == cls) | |
| cls_images, cls_labels = images[cls_images_idx], labels[cls_images_idx] | |
| cls_selected_images, cls_selected_labels = construct_examplar(copy.copy(datasets), cls_images, cls_labels, feature_extractor, per_classes, device) | |
| selected_images.extend(cls_selected_images) | |
| selected_labels.extend(cls_selected_labels) | |
| label_counter = Counter(buffer.labels) | |
| print("\nBuffer composition per class:") | |
| for cls in sorted(label_counter.keys()): | |
| print(f" Class {cls:3d} : {label_counter[cls]} samples") | |
| buffer.images, buffer.labels = selected_images, selected_labels | |
| def construct_examplar(datasets, images, labels, feature_extractor, per_classes, device): | |
| if len(images) <= per_classes: | |
| print(labels[0], len(images), per_classes) | |
| return images, labels | |
| datasets.images, datasets.labels = images, labels | |
| dataloader = DataLoader(datasets, shuffle = False, batch_size = 32, drop_last = False) | |
| with torch.no_grad(): | |
| features = [] | |
| for data in dataloader: | |
| imgs = data['image'].to(device) | |
| features.append(feature_extractor(imgs)['features'].cpu().numpy().tolist()) | |
| features = np.concatenate(features) | |
| selected_images, selected_labels = [], [] | |
| selected_features = [] | |
| class_mean = np.mean(features, axis=0) | |
| for k in range(1, per_classes+1): | |
| if len(selected_features) == 0: | |
| S = np.zeros_like(features[0]) | |
| else: | |
| S = np.mean(np.array(selected_features), axis=0) | |
| mu_p = (S + features) / k | |
| i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) | |
| selected_images.append(images[i]) | |
| selected_labels.append(labels[i]) | |
| selected_features.append(features[i]) | |
| features = np.delete(features, i, axis=0) | |
| images = np.delete(images, i) | |
| labels = np.delete(labels, i) | |
| return selected_images, selected_labels | |