| | |
| |
|
| | import torch |
| | from torch.distributed.rpc import is_available |
| |
|
| | from mmengine.dist import is_main_process |
| | from mmengine.utils import digit_version |
| | from mmengine.utils.dl_utils import TORCH_VERSION |
| |
|
| | try: |
| | from torch.distributed.optim import \ |
| | ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer |
| | except ImportError: |
| | _ZeroRedundancyOptimizer = object |
| |
|
| | from .builder import OPTIMIZERS |
| |
|
| |
|
| | @OPTIMIZERS.register_module() |
| | class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): |
| | """A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a |
| | optimizer type as string. |
| | |
| | This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its |
| | states across ranks in the group as described by ZeRO_. The local optimizer |
| | instance in each rank is only responsible for updating approximately |
| | ``1 / world_size`` parameters and hence only needs to keep |
| | ``1 / world_size`` optimizer states. After parameters are updated locally, |
| | each rank will broadcast its parameters to all other peers to keep all |
| | model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used |
| | in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to |
| | reduce per-rank peak memory consumption. |
| | |
| | ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number |
| | of parameters at each rank. Each parameter belongs to a single rank and is |
| | not divided among ranks. The partition is arbitrary and might not match the |
| | the parameter registration or usage order. |
| | |
| | Warnings: |
| | ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. |
| | |
| | Warnings: |
| | ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param |
| | groups. |
| | |
| | Args: |
| | params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s |
| | or :class:`dict` s giving all parameters, which will be sharded |
| | across ranks. |
| | optimizer_type (str): the string of the local optimizer class. |
| | |
| | .. _ZeRO: https://arxiv.org/abs/1910.02054 |
| | """ |
| |
|
| | def __init__(self, params, optimizer_type: str, **kwargs): |
| | assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( |
| | '`torch.distributed.optim.ZeroReundancyOptimizer` is only ' |
| | 'available when pytorch version >= 1.8.') |
| | assert is_available(), 'torch.distributed.rpc is not available.' |
| | |
| | params = list(params) |
| | assert ( |
| | all(isinstance(p, torch.Tensor) for p in params) |
| | or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), ( |
| | 'PyTorch ZeroRedundancyOptimizer started to support param ' |
| | 'groups since 1.12.0. Please update your pytorch version to ' |
| | 'enable this feature, or disable param groups by deleting ' |
| | '`paramwise_cfg` filed in config file.') |
| | optimizer_class = getattr(torch.optim, optimizer_type) |
| | |
| | |
| | |
| | super().__init__(params, optimizer_class, **kwargs) |
| |
|
| | def state_dict(self): |
| | """Consolidate `state_dict`s from ranks to save the `state_dict`.""" |
| | self.consolidate_state_dict() |
| | state_dict = super().state_dict() if is_main_process() else dict() |
| | return state_dict |
| |
|