File size: 2,861 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import torch
import copy
from collections import Counter
from torch.utils.data import DataLoader

def random_update(datasets, buffer):

    images = np.array(datasets.images + buffer.images)
    labels = np.array(datasets.labels + buffer.labels)
    perm = np.random.permutation(len(labels))

    images, labels = images[perm[:buffer.buffer_size]], labels[perm[:buffer.buffer_size]]

    buffer.images = images.tolist()
    buffer.labels = labels.tolist()

def herding_update(datasets, buffer, feature_extractor, device):

    print("Using Herding Update Strategy")

    per_classes = buffer.buffer_size // buffer.total_classes

    selected_images, selected_labels = [], []
    images = np.array(datasets.images + buffer.images)
    labels = np.array(datasets.labels + buffer.labels)

    for cls in range(buffer.total_classes):
        cls_images_idx = np.where(labels == cls)
        cls_images, cls_labels = images[cls_images_idx], labels[cls_images_idx]

        cls_selected_images, cls_selected_labels = construct_examplar(copy.copy(datasets), cls_images, cls_labels, feature_extractor, per_classes, device)
        selected_images.extend(cls_selected_images)
        selected_labels.extend(cls_selected_labels)

    label_counter = Counter(buffer.labels)
    print("\nBuffer composition per class:")
    for cls in sorted(label_counter.keys()):
        print(f"  Class {cls:3d} : {label_counter[cls]} samples")

    buffer.images, buffer.labels = selected_images, selected_labels

def construct_examplar(datasets, images, labels, feature_extractor, per_classes, device):
    if len(images) <= per_classes:
        print(labels[0], len(images), per_classes)
        return images, labels
    
    datasets.images, datasets.labels = images, labels
    dataloader = DataLoader(datasets, shuffle = False, batch_size = 32, drop_last = False)

    with torch.no_grad():
        features = []
        for data in dataloader:
            imgs = data['image'].to(device)
            features.append(feature_extractor(imgs)['features'].cpu().numpy().tolist())

    features = np.concatenate(features)
    selected_images, selected_labels = [], []
    selected_features = []
    class_mean = np.mean(features, axis=0)

    for k in range(1, per_classes+1):
        if len(selected_features) == 0:
            S = np.zeros_like(features[0])
        else:
            S = np.mean(np.array(selected_features), axis=0)


        mu_p = (S + features) / k
        i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1)))

        selected_images.append(images[i])
        selected_labels.append(labels[i])
        selected_features.append(features[i])

        features = np.delete(features, i, axis=0) 
        images = np.delete(images, i)
        labels = np.delete(labels, i)

    return selected_images, selected_labels