| from typing import Optional |
|
|
| import torch |
| from torch import Tensor |
| from torch.distributed import ProcessGroup |
|
|
| |
| |
| |
| |
| if "all_gather_into_tensor" not in dir(torch.distributed): |
| torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base |
| if "reduce_scatter_tensor" not in dir(torch.distributed): |
| torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base |
|
|
|
|
| |
| def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
| world_size = torch.distributed.get_world_size(process_group) |
| output = torch.empty( |
| world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device |
| ) |
| handle = torch.distributed.all_gather_into_tensor( |
| output, input_.contiguous(), group=process_group, async_op=async_op |
| ) |
| return output, handle |
|
|
|
|
| |
| def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
| world_size = torch.distributed.get_world_size(process_group) |
| assert input_.shape[0] % world_size == 0 |
| output = torch.empty( |
| input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device |
| ) |
| handle = torch.distributed.reduce_scatter_tensor( |
| output, input_.contiguous(), group=process_group, async_op=async_op |
| ) |
| return output, handle |
|
|
|
|
| |
| def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
| input_ = input_.contiguous() |
| handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) |
| return input_, handle |
|
|
|
|
| class AllGatherFunc(torch.autograd.Function): |
| """Gather the input from sequence parallel region and concatenate.""" |
|
|
| @staticmethod |
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
| ctx.process_group = process_group |
| output, _ = all_gather_raw(input_, process_group) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output: Tensor): |
| grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) |
| return grad_input, None |
|
|
|
|
| |
| all_gather = AllGatherFunc.apply |
|
|
|
|
| class ReduceScatterFunc(torch.autograd.Function): |
| """Reduce scatter the input from the sequence parallel region and concatenate.""" |
|
|
| @staticmethod |
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
| ctx.process_group = process_group |
| output, _ = reduce_scatter_raw(input_, process_group) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output: Tensor): |
| grad_input, _ = all_gather_raw(grad_output, ctx.process_group) |
| return grad_input, None |
|
|
|
|
| |
| reduce_scatter = ReduceScatterFunc.apply |
|
|
|
|
| class AllReduceFunc(torch.autograd.Function): |
| """Gather the input from sequence parallel region and concatenate.""" |
|
|
| @staticmethod |
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
| ctx.process_group = process_group |
| output, _ = all_reduce_raw(input_, process_group) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output: Tensor): |
| return grad_output, None |
|
|
|
|
| |
| all_reduce = AllReduceFunc.apply |
|
|
|
|
| def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): |
| |
| |
| pamams_shared = { |
| name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) |
| } |
| for _, p in sorted(pamams_shared.items()): |
| with torch.no_grad(): |
| |
| torch.distributed.broadcast( |
| p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group |
| ) |
|
|
|
|
| |
| def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): |
| |
| |
| params_seqparallel = { |
| name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) |
| } |
| grads = [p.grad for _, p in sorted(params_seqparallel.items())] |
| if grads: |
| with torch.no_grad(): |
| coalesced = torch._utils._flatten_dense_tensors(grads) |
| torch.distributed.all_reduce(coalesced, group=process_group) |
| for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): |
| buf.copy_(synced) |
|
|
|
|
| def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: |
| """Get the dim for the local rank derived from splitting dim on world_size processes. |
| |
| The split may not be even across the world_size processes. |
| """ |
| multiple = dim // multiple_of |
| div = multiple // world_size |
| mod = multiple % world_size |
| local_multiple = div + int(local_rank < mod) |
| return local_multiple * multiple_of |
|
|