Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Support FSDP Strategy for LightningTrainer #34148

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
import shutil
import torch
import tempfile
import pytorch_lightning as pl

from packaging.version import Version
from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy

if Version(pl.__version__) >= Version("2.0.0"):
from pytorch_lightning.strategies import FSDPStrategy
else:
from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy

import ray
from ray.air import session
Expand Down Expand Up @@ -36,6 +42,21 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
)


class RayFSDPStrategy(FSDPStrategy):
"""Subclass of FSDPStrategy to ensure compatibility with Ray orchestration."""

@property
def root_device(self) -> torch.device:
return ray.train.torch.get_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return dict(
num_replicas=self.world_size,
rank=self.global_rank,
)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved


class RayEnvironment(LightningEnvironment):
"""Setup Lightning DDP training environment for Ray cluster."""

Expand Down
34 changes: 26 additions & 8 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ray.util import PublicAPI
from ray.train.lightning._lightning_utils import (
RayDDPStrategy,
RayFSDPStrategy,
RayEnvironment,
RayDataModule,
RayModelCheckpoint,
Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(self) -> None:
self._module_init_config = {}
self._trainer_init_config = {}
self._trainer_fit_params = {}
self._ddp_strategy_config = {}
self._strategy_config = {}
self._model_checkpoint_config = {}

def module(
Expand All @@ -107,8 +108,9 @@ def trainer(self, **kwargs) -> "LightningConfigBuilder":
"""Set up the configurations of ``pytorch_lightning.Trainer``.

Note that you don't have to specify the `strategy` argument here since the
``LightningTrainer`` creates a DDPStrategy by default. You can set up
advanced configurations for DDPStrategy via the `.ddp_strategy()` method.
``LightningTrainer`` creates a PyTorch Lightning Strategy object with the
configurations specified in the `.strategy()` method. If no configuration
is specified, it creates a DDPStrategy by default.

Args:
kwargs: The initialization arguments for ``pytorch_lightning.Trainer``
Expand Down Expand Up @@ -142,14 +144,19 @@ def fit_params(self, **kwargs) -> "LightningConfigBuilder":
self._trainer_fit_params.update(**kwargs)
return self

def ddp_strategy(self, **kwargs) -> "LightningConfigBuilder":
def strategy(self, name, **kwargs) -> "LightningConfigBuilder":
"""Set up the configurations of ``pytorch_lightning.strategies.DDPStrategy``.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Args:
name: The name of your distributed strategy. You can choose
from "ddp" and "fsdp". Default: "ddp".
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
kwargs: For valid arguments to pass, please refer to:
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DDPStrategy.html
and
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.FSDPStrategy.html
"""
self._ddp_strategy_config.update(**kwargs)
self._strategy_config["_strategy_name"] = name
self._strategy_config.update(**kwargs)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
return self

def checkpointing(self, **kwargs) -> "LightningConfigBuilder":
Expand Down Expand Up @@ -461,7 +468,8 @@ def _lightning_train_loop_per_worker(config):
trainer_fit_params = ptl_config["_trainer_fit_params"]
module_class = ptl_config["_module_class"]
module_init_config = ptl_config["_module_init_config"]
ddp_strategy_config = ptl_config["_ddp_strategy_config"]
strategy_config = ptl_config["_strategy_config"]
strategy_name = strategy_config.pop("_strategy_name", "ddp")
model_checkpoint_config = ptl_config["_model_checkpoint_config"]

# Prepare data
Expand Down Expand Up @@ -519,9 +527,19 @@ def _lightning_train_loop_per_worker(config):
logger.warning(
"`strategy` specified in `LightningConfig.trainer_init_config` "
"will be ignored. LightningTrainer will create a RayDDPStrategy "
"object based on `LightningConfig.ddp_strategy_config`."
"or RayFSDPStrategy object based on `LightningConfigBuilder.strategy()`."
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
)
trainer_config["strategy"] = RayDDPStrategy(**ddp_strategy_config)

assert strategy_name in [
"ddp",
"fsdp",
], "LightningTrainer currently supports 'ddp' and 'fsdp' strategy. "
"Please choose one of them."

if strategy_name == "ddp":
trainer_config["strategy"] = RayDDPStrategy(**strategy_config)
if strategy_name == "fsdp":
trainer_config["strategy"] = RayFSDPStrategy(**strategy_config)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

# LightningTrainer always requires checkpointing
trainer_config["enable_checkpointing"] = True
Expand Down
16 changes: 13 additions & 3 deletions python/ray/train/tests/test_lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,19 @@ def __init__(self) -> None:
)
assert config["_module_init_config"]["input_dim"] == 10
assert config["_trainer_init_config"]["log_every_n_steps"] == 100
assert not config["_ddp_strategy_config"]
assert not config["_strategy_config"]
assert not config["_model_checkpoint_config"]


@pytest.mark.parametrize("strategy", ["ddp", "fsdp"])
@pytest.mark.parametrize("accelerator", ["cpu", "gpu"])
@pytest.mark.parametrize("datasource", ["dataloader", "datamodule"])
def test_trainer_with_native_dataloader(
ray_start_6_cpus_2_gpus, accelerator, datasource
ray_start_6_cpus_2_gpus, strategy, accelerator, datasource
):
if accelerator == "cpu" and strategy == "fsdp":
return

num_epochs = 4
batch_size = 8
num_workers = 2
Expand All @@ -67,6 +71,7 @@ def test_trainer_with_native_dataloader(
LightningConfigBuilder()
.module(LinearModule, input_dim=32, output_dim=4)
.trainer(max_epochs=num_epochs, accelerator=accelerator)
.strategy(strategy)
)

datamodule = DummyDataModule(batch_size, dataset_size)
Expand Down Expand Up @@ -97,8 +102,12 @@ def test_trainer_with_native_dataloader(
assert "val_loss" in results.metrics


@pytest.mark.parametrize("strategy", ["ddp", "fsdp"])
@pytest.mark.parametrize("accelerator", ["cpu", "gpu"])
def test_trainer_with_ray_data(ray_start_6_cpus_2_gpus, accelerator):
def test_trainer_with_ray_data(ray_start_6_cpus_2_gpus, strategy, accelerator):
if accelerator == "cpu" and strategy == "fsdp":
return
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

num_epochs = 4
batch_size = 8
num_workers = 2
Expand All @@ -112,6 +121,7 @@ def test_trainer_with_ray_data(ray_start_6_cpus_2_gpus, accelerator):
LightningConfigBuilder()
.module(cls=LinearModule, input_dim=32, output_dim=4)
.trainer(max_epochs=num_epochs, accelerator=accelerator)
.strategy(strategy)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
.build()
)

Expand Down
1 change: 1 addition & 0 deletions python/requirements/ml/requirements_tune.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pytest-remotedata==0.3.2
lightning-bolts==0.4.0
protobuf==3.19.6
pytorch-lightning==1.6.5
fairscale==0.4.6
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
shortuuid==1.0.1
scikit-optimize==0.9.0
sigopt==7.5.0
Expand Down