Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Checkpointing - PR4] Refactored the CheckpointLoader into a load_checkpoint function #693

Merged
merged 40 commits into from
Mar 11, 2022

Conversation

ravi-mosaicml
Copy link
Contributor

@ravi-mosaicml ravi-mosaicml commented Mar 8, 2022

Since the checkpoint loading happens in Trainer.__init__ (except for the restoration of the rng state), there is no need for a checkpoint loader class.

  1. The CheckpointLoader class is replaced with a function load_checkpoint, and all private members are converted into private, module-level helper functions.
  2. The file downloading portions of the checkpoint loader were refactored into their own, standalone utility at composer.utils.file_retriever, with their own test cases.

This PR is the first in a series for cleaning up the checkpoint API. One of the prerequesites is storing the seed on the state.
Here, only the rank zero seed is stored on state, since only the rank zero state is persisted in a checkpoint. The trainer uses a distributed reduction to share the seed across states, so the same seed will be restored when resuming from checkpointing, even if a seed was not originally specified.

This PR ignores the `seed` parameter passed into the trainer when resuming from a checkpoint. For the time being, if a new seed is desired, the `seed` attribute must be removed from the checkpoint state dict. #497 will introduce a cleaner API for this (edge) use case.
1. RNG serialization / deserialization is moved from `composer.trainer._checkpoint` to `composer.utils.reproducibility`. This change is needed to refactor the checkpoint saver into a public module.
2. Moved helper methods from `composer.trainer._deepspeed` to `composer.core.state` to determine whether the model is deepspeed
3. Added a similar helper for `is_model_ddp`.
3. Refactored how the state_dict was serialized and deserialized to support serialization of `@property`s. Stopped storing leading underscores in the checkpoint, as that is a state implementation detail and not something
that should be persisted through the checkpoint.
…oint` function

Since the checkpoint loading happens in `Trainer.__init__` (except for the restoration of the rng state), there is no need for a checkpoint loader class. This class is replaced with a function `load_checkpoint`, and all private members are converted into private, module-level helper functions.
@ravi-mosaicml ravi-mosaicml removed the request for review from jbloxham March 9, 2022 22:55
Copy link
Contributor

@ajaysaini725 ajaysaini725 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from ravi/i414_p1.2 to dev March 11, 2022 02:49
@ravi-mosaicml ravi-mosaicml merged commit 2a0b253 into dev Mar 11, 2022
@ravi-mosaicml ravi-mosaicml deleted the ravi/i414_p1.3 branch March 11, 2022 04:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants