Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Add FullyShardedDataParallel support to TorchTrainer #28096

Merged
merged 12 commits into from
Sep 7, 2022

Conversation

markrogersjr
Copy link
Contributor

@markrogersjr markrogersjr commented Aug 25, 2022

Why are these changes needed?

As of version 1.11, PyTorch supports automatically sharding large models via FullyShardedDataParallel. This change is necessary to take advantage of this new feature.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@amogkam
Copy link
Contributor

amogkam commented Aug 25, 2022

Wow thanks for the contribution @markrogersjr! Will take a closer look later!

@@ -412,7 +412,7 @@ install_dependencies() {
1.5) TORCHVISION_VERSION=0.6.0;;
*) TORCHVISION_VERSION=0.5.0;;
esac
pip install --use-deprecated=legacy-resolver --upgrade torch=="${TORCH_VERSION-1.9.0}" torchvision=="${TORCHVISION_VERSION}"
pip install --use-deprecated=legacy-resolver --upgrade torch=="${TORCH_VERSION-1.11.0}" torchvision=="${TORCHVISION_VERSION}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm is this change necessary? This is only here for legacy reasons for Ray Serve tests... for Ray Train tests it doesn't go inside this conditional.

@@ -51,6 +52,8 @@ def prepare_model(
move_to_device: bool = True,
wrap_ddp: bool = True,
ddp_kwargs: Optional[Dict[str, Any]] = None,
wrap_fsdp: bool = False,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than a separate wrap_fsdp arg, can we change the API to have a single parallel_strategy arg that can be set to either "ddp", "fsdp", or None? This will allow for more extendability if we add more distributed strategies later.

Then we can also consolidate ddp_kwargs and fsdp_kwargs.

@@ -7,6 +7,7 @@
import torch
import torchvision
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import will fail if the user's Pytorch version is less than 1.11 right?

To make sure that users can still use Ray Train with older versions of torch, can we do something like the following:

from distutils import LooseVersion

if LooseVersion(torch.__version__) < LooseVersion("1.11.0"):
    FullyShardedDataParallel = None
else:
    from torch.distributed.fsdp import FullyShardedDataParallel

Then in prepare_model, if FullyShardedDataParallel is None but the user enables it, we can raise an error saying that the user needs to upgrade their torch version to enable FSDP.

@@ -5,7 +5,7 @@ tblib

# If you make changes to the torch versions, please also make the corresponding changes to `requirements_dl.txt`!
-f https://download.pytorch.org/whl/torch_stable.html
torch==1.9.0+cu111
torch==1.11.0+cu111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately our dependencies for CI are a bit of a mess right now so upgrading won't work out of the box.

Instead can we do the following:

- label: ":tv: :steam_locomotive: Train GPU tests (PyTorch 1.11) "
  conditions: ["RAY_CI_TRAIN_AFFECTED"]
  commands:
    - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
    - PYTHON=3.7 TRAIN_TESTING=1 ./ci/env/install-dependencies.sh
    # Because Python version changed, we need to re-install Ray here
    - rm -rf ./python/ray/thirdparty_files; rm -rf ./python/ray/pickle5_files; ./ci/ci.sh build
    - pip install -Ur torch==1.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
    - ./ci/env/env_info.sh
    - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=gpu,gpu_only,torch_1_11,-ray_air python/ray/train/...

@markrogersjr markrogersjr force-pushed the fsdp branch 8 times, most recently from 6568bd5 to 95c8e4b Compare August 31, 2022 19:40
@markrogersjr markrogersjr requested a review from a team as a code owner August 31, 2022 19:44
@markrogersjr markrogersjr force-pushed the fsdp branch 4 times, most recently from 8b9a85a to a7b7301 Compare September 1, 2022 05:53
@markrogersjr
Copy link
Contributor Author

@amogkam Thanks for the feedback, I tried to incorporate your suggestions as closely as possible. Several of the tests are still not passing. I tried tweaking the PyTorch 1.11 test configuration to no avail. Any suggestions?

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @markrogersjr! Left some comments- I think these should work

.buildkite/pipeline.gpu.large.yml Outdated Show resolved Hide resolved
.buildkite/pipeline.gpu.large.yml Show resolved Hide resolved
.buildkite/pipeline.gpu.large.yml Outdated Show resolved Hide resolved
@amogkam
Copy link
Contributor

amogkam commented Sep 2, 2022

@markrogersjr can you also merge in latest master?

Signed-off-by: Mark Rogers <m@inmimo.me>
…gs in prepare_model

Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
@markrogersjr markrogersjr force-pushed the fsdp branch 3 times, most recently from 7984709 to d1b3ce6 Compare September 2, 2022 05:59
Signed-off-by: Mark Rogers <m@inmimo.me>
@markrogersjr
Copy link
Contributor Author

@amogkam I managed to get most tests to pass, looks like the rest could be failing due to flakiness. Please let me know how I can help from here!

@amogkam amogkam changed the title add fsdp support for torch trainer [Train] Add FullyShardedDataParallel support to TorchTrainer Sep 2, 2022
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
@amogkam
Copy link
Contributor

amogkam commented Sep 2, 2022

Thanks @markrogersjr, this looks great to me! I just pushed some additional changes, primarily for backwards compatibility. But I will make sure to get this merged in!

Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent, thanks @markrogersjr!

@markrogersjr
Copy link
Contributor Author

@amogkam thank you as well, nice work!

@amogkam amogkam merged commit be92ab6 into ray-project:master Sep 7, 2022
@markrogersjr markrogersjr deleted the fsdp branch September 7, 2022 14:43
ilee300a pushed a commit to ilee300a/ray that referenced this pull request Sep 12, 2022
…ject#28096)

As of version 1.11, PyTorch supports automatically sharding large models via FullyShardedDataParallel. This change is necessary to take advantage of this new feature.

Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Signed-off-by: ilee300a <ilee300@anyscale.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants