-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Add FullyShardedDataParallel support to TorchTrainer #28096
Conversation
Wow thanks for the contribution @markrogersjr! Will take a closer look later! |
ci/env/install-dependencies.sh
Outdated
@@ -412,7 +412,7 @@ install_dependencies() { | |||
1.5) TORCHVISION_VERSION=0.6.0;; | |||
*) TORCHVISION_VERSION=0.5.0;; | |||
esac | |||
pip install --use-deprecated=legacy-resolver --upgrade torch=="${TORCH_VERSION-1.9.0}" torchvision=="${TORCHVISION_VERSION}" | |||
pip install --use-deprecated=legacy-resolver --upgrade torch=="${TORCH_VERSION-1.11.0}" torchvision=="${TORCHVISION_VERSION}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm is this change necessary? This is only here for legacy reasons for Ray Serve tests... for Ray Train tests it doesn't go inside this conditional.
@@ -51,6 +52,8 @@ def prepare_model( | |||
move_to_device: bool = True, | |||
wrap_ddp: bool = True, | |||
ddp_kwargs: Optional[Dict[str, Any]] = None, | |||
wrap_fsdp: bool = False, | |||
fsdp_kwargs: Optional[Dict[str, Any]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than a separate wrap_fsdp
arg, can we change the API to have a single parallel_strategy
arg that can be set to either "ddp"
, "fsdp"
, or None? This will allow for more extendability if we add more distributed strategies later.
Then we can also consolidate ddp_kwargs
and fsdp_kwargs
.
python/ray/train/tests/test_gpu.py
Outdated
@@ -7,6 +7,7 @@ | |||
import torch | |||
import torchvision | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.distributed.fsdp import FullyShardedDataParallel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import will fail if the user's Pytorch version is less than 1.11 right?
To make sure that users can still use Ray Train with older versions of torch, can we do something like the following:
from distutils import LooseVersion
if LooseVersion(torch.__version__) < LooseVersion("1.11.0"):
FullyShardedDataParallel = None
else:
from torch.distributed.fsdp import FullyShardedDataParallel
Then in prepare_model
, if FullyShardedDataParallel
is None but the user enables it, we can raise an error saying that the user needs to upgrade their torch version to enable FSDP.
python/requirements_ml_docker.txt
Outdated
@@ -5,7 +5,7 @@ tblib | |||
|
|||
# If you make changes to the torch versions, please also make the corresponding changes to `requirements_dl.txt`! | |||
-f https://download.pytorch.org/whl/torch_stable.html | |||
torch==1.9.0+cu111 | |||
torch==1.11.0+cu111 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately our dependencies for CI are a bit of a mess right now so upgrading won't work out of the box.
Instead can we do the following:
- Split out the fsdp test to a separate file (
test_torch_fsdp.py
) - Add this file to
ray/train/tests/BUILD
with the following tags:tags = ["team:ml", "exclusive", "gpu_only", "torch_1_11"]
- Create a new test suite like this: https://github.com/ray-project/ray/blob/master/.buildkite/pipeline.gpu.large.yml#L1-L10, except with the following changes:
- label: ":tv: :steam_locomotive: Train GPU tests (PyTorch 1.11) "
conditions: ["RAY_CI_TRAIN_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- PYTHON=3.7 TRAIN_TESTING=1 ./ci/env/install-dependencies.sh
# Because Python version changed, we need to re-install Ray here
- rm -rf ./python/ray/thirdparty_files; rm -rf ./python/ray/pickle5_files; ./ci/ci.sh build
- pip install -Ur torch==1.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
- ./ci/env/env_info.sh
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=gpu,gpu_only,torch_1_11,-ray_air python/ray/train/...
6568bd5
to
95c8e4b
Compare
8b9a85a
to
a7b7301
Compare
@amogkam Thanks for the feedback, I tried to incorporate your suggestions as closely as possible. Several of the tests are still not passing. I tried tweaking the PyTorch 1.11 test configuration to no avail. Any suggestions? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @markrogersjr! Left some comments- I think these should work
@markrogersjr can you also merge in latest master? |
Signed-off-by: Mark Rogers <m@inmimo.me>
…gs in prepare_model Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
Signed-off-by: Mark Rogers <m@inmimo.me>
7984709
to
d1b3ce6
Compare
Signed-off-by: Mark Rogers <m@inmimo.me>
@amogkam I managed to get most tests to pass, looks like the rest could be failing due to flakiness. Please let me know how I can help from here! |
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Thanks @markrogersjr, this looks great to me! I just pushed some additional changes, primarily for backwards compatibility. But I will make sure to get this merged in! |
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent, thanks @markrogersjr!
@amogkam thank you as well, nice work! |
…ject#28096) As of version 1.11, PyTorch supports automatically sharding large models via FullyShardedDataParallel. This change is necessary to take advantage of this new feature. Signed-off-by: Mark Rogers <m@inmimo.me> Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com> Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com> Signed-off-by: ilee300a <ilee300@anyscale.com>
Why are these changes needed?
As of version 1.11, PyTorch supports automatically sharding large models via
FullyShardedDataParallel
. This change is necessary to take advantage of this new feature.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.