File size: 4,492 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import random
import numpy as np
import core.data.custom_transforms as cstf

from torchvision import datasets, transforms
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from .dataset import ContinualDatasets, ImbalancedDatasets
from .data import transform_classes
from PIL import Image
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def _create_transforms(cfg):
    transform_list = []

    for item in cfg:
        for func_name, params in item.items():
        
            # Convert str to enum, if required
            for k, v in params.items():
                if isinstance(v, str):
                    try:
                        params[k] = transforms.InterpolationMode[v]
                    except KeyError:
                        pass

            if func_name in cstf.custom_trfm_names:
                transform = getattr(cstf, func_name)
            else:
                transform = getattr(transforms, func_name)(**params)

            transform_list.append(transform)

    return transforms.Compose(transform_list)

def get_augment(config, mode='train'):
    # Special judge for RAPF
    if 'is_rapf' in config.keys() and config['is_rapf']:
        def _convert_image_to_rgb(image):
            return image.convert("RGB")
        n_px = config['image_size']

        return Compose([
            transforms.Resize(n_px, interpolation=BICUBIC),
            CenterCrop(n_px),
            _convert_image_to_rgb,
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])
    if f'{mode}_trfms' in config.keys():
        return _create_transforms(config[f'{mode}_trfms'])

    # TODO: currently keeping below part for backward compatibility, will be remove in future

    d = {'dataset': 'cifar', 
         'backbone': 'resnet',
         'mode': mode}
    
    if 'dataset' in config.keys():
        if 'cifar' in config['dataset']:
            d['dataset'] = 'cifar'
        else:
            d['dataset'] = config['dataset']

    if 'vit' in config['backbone']['name'].lower():
        d['backbone'] = 'vit'
    if 'alexnet' in config['backbone']['name'].lower():
        d['backbone'] = 'alexnet'
        
    return transform_classes[d['dataset']].get_transform(d['backbone'], d['mode'])
    
def get_dataloader(config, mode, cls_map=None):
    '''
    Initialize the dataloaders for Continual Learning.

    Args:
        config (dict): Parsed config dict.
        mode (string): 'trian' or 'test'.
        cls_map (dict): record the map between class and labels.
    
    Returns:
        Dataloaders (list): a list of dataloaders
    '''

    task_num = config['task_num']
    init_cls_num = config['init_cls_num']
    inc_cls_num = config['inc_cls_num']

    data_root = config['data_root']
    num_workers = config['num_workers']
    dataset = config['dataset']

    trfms = get_augment(config, mode)

    if f'{mode}_batch_size' in config.keys():
        batch_size = config[f'{mode}_batch_size']
    else:
        batch_size = config['batch_size']

    if dataset == 'tiny-imagenet':
        cls_map = {}
        with open(os.path.join(os.getcwd(), "core", "data", "dataset_reqs", f"tinyimagenet_classes.txt"), "r") as f:
            for line in f.readlines():
                _, cls_code, cls_name = line.strip().split('\t')
                cls_map[cls_code] = cls_name

    elif cls_map is None and dataset != 'binary_cifar100':
        # Apply class_order for debugging
        cls_list = sorted(os.listdir(os.path.join(data_root, mode)))
        #random.shuffle(cls_list)
        if 'class_order' in config.keys():
            class_order = config['class_order']
            perm = class_order
        else: 
            perm = np.random.permutation(len(cls_list))
        cls_map = dict()
        for label, ori_label in enumerate(perm):
            cls_map[label] = cls_list[ori_label]

    if mode == 'train' and 'imb_type' in config.keys():
        # generate long-tailed data to reproduce DAP
        return ImbalancedDatasets(mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batch_size, num_workers, config['imb_type'], config['imb_factor'], config['shuffle'])

    return ContinualDatasets(dataset, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batch_size, num_workers, config)