|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
class GatherLayer(torch.autograd.Function): |
|
|
""" |
|
|
Gather tensors from all process, supporting backward propagation. |
|
|
https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/gather.py |
|
|
""" |
|
|
@staticmethod |
|
|
def forward(ctx, input): |
|
|
ctx.save_for_backward(input) |
|
|
output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] |
|
|
dist.all_gather(output, input) |
|
|
return tuple(output) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, *grads): |
|
|
(input,) = ctx.saved_tensors |
|
|
grad_out = torch.zeros_like(input) |
|
|
grad_out[:] = grads[dist.get_rank()] |
|
|
return grad_out |
|
|
|
|
|
|
|
|
def dist_gather(x: torch.tensor): |
|
|
if not dist.is_initialized(): return x |
|
|
if len(x.shape) == 0: |
|
|
x = x.reshape(1) |
|
|
x_gather = GatherLayer.apply(x) |
|
|
x_gather = torch.cat(x_gather, dim=0) |
|
|
return x_gather |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def dist_gather_nograd(x: torch.tensor): |
|
|
if not dist.is_initialized(): return x |
|
|
x_gather = [torch.ones_like(x) for _ in range(get_world_size())] |
|
|
dist.all_gather(x_gather, x, async_op=False) |
|
|
x_gather = torch.cat(x_gather, dim=0) |
|
|
return x_gather |
|
|
|
|
|
|
|
|
def get_rank(): |
|
|
if not dist.is_available(): |
|
|
return 0 |
|
|
if not dist.is_initialized(): |
|
|
return 0 |
|
|
return dist.get_rank() |
|
|
|
|
|
|
|
|
def is_main(): |
|
|
return get_rank() == 0 |
|
|
|
|
|
|
|
|
def get_world_size(): |
|
|
if not dist.is_initialized(): |
|
|
return 1 |
|
|
else: |
|
|
return dist.get_world_size() |
|
|
|
|
|
def barrier(): |
|
|
if dist.is_initialized(): |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def varsize_gather_nograd(x: torch.Tensor): |
|
|
"""gather tensors of different sizes along the first dimension""" |
|
|
if not dist.is_initialized(): |
|
|
return x |
|
|
|
|
|
|
|
|
size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) |
|
|
allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] |
|
|
dist.all_gather(allsizes, size) |
|
|
max_size = max([size.cpu().max() for size in allsizes]) |
|
|
|
|
|
padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) |
|
|
padded[: x.shape[0]] = x |
|
|
output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] |
|
|
dist.all_gather(output, padded) |
|
|
|
|
|
output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] |
|
|
output = torch.cat(output, dim=0) |
|
|
|
|
|
return output |
|
|
|