Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move checkpoint funtions from runner to a new sub-package #1495

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/en/api/checkpoint.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
.. role:: hidden
:class: hidden-section

mmengine.checkpoint
===================================

.. contents:: mmengine.checkpoint
:depth: 2
:local:
:backlinks: top

.. currentmodule:: mmengine.checkpoint

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

CheckpointLoader

.. autosummary::
:toctree: generated
:nosignatures:

load_checkpoint
save_checkpoint
load_state_dict
get_state_dict
weights_to_cpu
find_latest_checkpoint
get_deprecated_model_names
get_external_models
get_mmcls_models
get_torchvision_models
4 changes: 4 additions & 0 deletions docs/en/api/runner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Loop
Checkpoints
----------------

.. warn::

All functions and classes in this file have been moved to `mmengine.checkpoint`. Please import them from `mmengine.checkpoint`.

.. autosummary::
:toctree: generated
:nosignatures:
Expand Down
2 changes: 1 addition & 1 deletion docs/en/api/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ mmengine._strategy
:nosignatures:
:template: classtemplate.rst

CollosalAIModelWrapper
ColossalAIModelWrapper
ColossalAIOptimWrapper
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ You can switch between Chinese and English documents in the lower-left corner of
mmengine.dataset <api/dataset>
mmengine.infer <api/infer>
mmengine.device <api/device>
mmengine.checkpoint <api/checkpoint>
mmengine.hub <api/hub>
mmengine.logging <api/logging>
mmengine.visualization <api/visualization>
Expand Down
34 changes: 34 additions & 0 deletions docs/zh_cn/api/checkpoint.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
.. role:: hidden
:class: hidden-section

mmengine.checkpoint
===================================

.. contents:: mmengine.checkpoint
:depth: 2
:local:
:backlinks: top

.. currentmodule:: mmengine.checkpoint

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

CheckpointLoader

.. autosummary::
:toctree: generated
:nosignatures:

load_checkpoint
save_checkpoint
load_state_dict
get_state_dict
weights_to_cpu
find_latest_checkpoint
get_deprecated_model_names
get_external_models
get_mmcls_models
get_torchvision_models
4 changes: 4 additions & 0 deletions docs/zh_cn/api/runner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ Loop
Checkpoints
----------------

.. warn::

所有的函数和类在这个文件中已经被移动到 `mmengine.checkpoint`。请从 `mmengine.checkpoint` 导入它们。

.. autosummary::
:toctree: generated
:nosignatures:
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/api/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ mmengine._strategy
:nosignatures:
:template: classtemplate.rst

CollosalAIModelWrapper
ColossalAIModelWrapper
ColossalAIOptimWrapper
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
mmengine.dataset <api/dataset>
mmengine.infer <api/infer>
mmengine.device <api/device>
mmengine.checkpoint <api/checkpoint>
mmengine.hub <api/hub>
mmengine.logging <api/logging>
mmengine.visualization <api/visualization>
Expand Down
3 changes: 1 addition & 2 deletions mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.optim import Optimizer

import mmengine
from mmengine.checkpoint.io import _load_checkpoint_to_model
from mmengine.config import Config, ConfigDict
from mmengine.dist import (broadcast, get_dist_info, infer_launcher,
is_distributed)
Expand Down Expand Up @@ -792,8 +793,6 @@ def load_model_state_dict(
revise_keys: list = [(r'^module.', '')],
) -> None:
"""Load model state from dict."""
from mmengine.runner.checkpoint import _load_checkpoint_to_model

if is_model_wrapper(self.model):
model = self.model.module
else:
Expand Down
4 changes: 2 additions & 2 deletions mmengine/_strategy/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import mmengine
from mmengine import mkdir_or_exist
from mmengine._strategy import BaseStrategy
from mmengine.checkpoint.io import _load_checkpoint, save_checkpoint
from mmengine.device import get_device
from mmengine.dist import init_dist, is_main_process
from mmengine.fileio import join_path
from mmengine.model import BaseDataPreprocessor
from mmengine.optim import BaseOptimWrapper, OptimWrapper, _ParamScheduler
from mmengine.registry import STRATEGIES, Registry
from mmengine.registry.root import MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS
from mmengine.runner.checkpoint import _load_checkpoint, save_checkpoint
from mmengine.utils import get_git_hash

# Component for colossalai `plugins` and `mixed_precisions`
Expand Down Expand Up @@ -191,7 +191,7 @@ def __getattr__(self, name):
class ColossalAIStrategy(BaseStrategy):
"""
Args:
config: (str or dict): The colossalai config file to setup distributed
config (str or dict): The colossalai config file to setup distributed
environment. See more details in the `colossalai config tutorial`_.
mixed_precision (str or MixedPrecision): The mixed precision to run the
training. Defaults to None. If the argument is a string, it can be
Expand Down
2 changes: 1 addition & 1 deletion mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import torch.nn as nn

import mmengine
from mmengine.checkpoint import save_checkpoint, weights_to_cpu
from mmengine.dist import init_dist, is_main_process
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES)
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
from mmengine.utils import apply_to, digit_version, get_git_hash
from .base import BaseStrategy

Expand Down
3 changes: 1 addition & 2 deletions mmengine/_strategy/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.optim.lr_scheduler import LRScheduler

import mmengine
from mmengine.checkpoint import save_checkpoint
from mmengine.config import Config, ConfigDict
from mmengine.device import get_device
from mmengine.dist import get_rank, is_main_process
Expand Down Expand Up @@ -259,8 +260,6 @@ def save_checkpoint(self,
checkpoint before saving the checkpoint.
Defaults to None.
"""
from mmengine.runner.checkpoint import save_checkpoint

state_dict: dict = dict()
state_dict['state_dict'] = self.model_state_dict()

Expand Down
5 changes: 1 addition & 4 deletions mmengine/_strategy/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn

import mmengine
from mmengine.checkpoint.io import _load_checkpoint, save_checkpoint
from mmengine.device import get_device
from mmengine.model import revert_sync_batchnorm
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
Expand Down Expand Up @@ -135,8 +136,6 @@ def load_checkpoint(
checkpoint after loading the checkpoint.
Defaults to None.
"""
from mmengine.runner.checkpoint import _load_checkpoint

self.logger.info(f'Load checkpoint from {filename}')

if map_location == 'default':
Expand Down Expand Up @@ -255,8 +254,6 @@ def save_checkpoint(
checkpoint before saving the checkpoint.
Defaults to None.
"""
from mmengine.runner.checkpoint import save_checkpoint

state_dict: dict = dict()
state_dict['state_dict'] = self.model_state_dict()

Expand Down
14 changes: 14 additions & 0 deletions mmengine/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .io import (get_state_dict, load_checkpoint, load_state_dict,
save_checkpoint, weights_to_cpu)
from .loader import CheckpointLoader
from .utils import (find_latest_checkpoint, get_deprecated_model_names,
get_external_models, get_mmcls_models,
get_torchvision_models)

__all__ = [
'CheckpointLoader', 'find_latest_checkpoint', 'get_deprecated_model_names',
'get_external_models', 'get_mmcls_models', 'get_state_dict',
'get_torchvision_models', 'load_checkpoint', 'load_state_dict',
'save_checkpoint', 'weights_to_cpu'
]
Loading
Loading