|
|
from tokenize import group |
|
|
import torch |
|
|
import numpy as np |
|
|
import numpy.random as npr |
|
|
import torch.distributed as dist |
|
|
import math |
|
|
|
|
|
from ...log_service import print_log |
|
|
from ... import sync |
|
|
|
|
|
def singleton(class_): |
|
|
instances = {} |
|
|
def getinstance(*args, **kwargs): |
|
|
if class_ not in instances: |
|
|
instances[class_] = class_(*args, **kwargs) |
|
|
return instances[class_] |
|
|
return getinstance |
|
|
|
|
|
@singleton |
|
|
class get_sampler(object): |
|
|
def __init__(self): |
|
|
self.sampler = {} |
|
|
|
|
|
def register(self, sampler): |
|
|
self.sampler[sampler.__name__] = sampler |
|
|
|
|
|
def __call__(self, dataset, cfg): |
|
|
if cfg == 'default_train': |
|
|
return GlobalDistributedSampler(dataset, shuffle=True, extend=False) |
|
|
elif cfg == 'default_eval': |
|
|
return GlobalDistributedSampler(dataset, shuffle=False, extend=True) |
|
|
else: |
|
|
t = cfg.type |
|
|
return self.sampler[t](dataset=dataset, **cfg.args) |
|
|
|
|
|
def register(): |
|
|
def wrapper(class_): |
|
|
get_sampler().register(class_) |
|
|
return class_ |
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register() |
|
|
class GlobalDistributedSampler(torch.utils.data.Sampler): |
|
|
""" |
|
|
This is a distributed sampler that sync accross gpus and nodes. |
|
|
""" |
|
|
def __init__(self, |
|
|
dataset, |
|
|
shuffle=True, |
|
|
extend=False,): |
|
|
""" |
|
|
Arguments: |
|
|
dataset: Dataset used for sampling. |
|
|
shuffle: If true, sampler will shuffle the indices |
|
|
extend: If true, sampler will extend the indices that can be even distributed by ranks |
|
|
otherwise sampler will truncate the indices to make it even. |
|
|
""" |
|
|
self.ddp = sync.is_ddp() |
|
|
self.rank = sync.get_rank('global') |
|
|
self.world_size = sync.get_world_size('global') |
|
|
self.dataset = dataset |
|
|
self.shuffle = shuffle |
|
|
self.extend = extend |
|
|
|
|
|
num_samples = len(dataset) // self.world_size |
|
|
if extend and (len(dataset)%self.world_size != 0): |
|
|
num_samples+=1 |
|
|
self.num_samples = num_samples |
|
|
self.total_size = num_samples * self.world_size |
|
|
|
|
|
def __iter__(self): |
|
|
indices = self.get_sync_order() |
|
|
if self.extend: |
|
|
|
|
|
indices = indices+indices[0:self.total_size-len(indices)] |
|
|
else: |
|
|
|
|
|
indices = indices[0:self.total_size] |
|
|
|
|
|
indices = indices[self.rank : len(indices) : self.world_size] |
|
|
return iter(indices) |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_samples |
|
|
|
|
|
def get_sync_order(self): |
|
|
if self.shuffle: |
|
|
indices = torch.randperm(len(self.dataset)).to(self.rank) |
|
|
if self.ddp: |
|
|
dist.broadcast(indices, src=0) |
|
|
indices = indices.to('cpu').tolist() |
|
|
else: |
|
|
indices = list(range(len(self.dataset))) |
|
|
print_log('Sampler : {}'.format(str(indices[0:5])) ) |
|
|
return indices |
|
|
|
|
|
@register() |
|
|
class LocalDistributedSampler(GlobalDistributedSampler): |
|
|
""" |
|
|
This is a distributed sampler that sync across gpus within the nodes. |
|
|
But not sync across nodes. |
|
|
""" |
|
|
def __init__(self, |
|
|
dataset, |
|
|
shuffle=True, |
|
|
extend=False,): |
|
|
super().__init__(dataset, shuffle, extend) |
|
|
self.rank = sync.get_rank('local') |
|
|
self.world_size = sync.get_world_size('local') |
|
|
|
|
|
def get_sync_order(self): |
|
|
if self.shuffle: |
|
|
if self.rank == 0: |
|
|
indices = list(npr.permutation(len(self.dataset))) |
|
|
sync.nodewise_sync().broadcast_r0(indices) |
|
|
else: |
|
|
indices = sync.nodewise_sync().broadcast_r0(None) |
|
|
else: |
|
|
indices = list(range(len(self.dataset))) |
|
|
print_log('Sampler : {}'.format(str(indices[0:5])) ) |
|
|
return indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register() |
|
|
class GroupSampler(torch.utils.data.Sampler): |
|
|
""" |
|
|
This is a new DistributedSampler that sample all index according to group. |
|
|
i.e. |
|
|
if group_size=3, num_replicas=2, train mode: |
|
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 |
|
|
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10] |
|
|
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10]) |
|
|
process1: [0, 1, 2] |
|
|
==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10) |
|
|
process1: [0, 1, 2] |
|
|
==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10) |
|
|
process1: [0, 1, 2], [8, 9] |
|
|
|
|
|
it will avoid_batchsize=1: |
|
|
0, 1, 2, 3, 4, 5, 6, 7, 8, |
|
|
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8] |
|
|
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8]) |
|
|
process1: [0, 1, 2] |
|
|
==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8]) |
|
|
process1: [0, 1, 2] |
|
|
==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1) |
|
|
process1: [0, 1, 2] |
|
|
|
|
|
if group_size=3, num_replicas=2, eval mode: |
|
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 |
|
|
==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10 |
|
|
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10] |
|
|
==> (distribute) process0: [0, 1, 2], [6, 7, 8], |
|
|
process1: [3, 4, 5], [9, 10, 10] |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
dataset, |
|
|
group_size, |
|
|
num_replicas=None, |
|
|
rank=None, |
|
|
mode='train',): |
|
|
if num_replicas is None: |
|
|
if not dist.is_available(): |
|
|
raise ValueError |
|
|
num_replicas = dist.get_world_size() |
|
|
if rank is None: |
|
|
if not dist.is_available(): |
|
|
raise ValueError |
|
|
rank = dist.get_rank() |
|
|
|
|
|
self.dataset = dataset |
|
|
self.len_dataset = len(dataset) |
|
|
self.group_size = group_size |
|
|
self.num_replicas = num_replicas |
|
|
self.rank = rank |
|
|
self.mode = mode |
|
|
len_dataset = self.len_dataset |
|
|
|
|
|
if (len_dataset % num_replicas != 0) and (mode == 'train'): |
|
|
|
|
|
aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)] |
|
|
aligned_len_dataset = aligned_indices.shape[0] |
|
|
elif (len_dataset % num_replicas != 0) and (mode == 'eval'): |
|
|
extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)]) |
|
|
aligned_indices = np.concatenate([range(len_dataset), extend]) |
|
|
aligned_len_dataset = aligned_indices.shape[0] |
|
|
else: |
|
|
aligned_indices = np.arange(len_dataset) |
|
|
aligned_len_dataset = len_dataset |
|
|
|
|
|
num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas) |
|
|
num_even = num_even_distributed_groups * group_size * num_replicas |
|
|
|
|
|
self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size) |
|
|
self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1) |
|
|
|
|
|
if self.leftover_groups.size == 0: |
|
|
self.leftover_groups = None |
|
|
elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'): |
|
|
|
|
|
self.leftover_groups = None |
|
|
|
|
|
|
|
|
for groupi in self.regular_groups: |
|
|
for idx in groupi: |
|
|
idx_lowerbd = groupi[0] |
|
|
idx_upperbd = groupi[-1] |
|
|
idx_reference = (idx_lowerbd+idx_upperbd)//2 |
|
|
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] |
|
|
if self.leftover_groups is not None: |
|
|
for groupi in self.leftover_groups: |
|
|
for idx in groupi: |
|
|
idx_lowerbd = groupi[0] |
|
|
idx_upperbd = groupi[-1] |
|
|
idx_reference = (idx_lowerbd+idx_upperbd)//2 |
|
|
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] |
|
|
|
|
|
def concat(self, nparrays, axis=0): |
|
|
|
|
|
nparrays = [i for i in nparrays if i.size > 0] |
|
|
return np.concatenate(nparrays, axis=axis) |
|
|
|
|
|
def __iter__(self): |
|
|
indices = self.get_sync_order() |
|
|
return iter(indices) |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_samples |
|
|
|
|
|
def get_sync_order(self): |
|
|
|
|
|
|
|
|
|
|
|
mode = self.mode |
|
|
rank = self.rank |
|
|
num_replicas = self.num_replicas |
|
|
group_size = self.group_size |
|
|
num_groups = len(self.regular_groups) |
|
|
|
|
|
if mode == 'train': |
|
|
g_indices = torch.randperm(num_groups).to(rank) |
|
|
dist.broadcast(g_indices, src=0) |
|
|
g_indices = g_indices.to('cpu').tolist() |
|
|
num_groups_per_rank = num_groups // num_replicas |
|
|
groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)] |
|
|
indices = groups.flatten() |
|
|
|
|
|
if self.leftover_groups is not None: |
|
|
leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank) |
|
|
dist.broadcast(leftg_indices, src=0) |
|
|
leftg_indices = leftg_indices.to('cpu').tolist() |
|
|
last = self.leftover_groups[leftg_indices][rank] |
|
|
indices = np.concatenate([indices, last], axis=0) |
|
|
elif mode == 'eval': |
|
|
groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :] |
|
|
indices = groups.flatten() |
|
|
if self.leftover_groups is not None: |
|
|
last = self.leftover_groups[rank] |
|
|
indices = np.concatenate([indices, last], axis=0) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1]))) |
|
|
return indices |
|
|
|