| | |
| | from typing import Optional, Sequence, Union |
| |
|
| | import torch |
| |
|
| | from mmengine.registry import HOOKS |
| | from ..device import is_cuda_available, is_musa_available |
| | from .hook import Hook |
| |
|
| | DATA_BATCH = Optional[Union[dict, tuple, list]] |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class EmptyCacheHook(Hook): |
| | """Releases all unoccupied cached GPU memory during the process of |
| | training. |
| | |
| | Args: |
| | before_epoch (bool): Whether to release cache before an epoch. Defaults |
| | to False. |
| | after_epoch (bool): Whether to release cache after an epoch. Defaults |
| | to True. |
| | after_iter (bool): Whether to release cache after an iteration. |
| | Defaults to False. |
| | """ |
| |
|
| | priority = 'NORMAL' |
| |
|
| | def __init__(self, |
| | before_epoch: bool = False, |
| | after_epoch: bool = True, |
| | after_iter: bool = False) -> None: |
| | self._do_before_epoch = before_epoch |
| | self._do_after_epoch = after_epoch |
| | self._do_after_iter = after_iter |
| |
|
| | def _after_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs: Optional[Union[dict, Sequence]] = None, |
| | mode: str = 'train') -> None: |
| | """Empty cache after an iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | outputs (dict or sequence, optional): Outputs from model. |
| | mode (str): Current mode of runner. Defaults to 'train'. |
| | """ |
| | if self._do_after_iter: |
| | if is_cuda_available(): |
| | torch.cuda.empty_cache() |
| | elif is_musa_available(): |
| | torch.musa.empty_cache() |
| |
|
| | def _before_epoch(self, runner, mode: str = 'train') -> None: |
| | """Empty cache before an epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | mode (str): Current mode of runner. Defaults to 'train'. |
| | """ |
| | if self._do_before_epoch: |
| | if is_cuda_available(): |
| | torch.cuda.empty_cache() |
| | elif is_musa_available(): |
| | torch.musa.empty_cache() |
| |
|
| | def _after_epoch(self, runner, mode: str = 'train') -> None: |
| | """Empty cache after an epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | mode (str): Current mode of runner. Defaults to 'train'. |
| | """ |
| | if self._do_after_epoch: |
| | if is_cuda_available(): |
| | torch.cuda.empty_cache() |
| | elif is_musa_available(): |
| | torch.musa.empty_cache() |
| |
|