| | |
| | from mmengine.dist import all_reduce_params, is_distributed |
| | from mmengine.registry import HOOKS |
| | from .hook import Hook |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class SyncBuffersHook(Hook): |
| | """Synchronize model buffers such as running_mean and running_var in BN at |
| | the end of each epoch.""" |
| |
|
| | priority = 'NORMAL' |
| |
|
| | def __init__(self) -> None: |
| | self.distributed = is_distributed() |
| | |
| | |
| | self.called_in_train = False |
| |
|
| | def before_val_epoch(self, runner) -> None: |
| | """All-reduce model buffers before each validation epoch. |
| | |
| | Synchronize the buffers before each validation if they have not been |
| | synchronized at the end of the previous training epoch. This method |
| | will be called when using IterBasedTrainLoop. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | if self.distributed: |
| | if not self.called_in_train: |
| | all_reduce_params(runner.model.buffers(), op='mean') |
| | self.called_in_train = False |
| |
|
| | def after_train_epoch(self, runner) -> None: |
| | """All-reduce model buffers at the end of each epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | if self.distributed: |
| | all_reduce_params(runner.model.buffers(), op='mean') |
| | self.called_in_train = True |
| |
|