| from importlib import import_module |
| |
| from torch.utils.data import dataloader |
| from torch.utils.data import ConcatDataset |
| import torch |
| import random |
| |
| class MyConcatDataset(ConcatDataset): |
| def __init__(self, datasets): |
| super(MyConcatDataset, self).__init__(datasets) |
| |
|
|
| def set_scale(self, idx_scale): |
| for d in self.datasets: |
| if hasattr(d, 'set_scale'): d.set_scale(idx_scale) |
|
|
| class Data: |
| def __init__(self, args): |
| self.loader_train = None |
| self.loader_test = [] |
| for d in args.data_test: |
| if d in ['Set5', 'Set14', 'B100', 'Urban100']: |
| m = import_module('data.benchmark') |
| testset = getattr(m, 'Benchmark')(args, name=d) |
| else: |
| assert NotImplementedError |
|
|
| self.loader_test.append( |
| dataloader.DataLoader( |
| testset, |
| batch_size=1, |
| shuffle=False, |
| pin_memory=False, |
| num_workers=args.n_threads, |
| ) |
| ) |
|
|