| | |
| | from typing import Optional |
| |
|
| | import torch |
| | from mmengine.hooks import Hook |
| | from mmengine.runner import Runner |
| |
|
| | from mmdet.registry import HOOKS |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class CheckInvalidLossHook(Hook): |
| | """Check invalid loss hook. |
| | |
| | This hook will regularly check whether the loss is valid |
| | during training. |
| | |
| | Args: |
| | interval (int): Checking interval (every k iterations). |
| | Default: 50. |
| | """ |
| |
|
| | def __init__(self, interval: int = 50) -> None: |
| | self.interval = interval |
| |
|
| | def after_train_iter(self, |
| | runner: Runner, |
| | batch_idx: int, |
| | data_batch: Optional[dict] = None, |
| | outputs: Optional[dict] = None) -> None: |
| | """Regularly check whether the loss is valid every n iterations. |
| | |
| | Args: |
| | runner (:obj:`Runner`): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the train loop. |
| | data_batch (dict, Optional): Data from dataloader. |
| | Defaults to None. |
| | outputs (dict, Optional): Outputs from model. Defaults to None. |
| | """ |
| | if self.every_n_train_iters(runner, self.interval): |
| | assert torch.isfinite(outputs['loss']), \ |
| | runner.logger.info('loss become infinite or NaN!') |
| |
|