| | from typing import List, Optional, Union |
| |
|
| | import torch |
| | import torch.distributed |
| | from torchpack import distributed |
| |
|
| | from utils.misc import list_mean, list_sum |
| |
|
| | __all__ = ["ddp_reduce_tensor", "DistributedMetric"] |
| |
|
| |
|
| | def ddp_reduce_tensor( |
| | tensor: torch.Tensor, reduce="mean" |
| | ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| | tensor_list = [torch.empty_like(tensor) for _ in range(distributed.size())] |
| | torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) |
| | if reduce == "mean": |
| | return list_mean(tensor_list) |
| | elif reduce == "sum": |
| | return list_sum(tensor_list) |
| | elif reduce == "cat": |
| | return torch.cat(tensor_list, dim=0) |
| | elif reduce == "root": |
| | return tensor_list[0] |
| | else: |
| | return tensor_list |
| |
|
| |
|
| | class DistributedMetric(object): |
| | """Average metrics for distributed training.""" |
| |
|
| | def __init__(self, name: Optional[str] = None, backend="ddp"): |
| | self.name = name |
| | self.sum = 0 |
| | self.count = 0 |
| | self.backend = backend |
| |
|
| | def update(self, val: Union[torch.Tensor, int, float], delta_n=1): |
| | val *= delta_n |
| | if type(val) in [int, float]: |
| | val = torch.Tensor(1).fill_(val).cuda() |
| | if self.backend == "ddp": |
| | self.count += ddp_reduce_tensor( |
| | torch.Tensor(1).fill_(delta_n).cuda(), reduce="sum" |
| | ) |
| | self.sum += ddp_reduce_tensor(val.detach(), reduce="sum") |
| | else: |
| | raise NotImplementedError |
| |
|
| | @property |
| | def avg(self): |
| | if self.count == 0: |
| | return torch.Tensor(1).fill_(-1) |
| | else: |
| | return self.sum / self.count |
| |
|