|
|
|
import os.path as osp
|
|
import warnings
|
|
|
|
from annotator.uniformer.mmcv.fileio import FileClient
|
|
from ..dist_utils import allreduce_params, master_only
|
|
from .hook import HOOKS, Hook
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class CheckpointHook(Hook):
|
|
"""Save checkpoints periodically.
|
|
|
|
Args:
|
|
interval (int): The saving period. If ``by_epoch=True``, interval
|
|
indicates epochs, otherwise it indicates iterations.
|
|
Default: -1, which means "never".
|
|
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
|
Default: True.
|
|
save_optimizer (bool): Whether to save optimizer state_dict in the
|
|
checkpoint. It is usually used for resuming experiments.
|
|
Default: True.
|
|
out_dir (str, optional): The root directory to save checkpoints. If not
|
|
specified, ``runner.work_dir`` will be used by default. If
|
|
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
|
and the last level directory of ``runner.work_dir``.
|
|
`Changed in version 1.3.16.`
|
|
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
|
|
In some cases we want only the latest few checkpoints and would
|
|
like to delete old ones to save the disk space.
|
|
Default: -1, which means unlimited.
|
|
save_last (bool, optional): Whether to force the last checkpoint to be
|
|
saved regardless of interval. Default: True.
|
|
sync_buffer (bool, optional): Whether to synchronize buffers in
|
|
different gpus. Default: False.
|
|
file_client_args (dict, optional): Arguments to instantiate a
|
|
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
|
Default: None.
|
|
`New in version 1.3.16.`
|
|
|
|
.. warning::
|
|
Before v1.3.16, the ``out_dir`` argument indicates the path where the
|
|
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
|
|
root directory and the final path to save checkpoint is the
|
|
concatenation of ``out_dir`` and the last level directory of
|
|
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
|
|
and the value of ``runner.work_dir`` is "/path/of/B", then the final
|
|
path will be "/path/of/A/B".
|
|
"""
|
|
|
|
def __init__(self,
|
|
interval=-1,
|
|
by_epoch=True,
|
|
save_optimizer=True,
|
|
out_dir=None,
|
|
max_keep_ckpts=-1,
|
|
save_last=True,
|
|
sync_buffer=False,
|
|
file_client_args=None,
|
|
**kwargs):
|
|
self.interval = interval
|
|
self.by_epoch = by_epoch
|
|
self.save_optimizer = save_optimizer
|
|
self.out_dir = out_dir
|
|
self.max_keep_ckpts = max_keep_ckpts
|
|
self.save_last = save_last
|
|
self.args = kwargs
|
|
self.sync_buffer = sync_buffer
|
|
self.file_client_args = file_client_args
|
|
|
|
def before_run(self, runner):
|
|
if not self.out_dir:
|
|
self.out_dir = runner.work_dir
|
|
|
|
self.file_client = FileClient.infer_client(self.file_client_args,
|
|
self.out_dir)
|
|
|
|
|
|
|
|
|
|
|
|
if self.out_dir != runner.work_dir:
|
|
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
|
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
|
|
|
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
|
|
f'{self.file_client.name}.'))
|
|
|
|
|
|
|
|
if 'create_symlink' in self.args:
|
|
if self.args[
|
|
'create_symlink'] and not self.file_client.allow_symlink:
|
|
self.args['create_symlink'] = False
|
|
warnings.warn(
|
|
('create_symlink is set as True by the user but is changed'
|
|
'to be False because creating symbolic link is not '
|
|
f'allowed in {self.file_client.name}'))
|
|
else:
|
|
self.args['create_symlink'] = self.file_client.allow_symlink
|
|
|
|
def after_train_epoch(self, runner):
|
|
if not self.by_epoch:
|
|
return
|
|
|
|
|
|
|
|
|
|
if self.every_n_epochs(
|
|
runner, self.interval) or (self.save_last
|
|
and self.is_last_epoch(runner)):
|
|
runner.logger.info(
|
|
f'Saving checkpoint at {runner.epoch + 1} epochs')
|
|
if self.sync_buffer:
|
|
allreduce_params(runner.model.buffers())
|
|
self._save_checkpoint(runner)
|
|
|
|
@master_only
|
|
def _save_checkpoint(self, runner):
|
|
"""Save the current checkpoint and delete unwanted checkpoint."""
|
|
runner.save_checkpoint(
|
|
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
|
if runner.meta is not None:
|
|
if self.by_epoch:
|
|
cur_ckpt_filename = self.args.get(
|
|
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
|
else:
|
|
cur_ckpt_filename = self.args.get(
|
|
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
|
runner.meta.setdefault('hook_msgs', dict())
|
|
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
|
|
self.out_dir, cur_ckpt_filename)
|
|
|
|
if self.max_keep_ckpts > 0:
|
|
if self.by_epoch:
|
|
name = 'epoch_{}.pth'
|
|
current_ckpt = runner.epoch + 1
|
|
else:
|
|
name = 'iter_{}.pth'
|
|
current_ckpt = runner.iter + 1
|
|
redundant_ckpts = range(
|
|
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
|
-self.interval)
|
|
filename_tmpl = self.args.get('filename_tmpl', name)
|
|
for _step in redundant_ckpts:
|
|
ckpt_path = self.file_client.join_path(
|
|
self.out_dir, filename_tmpl.format(_step))
|
|
if self.file_client.isfile(ckpt_path):
|
|
self.file_client.remove(ckpt_path)
|
|
else:
|
|
break
|
|
|
|
def after_train_iter(self, runner):
|
|
if self.by_epoch:
|
|
return
|
|
|
|
|
|
|
|
|
|
if self.every_n_iters(
|
|
runner, self.interval) or (self.save_last
|
|
and self.is_last_iter(runner)):
|
|
runner.logger.info(
|
|
f'Saving checkpoint at {runner.iter + 1} iterations')
|
|
if self.sync_buffer:
|
|
allreduce_params(runner.model.buffers())
|
|
self._save_checkpoint(runner)
|
|
|