Skip to content

Commit

Permalink
[Feature] Resume from the latest checkpoint automatically. (#245)
Browse files Browse the repository at this point in the history
* [Feature] Resume from the latest checkpoint automatically.

* fix windows path problem

* fix lint

* add code reference
  • Loading branch information
fangyixiao18 authored Mar 25, 2022
1 parent 7860475 commit 3dce8db
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 4 deletions.
9 changes: 8 additions & 1 deletion mmselfsup/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from mmselfsup.core import (DistOptimizerHook, GradAccumFp16OptimizerHook,
build_optimizer)
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.utils import get_root_logger, multi_gpu_test, single_gpu_test
from mmselfsup.utils import (find_latest_checkpoint, get_root_logger,
multi_gpu_test, single_gpu_test)


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -192,6 +193,12 @@ def train_model(model,
eval_hook(val_dataloader, test_fn=eval_fn, **eval_cfg),
priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
5 changes: 3 additions & 2 deletions mmselfsup/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from .extractor import Extractor
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
from .logger import get_root_logger
from .misc import find_latest_checkpoint
from .setup_env import setup_multi_processes
from .test_helper import multi_gpu_test, single_gpu_test

__all__ = [
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
'distributed_sinkhorn', 'Extractor', 'concat_all_gather', 'gather_tensors',
'gather_tensors_batch', 'get_root_logger', 'multi_gpu_test',
'single_gpu_test', 'setup_multi_processes'
'gather_tensors_batch', 'get_root_logger', 'find_latest_checkpoint',
'multi_gpu_test', 'single_gpu_test', 'setup_multi_processes'
]
38 changes: 38 additions & 0 deletions mmselfsup/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings

import mmcv
import numpy as np

Expand All @@ -15,3 +19,37 @@ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
img, mean, std, to_bgr=to_rgb).astype(np.uint8)
imgs.append(np.ascontiguousarray(img))
return imgs


def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.
Args:
path(str): The path to find checkpoints.
suffix(str): File extension.
Defaults to pth.
Returns:
latest_path(str | None): File path of the latest checkpoint.
References:
.. [1] https://github.com/microsoft/SoftTeacher
/blob/main/ssod/utils/patch.py
.. [2] https://github.com/open-mmlab/mmdetection
/blob/master/mmdet/utils/misc.py#L7
"""
if not osp.exists(path):
warnings.warn('The path of checkpoints does not exist.')
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')

checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('There are no checkpoints in the path.')
return None
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
43 changes: 42 additions & 1 deletion tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import pytest
import torch

from mmselfsup.utils.misc import tensor2imgs
from mmselfsup.utils.misc import find_latest_checkpoint, tensor2imgs


def test_tensor2imgs():
Expand All @@ -12,3 +15,41 @@ def test_tensor2imgs():
fake_imgs = tensor2imgs(fake_tensor)
assert len(fake_imgs) == 3
assert fake_imgs[0].shape == (16, 16, 3)


def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir
latest = find_latest_checkpoint(path)
# There are no checkpoints in the path.
assert latest is None

path = osp.join(tmpdir, 'none')
latest = find_latest_checkpoint(path)
# The path does not exist.
assert latest is None

with tempfile.TemporaryDirectory() as tmpdir:
with open(osp.join(tmpdir, 'latest.pth'), 'w') as f:
f.write('latest')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'latest.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(osp.join(tmpdir, 'iter_4000.pth'), 'w') as f:
f.write('iter_4000')
with open(osp.join(tmpdir, 'iter_8000.pth'), 'w') as f:
f.write('iter_8000')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'iter_8000.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(osp.join(tmpdir, 'epoch_1.pth'), 'w') as f:
f.write('epoch_1')
with open(osp.join(tmpdir, 'epoch_2.pth'), 'w') as f:
f.write('epoch_2')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'epoch_2.pth')
5 changes: 5 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def parse_args():
parser.add_argument('--work_dir', help='the dir to save logs and models')
parser.add_argument(
'--resume_from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
Expand Down Expand Up @@ -100,6 +104,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down

0 comments on commit 3dce8db

Please sign in to comment.