Skip to content

Commit

Permalink
[Train] LightningTrainer enable checkpoint full dict with FSDP strate…
Browse files Browse the repository at this point in the history
…gy (#34967)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
  • Loading branch information
woshiyyya authored May 5, 2023
1 parent d6a9ef6 commit 3be4491
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 8 deletions.
45 changes: 38 additions & 7 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
import ray
from ray.air import session
from ray.air.constants import MODEL_KEY
from ray.data.datastream import DataIterator
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint

import logging
import shutil
import torch
import tempfile
from packaging.version import Version
from typing import Any, Dict, Optional
from torch.utils.data import IterableDataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy

if Version(pl.__version__) >= Version("2.0.0"):
_LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0")
_TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0")
_TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available()

if _LIGHTNING_GREATER_EQUAL_2_0:
from pytorch_lightning.strategies import FSDPStrategy
else:
from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy

import ray
from ray.air import session
from ray.air.constants import MODEL_KEY
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from torch.utils.data import IterableDataset, DataLoader
from ray.data.datastream import DataIterator
if _TORCH_FSDP_AVAILABLE:
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel,
StateDictType,
)


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,6 +77,25 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
rank=self.global_rank,
)

def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Gathers the full state dict to rank 0 on CPU."""
assert self.model is not None, "Failed to get the state dict for a None model!"

if _LIGHTNING_GREATER_EQUAL_2_0 and _TORCH_FSDP_AVAILABLE:
with FullyShardedDataParallel.state_dict_type(
module=self.model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
),
):
state_dict = self.model.state_dict()
prefix_len = len("_forward_module.")
return {k[prefix_len:]: v for k, v in state_dict.items()}
else:
# Otherwise Lightning uses Fairscale FSDP, no need to unshard by ourself.
return super().lightning_module_state_dict()


class RayEnvironment(LightningEnvironment):
"""Setup Lightning DDP training environment for Ray cluster."""
Expand Down
44 changes: 43 additions & 1 deletion python/ray/train/tests/test_lightning_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import torch.nn as nn
import tempfile

from ray.train.lightning import LightningCheckpoint
import ray
from ray.air.constants import MODEL_KEY
from torch.utils.data import DataLoader
from ray.train.tests.lightning_test_utils import LinearModule, DummyDataModule
from ray.train.lightning import (
LightningCheckpoint,
LightningConfigBuilder,
LightningTrainer,
)


class Net(pl.LightningModule):
Expand Down Expand Up @@ -100,6 +106,42 @@ def test_from_directory():
assert torch.equal(output, checkpoint_output)


def test_fsdp_checkpoint():
num_epochs = 1
batch_size = 8
input_dim = 32
output_dim = 4
dataset_size = 256

datamodule = DummyDataModule(batch_size, dataset_size)

config_builder = (
LightningConfigBuilder()
.module(
LinearModule, input_dim=input_dim, output_dim=output_dim, strategy="fsdp"
)
.trainer(max_epochs=num_epochs, accelerator="gpu")
.strategy("fsdp")
.checkpointing(save_last=True)
.fit_params(datamodule=datamodule)
)

scaling_config = ray.air.ScalingConfig(num_workers=2, use_gpu=True)

trainer = LightningTrainer(
lightning_config=config_builder.build(), scaling_config=scaling_config
)

results = trainer.fit()

with results.checkpoint.as_directory() as checkpoint_dir:
checkpoint = torch.load(f"{checkpoint_dir}/{MODEL_KEY}")
model = LinearModule(input_dim=input_dim, output_dim=output_dim)

for key in model.state_dict().keys():
assert key in checkpoint["state_dict"]


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 3be4491

Please sign in to comment.