| |
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| from torch._utils import (_flatten_dense_tensors, _take_tensors, |
| _unflatten_dense_tensors) |
|
|
| from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version |
| from .registry import MODULE_WRAPPERS |
| from .scatter_gather import scatter_kwargs |
|
|
|
|
| @MODULE_WRAPPERS.register_module() |
| class MMDistributedDataParallel(nn.Module): |
|
|
| def __init__(self, |
| module, |
| dim=0, |
| broadcast_buffers=True, |
| bucket_cap_mb=25): |
| super(MMDistributedDataParallel, self).__init__() |
| self.module = module |
| self.dim = dim |
| self.broadcast_buffers = broadcast_buffers |
|
|
| self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024 |
| self._sync_params() |
|
|
| def _dist_broadcast_coalesced(self, tensors, buffer_size): |
| for tensors in _take_tensors(tensors, buffer_size): |
| flat_tensors = _flatten_dense_tensors(tensors) |
| dist.broadcast(flat_tensors, 0) |
| for tensor, synced in zip( |
| tensors, _unflatten_dense_tensors(flat_tensors, tensors)): |
| tensor.copy_(synced) |
|
|
| def _sync_params(self): |
| module_states = list(self.module.state_dict().values()) |
| if len(module_states) > 0: |
| self._dist_broadcast_coalesced(module_states, |
| self.broadcast_bucket_size) |
| if self.broadcast_buffers: |
| if (TORCH_VERSION != 'parrots' |
| and digit_version(TORCH_VERSION) < digit_version('1.0')): |
| buffers = [b.data for b in self.module._all_buffers()] |
| else: |
| buffers = [b.data for b in self.module.buffers()] |
| if len(buffers) > 0: |
| self._dist_broadcast_coalesced(buffers, |
| self.broadcast_bucket_size) |
|
|
| def scatter(self, inputs, kwargs, device_ids): |
| return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) |
|
|
| def forward(self, *inputs, **kwargs): |
| inputs, kwargs = self.scatter(inputs, kwargs, |
| [torch.cuda.current_device()]) |
| return self.module(*inputs[0], **kwargs[0]) |
|
|
| def train_step(self, *inputs, **kwargs): |
| inputs, kwargs = self.scatter(inputs, kwargs, |
| [torch.cuda.current_device()]) |
| output = self.module.train_step(*inputs[0], **kwargs[0]) |
| return output |
|
|
| def val_step(self, *inputs, **kwargs): |
| inputs, kwargs = self.scatter(inputs, kwargs, |
| [torch.cuda.current_device()]) |
| output = self.module.val_step(*inputs[0], **kwargs[0]) |
| return output |
|
|