| | |
| | |
| | from contextlib import contextmanager |
| | from typing import Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from mmengine.device import (is_cuda_available, is_mlu_available, |
| | is_musa_available, is_npu_available) |
| | from mmengine.registry import OPTIM_WRAPPERS |
| | from mmengine.utils import digit_version |
| | from mmengine.utils.dl_utils import TORCH_VERSION |
| | from mmengine.optim.optimizer import OptimWrapper |
| | |
| |
|
| | if is_npu_available(): |
| | from torch.npu.amp import GradScaler |
| | elif is_mlu_available(): |
| | from torch.mlu.amp import GradScaler |
| | else: |
| | from torch.cuda.amp import GradScaler |
| |
|
| |
|
| | |
| | class AmpOptimWrapper(OptimWrapper): |
| | """A subclass of :class:`OptimWrapper` that supports automatic mixed |
| | precision training based on torch.cuda.amp. |
| | |
| | ``AmpOptimWrapper`` provides a unified interface with |
| | ``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way |
| | as ``OptimWrapper``. |
| | |
| | Warnings: |
| | ``AmpOptimWrapper`` requires PyTorch >= 1.6. |
| | |
| | Args: |
| | loss_scale (float or str or dict): The initial configuration of |
| | `torch.cuda.amp.GradScaler`. See more specific arguments |
| | introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501 |
| | Defaults to ``dynamic``. |
| | |
| | - "dynamic": Initialize GradScale without any arguments. |
| | - float: Initialize GradScaler with ``init_scale``. |
| | - dict: Initialize GradScaler with more detail configuration. |
| | |
| | dtype (str or torch.dtype, optional): The data type to autocast in amp. |
| | If a ``str`` is given, it will be converted to ``torch.dtype``. |
| | Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and |
| | `'float64'`. If set to ``None``, the default data type will be used. |
| | Defaults to None. |
| | `New in version 0.6.1.` |
| | use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should |
| | be enabled when using ``FullyShardedDataParallel``. |
| | Defaults to False. |
| | `New in version 0.8.0.` |
| | **kwargs: Keyword arguments passed to OptimWrapper. |
| | |
| | Warnings: |
| | ``dtype`` argument is only available with PyTorch version >= 1.10.0. If |
| | you use PyTorch of an older version, it will be ignored. |
| | |
| | Note: |
| | If you use ``IterBasedRunner`` and enable gradient accumulation, |
| | the original `max_iters` should be multiplied by |
| | ``accumulative_counts``. |
| | """ |
| |
|
| | valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64') |
| |
|
| | def __init__(self, |
| | loss_scale: str = 'dynamic', |
| | dtype: Union[str, torch.dtype] = None, |
| | use_fsdp: bool = False, |
| | **kwargs): |
| | assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( |
| | '`torch.cuda.amp` is only available when pytorch version >= 1.6') |
| | assert is_cuda_available() or is_npu_available() or is_mlu_available( |
| | ) or is_musa_available(), ( |
| | '``AmpOptimizerWrapper`` is only available training ' |
| | 'on gpu, npu, mlu or musa') |
| | super().__init__(**kwargs) |
| | self._scale_update_param = None |
| |
|
| | if use_fsdp: |
| | if digit_version(torch.__version__) >= digit_version('2.0.0'): |
| | from torch.distributed.fsdp.sharded_grad_scaler import \ |
| | ShardedGradScaler |
| | scaler_type = ShardedGradScaler |
| | else: |
| | raise RuntimeError( |
| | 'PyTorch>=2.0.0 is required when sets `use_fsdp=True`') |
| | else: |
| | scaler_type = GradScaler |
| |
|
| | enable_loss_scaler = dtype != torch.bfloat16 |
| |
|
| | if loss_scale == 'dynamic': |
| | |
| | |
| | self.loss_scaler = scaler_type(enabled=enable_loss_scaler) |
| | elif isinstance(loss_scale, float): |
| | |
| | self._scale_update_param = loss_scale |
| | self.loss_scaler = scaler_type( |
| | init_scale=loss_scale, enabled=enable_loss_scaler) |
| | elif isinstance(loss_scale, dict): |
| | |
| | loss_scale['enabled'] = loss_scale.pop('enabled', |
| | True) and enable_loss_scaler |
| | self.loss_scaler = scaler_type(**loss_scale) |
| | else: |
| | raise TypeError('loss_scale must be of type float, dict, or ' |
| | f'"dynamic", but got {loss_scale}') |
| |
|
| | |
| | if isinstance(dtype, str): |
| | assert dtype in self.valid_dtypes, ( |
| | f'dtype should be any of {self.valid_dtypes}, got {dtype}') |
| | dtype = getattr(torch, dtype) |
| |
|
| | assert dtype is None or isinstance(dtype, torch.dtype), ( |
| | f'dtype should be None or instance of torch.dtype, got {dtype}') |
| | self.cast_dtype = dtype |
| |
|
| | def backward(self, loss: torch.Tensor, **kwargs): |
| | """Perform gradient back propagation with :attr:`loss_scaler`. |
| | |
| | Args: |
| | loss (torch.Tensor): The loss of current iteration. |
| | kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` |
| | """ |
| | self.loss_scaler.scale(loss).backward(**kwargs) |
| | self._inner_count += 1 |
| |
|
| | def step(self, **kwargs): |
| | """Update parameters with :attr:`loss_scaler`. |
| | |
| | Args: |
| | kwargs: Keyword arguments passed to |
| | :meth:`torch.optim.Optimizer.step`. |
| | """ |
| | if self.clip_grad_kwargs: |
| | self.loss_scaler.unscale_(self.optimizer) |
| | self._clip_grad() |
| | self.loss_scaler.step(self.optimizer, **kwargs) |
| | self.loss_scaler.update(self._scale_update_param) |
| |
|
| | def state_dict(self) -> dict: |
| | """Get the state dictionary of :attr:`optimizer` and |
| | :attr:`loss_scaler`. |
| | |
| | Based on the state dictionary of the optimizer, the returned state |
| | dictionary will add a key named "loss_scaler". |
| | |
| | Returns: |
| | dict: The merged state dict of :attr:`loss_scaler` and |
| | :attr:`optimizer`. |
| | """ |
| | |
| | state_dict = super().state_dict() |
| | state_dict['loss_scaler'] = self.loss_scaler.state_dict() |
| | return state_dict |
| |
|
| | def load_state_dict(self, state_dict: dict): |
| | """Load and parse the state dictionary of :attr:`optimizer` and |
| | :attr:`loss_scaler`. |
| | |
| | If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will |
| | load the corresponding keys. Otherwise, only the :attr:`optimizer` |
| | will load the state dictionary. |
| | |
| | Args: |
| | state_dict (dict): The state dict of :attr:`optimizer` and |
| | :attr:`loss_scaler` |
| | """ |
| | if 'loss_scaler' in state_dict: |
| | self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler')) |
| |
|
| | if 'base_param_settings' in state_dict: |
| | self.base_param_settings = state_dict.pop('base_param_settings') |
| |
|
| | |
| | self.optimizer.load_state_dict(state_dict) |
| |
|
| | @contextmanager |
| | def optim_context(self, model: nn.Module): |
| | """Enables the context for mixed precision training, and enables the |
| | context for disabling gradient synchronization during gradient |
| | accumulation context. |
| | |
| | Args: |
| | model (nn.Module): The training model. |
| | """ |
| | from mmengine.runner.amp import autocast |
| | with super().optim_context(model), autocast(dtype=self.cast_dtype): |
| | yield |