Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Disable gathering the full state dict in RayFSDPStrategy for lightning>2.1 #44569

Merged
merged 7 commits into from
Apr 15, 2024

Conversation

woshiyyya
Copy link
Member

@woshiyyya woshiyyya commented Apr 8, 2024

Why are these changes needed?

lighting 2.0.x does not natively support FSDP state_dict_type. Therefore, we added default state dict gathering logic (#34967) to enable FSDP checkpointing. After 2.1, Lightning inherently supports FSDP state_dict_type, so we no longer need this patch logic.

This PR restricts the patch's applicability to Lightning versions 2.0 through 2.1, enabling users to leverage Lightning's native FSDP integration in versions beyond 2.1.

Related issue number

Closes #44501

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Comment on lines 110 to 112
# Lightning < 2.1 lacks FSDP state_dict_type support.
# (PR: https://github.com/Lightning-AI/pytorch-lightning/pull/17623).
# We need this patch logic to enable FSDP checkpointing between 2.0 and 2.1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just force users to upgrade to versions outside of this range? It's a bit confusing for the behavior to be hardcoded to full state dict ckpt based on the library version.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this fix, the state dict of checkpoint in lightning 2.0.x will be empty.

After offline discussion, we will not raise an error since the hardcoded gathering logic doesn't contradict with the lightning behavior. Instead, we add a notice in RayFSDPStrategy docstring to recommend users upgrade to beyond 2.1 if they want to use FSDP.

python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! some small nits

python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
woshiyyya and others added 4 commits April 8, 2024 16:39
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@justinvyu justinvyu merged commit 0731833 into ray-project:master Apr 15, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[train] RayFSDPStrategy should allow for sharded checkpointing
2 participants