File size: 4,492 Bytes
5fee096 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import os
import random
import numpy as np
import core.data.custom_transforms as cstf
from torchvision import datasets, transforms
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from .dataset import ContinualDatasets, ImbalancedDatasets
from .data import transform_classes
from PIL import Image
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
def _create_transforms(cfg):
transform_list = []
for item in cfg:
for func_name, params in item.items():
# Convert str to enum, if required
for k, v in params.items():
if isinstance(v, str):
try:
params[k] = transforms.InterpolationMode[v]
except KeyError:
pass
if func_name in cstf.custom_trfm_names:
transform = getattr(cstf, func_name)
else:
transform = getattr(transforms, func_name)(**params)
transform_list.append(transform)
return transforms.Compose(transform_list)
def get_augment(config, mode='train'):
# Special judge for RAPF
if 'is_rapf' in config.keys() and config['is_rapf']:
def _convert_image_to_rgb(image):
return image.convert("RGB")
n_px = config['image_size']
return Compose([
transforms.Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
if f'{mode}_trfms' in config.keys():
return _create_transforms(config[f'{mode}_trfms'])
# TODO: currently keeping below part for backward compatibility, will be remove in future
d = {'dataset': 'cifar',
'backbone': 'resnet',
'mode': mode}
if 'dataset' in config.keys():
if 'cifar' in config['dataset']:
d['dataset'] = 'cifar'
else:
d['dataset'] = config['dataset']
if 'vit' in config['backbone']['name'].lower():
d['backbone'] = 'vit'
if 'alexnet' in config['backbone']['name'].lower():
d['backbone'] = 'alexnet'
return transform_classes[d['dataset']].get_transform(d['backbone'], d['mode'])
def get_dataloader(config, mode, cls_map=None):
'''
Initialize the dataloaders for Continual Learning.
Args:
config (dict): Parsed config dict.
mode (string): 'trian' or 'test'.
cls_map (dict): record the map between class and labels.
Returns:
Dataloaders (list): a list of dataloaders
'''
task_num = config['task_num']
init_cls_num = config['init_cls_num']
inc_cls_num = config['inc_cls_num']
data_root = config['data_root']
num_workers = config['num_workers']
dataset = config['dataset']
trfms = get_augment(config, mode)
if f'{mode}_batch_size' in config.keys():
batch_size = config[f'{mode}_batch_size']
else:
batch_size = config['batch_size']
if dataset == 'tiny-imagenet':
cls_map = {}
with open(os.path.join(os.getcwd(), "core", "data", "dataset_reqs", f"tinyimagenet_classes.txt"), "r") as f:
for line in f.readlines():
_, cls_code, cls_name = line.strip().split('\t')
cls_map[cls_code] = cls_name
elif cls_map is None and dataset != 'binary_cifar100':
# Apply class_order for debugging
cls_list = sorted(os.listdir(os.path.join(data_root, mode)))
#random.shuffle(cls_list)
if 'class_order' in config.keys():
class_order = config['class_order']
perm = class_order
else:
perm = np.random.permutation(len(cls_list))
cls_map = dict()
for label, ori_label in enumerate(perm):
cls_map[label] = cls_list[ori_label]
if mode == 'train' and 'imb_type' in config.keys():
# generate long-tailed data to reproduce DAP
return ImbalancedDatasets(mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batch_size, num_workers, config['imb_type'], config['imb_factor'], config['shuffle'])
return ContinualDatasets(dataset, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batch_size, num_workers, config)
|