| |
| import os.path as osp |
| import warnings |
|
|
| from annotator.uniformer.mmcv.fileio import FileClient |
| from ..dist_utils import allreduce_params, master_only |
| from .hook import HOOKS, Hook |
|
|
|
|
| @HOOKS.register_module() |
| class CheckpointHook(Hook): |
| """Save checkpoints periodically. |
| |
| Args: |
| interval (int): The saving period. If ``by_epoch=True``, interval |
| indicates epochs, otherwise it indicates iterations. |
| Default: -1, which means "never". |
| by_epoch (bool): Saving checkpoints by epoch or by iteration. |
| Default: True. |
| save_optimizer (bool): Whether to save optimizer state_dict in the |
| checkpoint. It is usually used for resuming experiments. |
| Default: True. |
| out_dir (str, optional): The root directory to save checkpoints. If not |
| specified, ``runner.work_dir`` will be used by default. If |
| specified, the ``out_dir`` will be the concatenation of ``out_dir`` |
| and the last level directory of ``runner.work_dir``. |
| `Changed in version 1.3.16.` |
| max_keep_ckpts (int, optional): The maximum checkpoints to keep. |
| In some cases we want only the latest few checkpoints and would |
| like to delete old ones to save the disk space. |
| Default: -1, which means unlimited. |
| save_last (bool, optional): Whether to force the last checkpoint to be |
| saved regardless of interval. Default: True. |
| sync_buffer (bool, optional): Whether to synchronize buffers in |
| different gpus. Default: False. |
| file_client_args (dict, optional): Arguments to instantiate a |
| FileClient. See :class:`mmcv.fileio.FileClient` for details. |
| Default: None. |
| `New in version 1.3.16.` |
| |
| .. warning:: |
| Before v1.3.16, the ``out_dir`` argument indicates the path where the |
| checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the |
| root directory and the final path to save checkpoint is the |
| concatenation of ``out_dir`` and the last level directory of |
| ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" |
| and the value of ``runner.work_dir`` is "/path/of/B", then the final |
| path will be "/path/of/A/B". |
| """ |
|
|
| def __init__(self, |
| interval=-1, |
| by_epoch=True, |
| save_optimizer=True, |
| out_dir=None, |
| max_keep_ckpts=-1, |
| save_last=True, |
| sync_buffer=False, |
| file_client_args=None, |
| **kwargs): |
| self.interval = interval |
| self.by_epoch = by_epoch |
| self.save_optimizer = save_optimizer |
| self.out_dir = out_dir |
| self.max_keep_ckpts = max_keep_ckpts |
| self.save_last = save_last |
| self.args = kwargs |
| self.sync_buffer = sync_buffer |
| self.file_client_args = file_client_args |
|
|
| def before_run(self, runner): |
| if not self.out_dir: |
| self.out_dir = runner.work_dir |
|
|
| self.file_client = FileClient.infer_client(self.file_client_args, |
| self.out_dir) |
|
|
| |
| |
| |
| |
| if self.out_dir != runner.work_dir: |
| basename = osp.basename(runner.work_dir.rstrip(osp.sep)) |
| self.out_dir = self.file_client.join_path(self.out_dir, basename) |
|
|
| runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' |
| f'{self.file_client.name}.')) |
|
|
| |
| |
| if 'create_symlink' in self.args: |
| if self.args[ |
| 'create_symlink'] and not self.file_client.allow_symlink: |
| self.args['create_symlink'] = False |
| warnings.warn( |
| ('create_symlink is set as True by the user but is changed' |
| 'to be False because creating symbolic link is not ' |
| f'allowed in {self.file_client.name}')) |
| else: |
| self.args['create_symlink'] = self.file_client.allow_symlink |
|
|
| def after_train_epoch(self, runner): |
| if not self.by_epoch: |
| return |
|
|
| |
| |
| |
| if self.every_n_epochs( |
| runner, self.interval) or (self.save_last |
| and self.is_last_epoch(runner)): |
| runner.logger.info( |
| f'Saving checkpoint at {runner.epoch + 1} epochs') |
| if self.sync_buffer: |
| allreduce_params(runner.model.buffers()) |
| self._save_checkpoint(runner) |
|
|
| @master_only |
| def _save_checkpoint(self, runner): |
| """Save the current checkpoint and delete unwanted checkpoint.""" |
| runner.save_checkpoint( |
| self.out_dir, save_optimizer=self.save_optimizer, **self.args) |
| if runner.meta is not None: |
| if self.by_epoch: |
| cur_ckpt_filename = self.args.get( |
| 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) |
| else: |
| cur_ckpt_filename = self.args.get( |
| 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) |
| runner.meta.setdefault('hook_msgs', dict()) |
| runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( |
| self.out_dir, cur_ckpt_filename) |
| |
| if self.max_keep_ckpts > 0: |
| if self.by_epoch: |
| name = 'epoch_{}.pth' |
| current_ckpt = runner.epoch + 1 |
| else: |
| name = 'iter_{}.pth' |
| current_ckpt = runner.iter + 1 |
| redundant_ckpts = range( |
| current_ckpt - self.max_keep_ckpts * self.interval, 0, |
| -self.interval) |
| filename_tmpl = self.args.get('filename_tmpl', name) |
| for _step in redundant_ckpts: |
| ckpt_path = self.file_client.join_path( |
| self.out_dir, filename_tmpl.format(_step)) |
| if self.file_client.isfile(ckpt_path): |
| self.file_client.remove(ckpt_path) |
| else: |
| break |
|
|
| def after_train_iter(self, runner): |
| if self.by_epoch: |
| return |
|
|
| |
| |
| |
| if self.every_n_iters( |
| runner, self.interval) or (self.save_last |
| and self.is_last_iter(runner)): |
| runner.logger.info( |
| f'Saving checkpoint at {runner.iter + 1} iterations') |
| if self.sync_buffer: |
| allreduce_params(runner.model.buffers()) |
| self._save_checkpoint(runner) |
|
|