File size: 2,861 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import numpy as np
import torch
import copy
from collections import Counter
from torch.utils.data import DataLoader
def random_update(datasets, buffer):
images = np.array(datasets.images + buffer.images)
labels = np.array(datasets.labels + buffer.labels)
perm = np.random.permutation(len(labels))
images, labels = images[perm[:buffer.buffer_size]], labels[perm[:buffer.buffer_size]]
buffer.images = images.tolist()
buffer.labels = labels.tolist()
def herding_update(datasets, buffer, feature_extractor, device):
print("Using Herding Update Strategy")
per_classes = buffer.buffer_size // buffer.total_classes
selected_images, selected_labels = [], []
images = np.array(datasets.images + buffer.images)
labels = np.array(datasets.labels + buffer.labels)
for cls in range(buffer.total_classes):
cls_images_idx = np.where(labels == cls)
cls_images, cls_labels = images[cls_images_idx], labels[cls_images_idx]
cls_selected_images, cls_selected_labels = construct_examplar(copy.copy(datasets), cls_images, cls_labels, feature_extractor, per_classes, device)
selected_images.extend(cls_selected_images)
selected_labels.extend(cls_selected_labels)
label_counter = Counter(buffer.labels)
print("\nBuffer composition per class:")
for cls in sorted(label_counter.keys()):
print(f" Class {cls:3d} : {label_counter[cls]} samples")
buffer.images, buffer.labels = selected_images, selected_labels
def construct_examplar(datasets, images, labels, feature_extractor, per_classes, device):
if len(images) <= per_classes:
print(labels[0], len(images), per_classes)
return images, labels
datasets.images, datasets.labels = images, labels
dataloader = DataLoader(datasets, shuffle = False, batch_size = 32, drop_last = False)
with torch.no_grad():
features = []
for data in dataloader:
imgs = data['image'].to(device)
features.append(feature_extractor(imgs)['features'].cpu().numpy().tolist())
features = np.concatenate(features)
selected_images, selected_labels = [], []
selected_features = []
class_mean = np.mean(features, axis=0)
for k in range(1, per_classes+1):
if len(selected_features) == 0:
S = np.zeros_like(features[0])
else:
S = np.mean(np.array(selected_features), axis=0)
mu_p = (S + features) / k
i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))
selected_images.append(images[i])
selected_labels.append(labels[i])
selected_features.append(features[i])
features = np.delete(features, i, axis=0)
images = np.delete(images, i)
labels = np.delete(labels, i)
return selected_images, selected_labels
|