| |
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from mmengine.model import ExponentialMovingAverage |
| from torch import Tensor |
|
|
| from mmdet.registry import MODELS |
|
|
|
|
| @MODELS.register_module() |
| class ExpMomentumEMA(ExponentialMovingAverage): |
| """Exponential moving average (EMA) with exponential momentum strategy, |
| which is used in YOLOX. |
| |
| Args: |
| model (nn.Module): The model to be averaged. |
| momentum (float): The momentum used for updating ema parameter. |
| Ema's parameter are updated with the formula: |
| `averaged_param = (1-momentum) * averaged_param + momentum * |
| source_param`. Defaults to 0.0002. |
| gamma (int): Use a larger momentum early in training and gradually |
| annealing to a smaller value to update the ema model smoothly. The |
| momentum is calculated as |
| `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. |
| Defaults to 2000. |
| interval (int): Interval between two updates. Defaults to 1. |
| device (torch.device, optional): If provided, the averaged model will |
| be stored on the :attr:`device`. Defaults to None. |
| update_buffers (bool): if True, it will compute running averages for |
| both the parameters and the buffers of the model. Defaults to |
| False. |
| """ |
|
|
| def __init__(self, |
| model: nn.Module, |
| momentum: float = 0.0002, |
| gamma: int = 2000, |
| interval=1, |
| device: Optional[torch.device] = None, |
| update_buffers: bool = False) -> None: |
| super().__init__( |
| model=model, |
| momentum=momentum, |
| interval=interval, |
| device=device, |
| update_buffers=update_buffers) |
| assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' |
| self.gamma = gamma |
|
|
| def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| steps: int) -> None: |
| """Compute the moving average of the parameters using the exponential |
| momentum strategy. |
| |
| Args: |
| averaged_param (Tensor): The averaged parameters. |
| source_param (Tensor): The source parameters. |
| steps (int): The number of times the parameters have been |
| updated. |
| """ |
| momentum = (1 - self.momentum) * math.exp( |
| -float(1 + steps) / self.gamma) + self.momentum |
| averaged_param.lerp_(source_param, momentum) |
|
|