Skip to content

Commit

Permalink
[Train] Support FSDP Strategy for LightningTrainer (ray-project#34148)
Browse files Browse the repository at this point in the history
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Jack He <jackhe2345@gmail.com>
  • Loading branch information
woshiyyya authored and ProjectsByJackHe committed May 4, 2023
1 parent ab6dc60 commit 7de94f2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 16 deletions.
27 changes: 24 additions & 3 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
import shutil
import torch
import tempfile
import pytorch_lightning as pl

from packaging.version import Version
from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy

if Version(pl.__version__) >= Version("2.0.0"):
from pytorch_lightning.strategies import FSDPStrategy
else:
from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy

import ray
from ray.air import session
Expand Down Expand Up @@ -45,6 +51,21 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
)


class RayFSDPStrategy(FSDPStrategy):
"""Subclass of FSDPStrategy to ensure compatibility with Ray orchestration."""

@property
def root_device(self) -> torch.device:
return get_worker_root_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return dict(
num_replicas=self.world_size,
rank=self.global_rank,
)


class RayEnvironment(LightningEnvironment):
"""Setup Lightning DDP training environment for Ray cluster."""

Expand Down
38 changes: 28 additions & 10 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.util import PublicAPI
from ray.train.lightning._lightning_utils import (
RayDDPStrategy,
RayFSDPStrategy,
RayEnvironment,
RayDataModule,
RayModelCheckpoint,
Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(self) -> None:
self._module_init_config = {}
self._trainer_init_config = {}
self._trainer_fit_params = {}
self._ddp_strategy_config = {}
self._strategy_config = {}
self._model_checkpoint_config = {}

def module(
Expand All @@ -107,8 +108,9 @@ def trainer(self, **kwargs) -> "LightningConfigBuilder":
"""Set up the configurations of ``pytorch_lightning.Trainer``.
Note that you don't have to specify the `strategy` argument here since the
``LightningTrainer`` creates a DDPStrategy by default. You can set up
advanced configurations for DDPStrategy via the `.ddp_strategy()` method.
``LightningTrainer`` creates a PyTorch Lightning Strategy object with the
configurations specified in the `.strategy()` method. If no configuration
is specified, it creates a DDPStrategy by default.
Args:
kwargs: The initialization arguments for ``pytorch_lightning.Trainer``
Expand Down Expand Up @@ -142,14 +144,25 @@ def fit_params(self, **kwargs) -> "LightningConfigBuilder":
self._trainer_fit_params.update(**kwargs)
return self

def ddp_strategy(self, **kwargs) -> "LightningConfigBuilder":
"""Set up the configurations of ``pytorch_lightning.strategies.DDPStrategy``.
def strategy(self, name: str = "ddp", **kwargs) -> "LightningConfigBuilder":
"""Set up the configurations of ``pytorch_lightning.strategies.Strategy``.
Args:
name: The name of your distributed strategy. You can choose
from "ddp" and "fsdp". Default: "ddp".
kwargs: For valid arguments to pass, please refer to:
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DDPStrategy.html
and
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.FSDPStrategy.html
"""
self._ddp_strategy_config.update(**kwargs)
if name not in ["ddp", "fsdp"]:
raise ValueError(
"LightningTrainer currently supports 'ddp' and 'fsdp' strategy. "
"Please choose one of them."
)

self._strategy_config["_strategy_name"] = name
self._strategy_config.update(**kwargs)
return self

def checkpointing(self, **kwargs) -> "LightningConfigBuilder":
Expand Down Expand Up @@ -461,7 +474,8 @@ def _lightning_train_loop_per_worker(config):
trainer_fit_params = ptl_config["_trainer_fit_params"]
module_class = ptl_config["_module_class"]
module_init_config = ptl_config["_module_init_config"]
ddp_strategy_config = ptl_config["_ddp_strategy_config"]
strategy_config = ptl_config["_strategy_config"]
strategy_name = strategy_config.pop("_strategy_name", "ddp")
model_checkpoint_config = ptl_config["_model_checkpoint_config"]

# Prepare data
Expand Down Expand Up @@ -518,10 +532,14 @@ def _lightning_train_loop_per_worker(config):
if "strategy" in trainer_config:
logger.warning(
"`strategy` specified in `LightningConfig.trainer_init_config` "
"will be ignored. LightningTrainer will create a RayDDPStrategy "
"object based on `LightningConfig.ddp_strategy_config`."
"will be ignored. LightningTrainer will create a strategy based on "
"the settings passed into `LightningConfigBuilder.strategy()`."
)
trainer_config["strategy"] = RayDDPStrategy(**ddp_strategy_config)

if strategy_name == "ddp":
trainer_config["strategy"] = RayDDPStrategy(**strategy_config)
if strategy_name == "fsdp":
trainer_config["strategy"] = RayFSDPStrategy(**strategy_config)

# LightningTrainer always requires checkpointing
trainer_config["enable_checkpointing"] = True
Expand Down
21 changes: 18 additions & 3 deletions python/ray/train/tests/test_lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def __init__(self) -> None:
):
LightningConfigBuilder().module(cls=LinearModule).trainer(10, 100)

with pytest.raises(
ValueError, match="LightningTrainer currently supports 'ddp' and 'fsdp'"
):
LightningConfigBuilder().strategy(name="dummy_strategy")

config = (
LightningConfigBuilder()
.module(cls=LinearModule, input_dim=10)
Expand All @@ -49,15 +54,19 @@ def __init__(self) -> None:
)
assert config["_module_init_config"]["input_dim"] == 10
assert config["_trainer_init_config"]["log_every_n_steps"] == 100
assert not config["_ddp_strategy_config"]
assert not config["_strategy_config"]
assert not config["_model_checkpoint_config"]


@pytest.mark.parametrize("strategy", ["ddp", "fsdp"])
@pytest.mark.parametrize("accelerator", ["cpu", "gpu"])
@pytest.mark.parametrize("datasource", ["dataloader", "datamodule"])
def test_trainer_with_native_dataloader(
ray_start_6_cpus_2_gpus, accelerator, datasource
ray_start_6_cpus_2_gpus, strategy, accelerator, datasource
):
if accelerator == "cpu" and strategy == "fsdp":
return

num_epochs = 4
batch_size = 8
num_workers = 2
Expand All @@ -67,6 +76,7 @@ def test_trainer_with_native_dataloader(
LightningConfigBuilder()
.module(LinearModule, input_dim=32, output_dim=4)
.trainer(max_epochs=num_epochs, accelerator=accelerator)
.strategy(strategy)
)

datamodule = DummyDataModule(batch_size, dataset_size)
Expand Down Expand Up @@ -97,8 +107,12 @@ def test_trainer_with_native_dataloader(
assert "val_loss" in results.metrics


@pytest.mark.parametrize("strategy", ["ddp", "fsdp"])
@pytest.mark.parametrize("accelerator", ["cpu", "gpu"])
def test_trainer_with_ray_data(ray_start_6_cpus_2_gpus, accelerator):
def test_trainer_with_ray_data(ray_start_6_cpus_2_gpus, strategy, accelerator):
if accelerator == "cpu" and strategy == "fsdp":
return

num_epochs = 4
batch_size = 8
num_workers = 2
Expand All @@ -112,6 +126,7 @@ def test_trainer_with_ray_data(ray_start_6_cpus_2_gpus, accelerator):
LightningConfigBuilder()
.module(cls=LinearModule, input_dim=32, output_dim=4)
.trainer(max_epochs=num_epochs, accelerator=accelerator)
.strategy(strategy)
.build()
)

Expand Down
1 change: 1 addition & 0 deletions python/requirements/ml/requirements_tune.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pytest-remotedata==0.3.2
lightning-bolts==0.4.0
protobuf==3.19.6
pytorch-lightning==1.6.5
fairscale==0.4.6
shortuuid==1.0.1
scikit-optimize==0.9.0
sigopt==7.5.0
Expand Down

0 comments on commit 7de94f2

Please sign in to comment.