| | |
| | from typing import Dict, Optional, Sequence, Union |
| |
|
| | from mmengine import is_method_overridden |
| |
|
| | DATA_BATCH = Optional[Union[dict, tuple, list]] |
| |
|
| |
|
| | class Hook: |
| | """Base hook class. |
| | |
| | All hooks should inherit from this class. |
| | """ |
| |
|
| | priority = 'NORMAL' |
| | stages = ('before_run', 'after_load_checkpoint', 'before_train', |
| | 'before_train_epoch', 'before_train_iter', 'after_train_iter', |
| | 'after_train_epoch', 'before_val', 'before_val_epoch', |
| | 'before_val_iter', 'after_val_iter', 'after_val_epoch', |
| | 'after_val', 'before_save_checkpoint', 'after_train', |
| | 'before_test', 'before_test_epoch', 'before_test_iter', |
| | 'after_test_iter', 'after_test_epoch', 'after_test', 'after_run') |
| |
|
| | def before_run(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before the training validation or testing process. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | """ |
| |
|
| | def after_run(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before the training validation or testing process. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | """ |
| |
|
| | def before_train(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before train. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| |
|
| | def after_train(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after train. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| |
|
| | def before_val(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before validation. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | """ |
| |
|
| | def after_val(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after validation. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | """ |
| |
|
| | def before_test(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before testing. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | """ |
| |
|
| | def after_test(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after testing. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | """ |
| |
|
| | def before_save_checkpoint(self, runner, checkpoint: dict) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before saving the checkpoint. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | checkpoint (dict): Model's checkpoint. |
| | """ |
| |
|
| | def after_load_checkpoint(self, runner, checkpoint: dict) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after loading the checkpoint. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | checkpoint (dict): Model's checkpoint. |
| | """ |
| |
|
| | def before_train_epoch(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each training epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | self._before_epoch(runner, mode='train') |
| |
|
| | def before_val_epoch(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each validation epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | """ |
| | self._before_epoch(runner, mode='val') |
| |
|
| | def before_test_epoch(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each test epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | """ |
| | self._before_epoch(runner, mode='test') |
| |
|
| | def after_train_epoch(self, runner) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each training epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | """ |
| | self._after_epoch(runner, mode='train') |
| |
|
| | def after_val_epoch(self, |
| | runner, |
| | metrics: Optional[Dict[str, float]] = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each validation epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | metrics (Dict[str, float], optional): Evaluation results of all |
| | metrics on validation dataset. The keys are the names of the |
| | metrics, and the values are corresponding results. |
| | """ |
| | self._after_epoch(runner, mode='val') |
| |
|
| | def after_test_epoch(self, |
| | runner, |
| | metrics: Optional[Dict[str, float]] = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each test epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | metrics (Dict[str, float], optional): Evaluation results of all |
| | metrics on test dataset. The keys are the names of the |
| | metrics, and the values are corresponding results. |
| | """ |
| | self._after_epoch(runner, mode='test') |
| |
|
| | def before_train_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each training iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the train loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | """ |
| | self._before_iter( |
| | runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') |
| |
|
| | def before_val_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each validation iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | batch_idx (int): The index of the current batch in the val loop. |
| | data_batch (dict, optional): Data from dataloader. |
| | Defaults to None. |
| | """ |
| | self._before_iter( |
| | runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') |
| |
|
| | def before_test_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each test iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the testing process. |
| | batch_idx (int): The index of the current batch in the test loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | Defaults to None. |
| | """ |
| | self._before_iter( |
| | runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') |
| |
|
| | def after_train_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs: Optional[dict] = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each training iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the train loop. |
| | data_batch (dict tuple or list, optional): Data from dataloader. |
| | outputs (dict, optional): Outputs from model. |
| | """ |
| | self._after_iter( |
| | runner, |
| | batch_idx=batch_idx, |
| | data_batch=data_batch, |
| | outputs=outputs, |
| | mode='train') |
| |
|
| | def after_val_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs: Optional[Sequence] = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each validation iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the validation process. |
| | batch_idx (int): The index of the current batch in the val loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | outputs (Sequence, optional): Outputs from model. |
| | """ |
| | self._after_iter( |
| | runner, |
| | batch_idx=batch_idx, |
| | data_batch=data_batch, |
| | outputs=outputs, |
| | mode='val') |
| |
|
| | def after_test_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs: Optional[Sequence] = None) -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each test iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | batch_idx (int): The index of the current batch in the test loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | outputs (Sequence, optional): Outputs from model. |
| | """ |
| | self._after_iter( |
| | runner, |
| | batch_idx=batch_idx, |
| | data_batch=data_batch, |
| | outputs=outputs, |
| | mode='test') |
| |
|
| | def _before_epoch(self, runner, mode: str = 'train') -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | mode (str): Current mode of runner. Defaults to 'train'. |
| | """ |
| |
|
| | def _after_epoch(self, runner, mode: str = 'train') -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | mode (str): Current mode of runner. Defaults to 'train'. |
| | """ |
| |
|
| | def _before_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | mode: str = 'train') -> None: |
| | """All subclasses should override this method, if they need any |
| | operations before each iter. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | batch_idx (int): The index of the current batch in the loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | mode (str): Current mode of runner. Defaults to 'train'. |
| | """ |
| |
|
| | def _after_iter(self, |
| | runner, |
| | batch_idx: int, |
| | data_batch: DATA_BATCH = None, |
| | outputs: Optional[Union[Sequence, dict]] = None, |
| | mode: str = 'train') -> None: |
| | """All subclasses should override this method, if they need any |
| | operations after each epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | batch_idx (int): The index of the current batch in the loop. |
| | data_batch (dict or tuple or list, optional): Data from dataloader. |
| | outputs (dict or Sequence, optional): Outputs from model. |
| | mode (str): Current mode of runner. Defaults to 'train'. |
| | """ |
| |
|
| | def every_n_epochs(self, runner, n: int, start: int = 0) -> bool: |
| | """Test whether current epoch can be evenly divided by n. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | n (int): Whether current epoch can be evenly divided by n. |
| | start (int): Starting from `start` to check the logic for |
| | every n epochs. Defaults to 0. |
| | |
| | Returns: |
| | bool: Whether current epoch can be evenly divided by n. |
| | """ |
| | dividend = runner.epoch + 1 - start |
| | return dividend % n == 0 if dividend >= 0 and n > 0 else False |
| |
|
| | def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: |
| | """Test whether current inner iteration can be evenly divided by n. |
| | |
| | Args: |
| | batch_idx (int): Current batch index of the training, validation |
| | or testing loop. |
| | n (int): Whether current inner iteration can be evenly |
| | divided by n. |
| | |
| | Returns: |
| | bool: Whether current inner iteration can be evenly |
| | divided by n. |
| | """ |
| | return (batch_idx + 1) % n == 0 if n > 0 else False |
| |
|
| | def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool: |
| | """Test whether current training iteration can be evenly divided by n. |
| | |
| | Args: |
| | runner (Runner): The runner of the training, validation or testing |
| | process. |
| | n (int): Whether current iteration can be evenly divided by n. |
| | start (int): Starting from `start` to check the logic for |
| | every n iterations. Defaults to 0. |
| | |
| | Returns: |
| | bool: Return True if the current iteration can be evenly divided |
| | by n, otherwise False. |
| | """ |
| | dividend = runner.iter + 1 - start |
| | return dividend % n == 0 if dividend >= 0 and n > 0 else False |
| |
|
| | def end_of_epoch(self, dataloader, batch_idx: int) -> bool: |
| | """Check whether the current iteration reaches the last iteration of |
| | the dataloader. |
| | |
| | Args: |
| | dataloader (Dataloader): The dataloader of the training, |
| | validation or testing process. |
| | batch_idx (int): The index of the current batch in the loop. |
| | Returns: |
| | bool: Whether reaches the end of current epoch or not. |
| | """ |
| | return batch_idx + 1 == len(dataloader) |
| |
|
| | def is_last_train_epoch(self, runner) -> bool: |
| | """Test whether current epoch is the last train epoch. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | |
| | Returns: |
| | bool: Whether reaches the end of training epoch. |
| | """ |
| | return runner.epoch + 1 == runner.max_epochs |
| |
|
| | def is_last_train_iter(self, runner) -> bool: |
| | """Test whether current iteration is the last train iteration. |
| | |
| | Args: |
| | runner (Runner): The runner of the training process. |
| | |
| | Returns: |
| | bool: Whether current iteration is the last train iteration. |
| | """ |
| | return runner.iter + 1 == runner.max_iters |
| |
|
| | def get_triggered_stages(self) -> list: |
| | """Get all triggered stages with method name of the hook. |
| | |
| | Returns: |
| | list: List of triggered stages. |
| | """ |
| | trigger_stages = set() |
| | for stage in Hook.stages: |
| | if is_method_overridden(stage, Hook, self): |
| | trigger_stages.add(stage) |
| |
|
| | |
| | |
| | method_stages_map = { |
| | '_before_epoch': |
| | ['before_train_epoch', 'before_val_epoch', 'before_test_epoch'], |
| | '_after_epoch': |
| | ['after_train_epoch', 'after_val_epoch', 'after_test_epoch'], |
| | '_before_iter': |
| | ['before_train_iter', 'before_val_iter', 'before_test_iter'], |
| | '_after_iter': |
| | ['after_train_iter', 'after_val_iter', 'after_test_iter'], |
| | } |
| |
|
| | for method, map_stages in method_stages_map.items(): |
| | if is_method_overridden(method, Hook, self): |
| | trigger_stages.update(map_stages) |
| |
|
| | return list(trigger_stages) |
| |
|