| |
| import functools |
| import warnings |
| from collections import abc |
| from inspect import getfullargspec |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version |
| from .dist_utils import allreduce_grads as _allreduce_grads |
|
|
| try: |
| |
| |
| |
| |
| from torch.cuda.amp import autocast |
| except ImportError: |
| pass |
|
|
|
|
| def cast_tensor_type(inputs, src_type, dst_type): |
| """Recursively convert Tensor in inputs from src_type to dst_type. |
| |
| Args: |
| inputs: Inputs that to be casted. |
| src_type (torch.dtype): Source type.. |
| dst_type (torch.dtype): Destination type. |
| |
| Returns: |
| The same type with inputs, but all contained Tensors have been cast. |
| """ |
| if isinstance(inputs, nn.Module): |
| return inputs |
| elif isinstance(inputs, torch.Tensor): |
| return inputs.to(dst_type) |
| elif isinstance(inputs, str): |
| return inputs |
| elif isinstance(inputs, np.ndarray): |
| return inputs |
| elif isinstance(inputs, abc.Mapping): |
| return type(inputs)({ |
| k: cast_tensor_type(v, src_type, dst_type) |
| for k, v in inputs.items() |
| }) |
| elif isinstance(inputs, abc.Iterable): |
| return type(inputs)( |
| cast_tensor_type(item, src_type, dst_type) for item in inputs) |
| else: |
| return inputs |
|
|
|
|
| def auto_fp16(apply_to=None, out_fp32=False): |
| """Decorator to enable fp16 training automatically. |
| |
| This decorator is useful when you write custom modules and want to support |
| mixed precision training. If inputs arguments are fp32 tensors, they will |
| be converted to fp16 automatically. Arguments other than fp32 tensors are |
| ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the |
| backend, otherwise, original mmcv implementation will be adopted. |
| |
| Args: |
| apply_to (Iterable, optional): The argument names to be converted. |
| `None` indicates all arguments. |
| out_fp32 (bool): Whether to convert the output back to fp32. |
| |
| Example: |
| |
| >>> import torch.nn as nn |
| >>> class MyModule1(nn.Module): |
| >>> |
| >>> # Convert x and y to fp16 |
| >>> @auto_fp16() |
| >>> def forward(self, x, y): |
| >>> pass |
| |
| >>> import torch.nn as nn |
| >>> class MyModule2(nn.Module): |
| >>> |
| >>> # convert pred to fp16 |
| >>> @auto_fp16(apply_to=('pred', )) |
| >>> def do_something(self, pred, others): |
| >>> pass |
| """ |
|
|
| def auto_fp16_wrapper(old_func): |
|
|
| @functools.wraps(old_func) |
| def new_func(*args, **kwargs): |
| |
| |
| if not isinstance(args[0], torch.nn.Module): |
| raise TypeError('@auto_fp16 can only be used to decorate the ' |
| 'method of nn.Module') |
| if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): |
| return old_func(*args, **kwargs) |
|
|
| |
| args_info = getfullargspec(old_func) |
| |
| args_to_cast = args_info.args if apply_to is None else apply_to |
| |
| new_args = [] |
| |
| if args: |
| arg_names = args_info.args[:len(args)] |
| for i, arg_name in enumerate(arg_names): |
| if arg_name in args_to_cast: |
| new_args.append( |
| cast_tensor_type(args[i], torch.float, torch.half)) |
| else: |
| new_args.append(args[i]) |
| |
| new_kwargs = {} |
| if kwargs: |
| for arg_name, arg_value in kwargs.items(): |
| if arg_name in args_to_cast: |
| new_kwargs[arg_name] = cast_tensor_type( |
| arg_value, torch.float, torch.half) |
| else: |
| new_kwargs[arg_name] = arg_value |
| |
| if (TORCH_VERSION != 'parrots' and |
| digit_version(TORCH_VERSION) >= digit_version('1.6.0')): |
| with autocast(enabled=True): |
| output = old_func(*new_args, **new_kwargs) |
| else: |
| output = old_func(*new_args, **new_kwargs) |
| |
| if out_fp32: |
| output = cast_tensor_type(output, torch.half, torch.float) |
| return output |
|
|
| return new_func |
|
|
| return auto_fp16_wrapper |
|
|
|
|
| def force_fp32(apply_to=None, out_fp16=False): |
| """Decorator to convert input arguments to fp32 in force. |
| |
| This decorator is useful when you write custom modules and want to support |
| mixed precision training. If there are some inputs that must be processed |
| in fp32 mode, then this decorator can handle it. If inputs arguments are |
| fp16 tensors, they will be converted to fp32 automatically. Arguments other |
| than fp16 tensors are ignored. If you are using PyTorch >= 1.6, |
| torch.cuda.amp is used as the backend, otherwise, original mmcv |
| implementation will be adopted. |
| |
| Args: |
| apply_to (Iterable, optional): The argument names to be converted. |
| `None` indicates all arguments. |
| out_fp16 (bool): Whether to convert the output back to fp16. |
| |
| Example: |
| |
| >>> import torch.nn as nn |
| >>> class MyModule1(nn.Module): |
| >>> |
| >>> # Convert x and y to fp32 |
| >>> @force_fp32() |
| >>> def loss(self, x, y): |
| >>> pass |
| |
| >>> import torch.nn as nn |
| >>> class MyModule2(nn.Module): |
| >>> |
| >>> # convert pred to fp32 |
| >>> @force_fp32(apply_to=('pred', )) |
| >>> def post_process(self, pred, others): |
| >>> pass |
| """ |
|
|
| def force_fp32_wrapper(old_func): |
|
|
| @functools.wraps(old_func) |
| def new_func(*args, **kwargs): |
| |
| |
| if not isinstance(args[0], torch.nn.Module): |
| raise TypeError('@force_fp32 can only be used to decorate the ' |
| 'method of nn.Module') |
| if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): |
| return old_func(*args, **kwargs) |
| |
| args_info = getfullargspec(old_func) |
| |
| args_to_cast = args_info.args if apply_to is None else apply_to |
| |
| new_args = [] |
| if args: |
| arg_names = args_info.args[:len(args)] |
| for i, arg_name in enumerate(arg_names): |
| if arg_name in args_to_cast: |
| new_args.append( |
| cast_tensor_type(args[i], torch.half, torch.float)) |
| else: |
| new_args.append(args[i]) |
| |
| new_kwargs = dict() |
| if kwargs: |
| for arg_name, arg_value in kwargs.items(): |
| if arg_name in args_to_cast: |
| new_kwargs[arg_name] = cast_tensor_type( |
| arg_value, torch.half, torch.float) |
| else: |
| new_kwargs[arg_name] = arg_value |
| |
| if (TORCH_VERSION != 'parrots' and |
| digit_version(TORCH_VERSION) >= digit_version('1.6.0')): |
| with autocast(enabled=False): |
| output = old_func(*new_args, **new_kwargs) |
| else: |
| output = old_func(*new_args, **new_kwargs) |
| |
| if out_fp16: |
| output = cast_tensor_type(output, torch.float, torch.half) |
| return output |
|
|
| return new_func |
|
|
| return force_fp32_wrapper |
|
|
|
|
| def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): |
| warnings.warning( |
| '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' |
| 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads') |
| _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) |
|
|
|
|
| def wrap_fp16_model(model): |
| """Wrap the FP32 model to FP16. |
| |
| If you are using PyTorch >= 1.6, torch.cuda.amp is used as the |
| backend, otherwise, original mmcv implementation will be adopted. |
| |
| For PyTorch >= 1.6, this function will |
| 1. Set fp16 flag inside the model to True. |
| |
| Otherwise: |
| 1. Convert FP32 model to FP16. |
| 2. Remain some necessary layers to be FP32, e.g., normalization layers. |
| 3. Set `fp16_enabled` flag inside the model to True. |
| |
| Args: |
| model (nn.Module): Model in FP32. |
| """ |
| if (TORCH_VERSION == 'parrots' |
| or digit_version(TORCH_VERSION) < digit_version('1.6.0')): |
| |
| model.half() |
| |
| patch_norm_fp32(model) |
| |
| for m in model.modules(): |
| if hasattr(m, 'fp16_enabled'): |
| m.fp16_enabled = True |
|
|
|
|
| def patch_norm_fp32(module): |
| """Recursively convert normalization layers from FP16 to FP32. |
| |
| Args: |
| module (nn.Module): The modules to be converted in FP16. |
| |
| Returns: |
| nn.Module: The converted module, the normalization layers have been |
| converted to FP32. |
| """ |
| if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): |
| module.float() |
| if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3': |
| module.forward = patch_forward_method(module.forward, torch.half, |
| torch.float) |
| for child in module.children(): |
| patch_norm_fp32(child) |
| return module |
|
|
|
|
| def patch_forward_method(func, src_type, dst_type, convert_output=True): |
| """Patch the forward method of a module. |
| |
| Args: |
| func (callable): The original forward method. |
| src_type (torch.dtype): Type of input arguments to be converted from. |
| dst_type (torch.dtype): Type of input arguments to be converted to. |
| convert_output (bool): Whether to convert the output back to src_type. |
| |
| Returns: |
| callable: The patched forward method. |
| """ |
|
|
| def new_forward(*args, **kwargs): |
| output = func(*cast_tensor_type(args, src_type, dst_type), |
| **cast_tensor_type(kwargs, src_type, dst_type)) |
| if convert_output: |
| output = cast_tensor_type(output, dst_type, src_type) |
| return output |
|
|
| return new_forward |
|
|
|
|
| class LossScaler: |
| """Class that manages loss scaling in mixed precision training which |
| supports both dynamic or static mode. |
| |
| The implementation refers to |
| https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py. |
| Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling. |
| It's important to understand how :class:`LossScaler` operates. |
| Loss scaling is designed to combat the problem of underflowing |
| gradients encountered at long times when training fp16 networks. |
| Dynamic loss scaling begins by attempting a very high loss |
| scale. Ironically, this may result in OVERflowing gradients. |
| If overflowing gradients are encountered, :class:`FP16_Optimizer` then |
| skips the update step for this particular iteration/minibatch, |
| and :class:`LossScaler` adjusts the loss scale to a lower value. |
| If a certain number of iterations occur without overflowing gradients |
| detected,:class:`LossScaler` increases the loss scale once more. |
| In this way :class:`LossScaler` attempts to "ride the edge" of always |
| using the highest loss scale possible without incurring overflow. |
| |
| Args: |
| init_scale (float): Initial loss scale value, default: 2**32. |
| scale_factor (float): Factor used when adjusting the loss scale. |
| Default: 2. |
| mode (str): Loss scaling mode. 'dynamic' or 'static' |
| scale_window (int): Number of consecutive iterations without an |
| overflow to wait before increasing the loss scale. Default: 1000. |
| """ |
|
|
| def __init__(self, |
| init_scale=2**32, |
| mode='dynamic', |
| scale_factor=2., |
| scale_window=1000): |
| self.cur_scale = init_scale |
| self.cur_iter = 0 |
| assert mode in ('dynamic', |
| 'static'), 'mode can only be dynamic or static' |
| self.mode = mode |
| self.last_overflow_iter = -1 |
| self.scale_factor = scale_factor |
| self.scale_window = scale_window |
|
|
| def has_overflow(self, params): |
| """Check if params contain overflow.""" |
| if self.mode != 'dynamic': |
| return False |
| for p in params: |
| if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data): |
| return True |
| return False |
|
|
| def _has_inf_or_nan(x): |
| """Check if params contain NaN.""" |
| try: |
| cpu_sum = float(x.float().sum()) |
| except RuntimeError as instance: |
| if 'value cannot be converted' not in instance.args[0]: |
| raise |
| return True |
| else: |
| if cpu_sum == float('inf') or cpu_sum == -float('inf') \ |
| or cpu_sum != cpu_sum: |
| return True |
| return False |
|
|
| def update_scale(self, overflow): |
| """update the current loss scale value when overflow happens.""" |
| if self.mode != 'dynamic': |
| return |
| if overflow: |
| self.cur_scale = max(self.cur_scale / self.scale_factor, 1) |
| self.last_overflow_iter = self.cur_iter |
| else: |
| if (self.cur_iter - self.last_overflow_iter) % \ |
| self.scale_window == 0: |
| self.cur_scale *= self.scale_factor |
| self.cur_iter += 1 |
|
|
| def state_dict(self): |
| """Returns the state of the scaler as a :class:`dict`.""" |
| return dict( |
| cur_scale=self.cur_scale, |
| cur_iter=self.cur_iter, |
| mode=self.mode, |
| last_overflow_iter=self.last_overflow_iter, |
| scale_factor=self.scale_factor, |
| scale_window=self.scale_window) |
|
|
| def load_state_dict(self, state_dict): |
| """Loads the loss_scaler state dict. |
| |
| Args: |
| state_dict (dict): scaler state. |
| """ |
| self.cur_scale = state_dict['cur_scale'] |
| self.cur_iter = state_dict['cur_iter'] |
| self.mode = state_dict['mode'] |
| self.last_overflow_iter = state_dict['last_overflow_iter'] |
| self.scale_factor = state_dict['scale_factor'] |
| self.scale_window = state_dict['scale_window'] |
|
|
| @property |
| def loss_scale(self): |
| return self.cur_scale |
|
|