| | import copy |
| | import itertools |
| | import logging |
| | from collections import defaultdict |
| | from enum import Enum |
| | from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union |
| | import torch |
| | from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler |
| |
|
| | |
| |
|
| | _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] |
| | _GradientClipper = Callable[[_GradientClipperInput], None] |
| |
|
| |
|
| | class GradientClipType(Enum): |
| | VALUE = "value" |
| | NORM = "norm" |
| |
|
| |
|
| | def _create_gradient_clipper(cfg) -> _GradientClipper: |
| | """ |
| | Creates gradient clipping closure to clip by value or by norm, |
| | according to the provided config. |
| | """ |
| | cfg = copy.deepcopy(cfg) |
| |
|
| | def clip_grad_norm(p: _GradientClipperInput): |
| | torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) |
| |
|
| | def clip_grad_value(p: _GradientClipperInput): |
| | torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) |
| |
|
| | _GRADIENT_CLIP_TYPE_TO_CLIPPER = { |
| | GradientClipType.VALUE: clip_grad_value, |
| | GradientClipType.NORM: clip_grad_norm, |
| | } |
| | return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] |
| |
|
| |
|
| | def _generate_optimizer_class_with_gradient_clipping( |
| | optimizer: Type[torch.optim.Optimizer], |
| | *, |
| | per_param_clipper: Optional[_GradientClipper] = None, |
| | global_clipper: Optional[_GradientClipper] = None, |
| | ) -> Type[torch.optim.Optimizer]: |
| | """ |
| | Dynamically creates a new type that inherits the type of a given instance |
| | and overrides the `step` method to add gradient clipping |
| | """ |
| | assert ( |
| | per_param_clipper is None or global_clipper is None |
| | ), "Not allowed to use both per-parameter clipping and global clipping" |
| |
|
| | def optimizer_wgc_step(self, closure=None): |
| | if per_param_clipper is not None: |
| | for group in self.param_groups: |
| | for p in group["params"]: |
| | per_param_clipper(p) |
| | else: |
| | |
| | |
| | all_params = itertools.chain(*[g["params"] for g in self.param_groups]) |
| | global_clipper(all_params) |
| | super(type(self), self).step(closure) |
| |
|
| | OptimizerWithGradientClip = type( |
| | optimizer.__name__ + "WithGradientClip", |
| | (optimizer,), |
| | {"step": optimizer_wgc_step}, |
| | ) |
| | return OptimizerWithGradientClip |
| |
|
| |
|
| | def maybe_add_gradient_clipping( |
| | cfg, optimizer: Type[torch.optim.Optimizer] |
| | ) -> Type[torch.optim.Optimizer]: |
| | """ |
| | If gradient clipping is enabled through config options, wraps the existing |
| | optimizer type to become a new dynamically created class OptimizerWithGradientClip |
| | that inherits the given optimizer and overrides the `step` method to |
| | include gradient clipping. |
| | |
| | Args: |
| | cfg: CfgNode, configuration options |
| | optimizer: type. A subclass of torch.optim.Optimizer |
| | |
| | Return: |
| | type: either the input `optimizer` (if gradient clipping is disabled), or |
| | a subclass of it with gradient clipping included in the `step` method. |
| | """ |
| | if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: |
| | return optimizer |
| | if isinstance(optimizer, torch.optim.Optimizer): |
| | optimizer_type = type(optimizer) |
| | else: |
| | assert issubclass(optimizer, torch.optim.Optimizer), optimizer |
| | optimizer_type = optimizer |
| |
|
| | grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) |
| | OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( |
| | optimizer_type, per_param_clipper=grad_clipper |
| | ) |
| | if isinstance(optimizer, torch.optim.Optimizer): |
| | optimizer.__class__ = OptimizerWithGradientClip |
| | return optimizer |
| | else: |
| | return OptimizerWithGradientClip |