Skip to content

Commit

Permalink
[Train] LightningTrainer always resumes from the latest AIR checkpoin…
Browse files Browse the repository at this point in the history
…t during restoration. (ray-project#35617)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
woshiyyya authored and arvind-chandra committed Aug 31, 2023
1 parent 2545d05 commit ebcf5b2
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
8 changes: 6 additions & 2 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,13 @@ def _lightning_train_loop_per_worker(config):

trainer = pl.Trainer(**trainer_config)

# Restore from a previously failed run
checkpoint = session.get_checkpoint()
if checkpoint and "ckpt_path" not in trainer_fit_params:
if checkpoint:
checkpoint_log_message = "Resuming training from an AIR checkpoint."
if "ckpt_path" in trainer_fit_params:
checkpoint_log_message += " `ckpt_path` will be ignored."
logger.info(checkpoint_log_message)

with checkpoint.as_directory() as ckpt_dir:
trainer_fit_params["ckpt_path"] = f"{ckpt_dir}/{MODEL_KEY}"
trainer.fit(lightning_module, **trainer_fit_params)
Expand Down
34 changes: 27 additions & 7 deletions python/ray/train/tests/test_lightning_trainer_restore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pytest
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import ray
from ray.air import RunConfig, CheckpointConfig
Expand Down Expand Up @@ -108,32 +110,50 @@ def test_native_trainer_restore(ray_start_4_cpus_2_gpus):
assert results.checkpoint


def test_air_trainer_restore(ray_start_6_cpus, tmpdir):
@pytest.mark.parametrize("resume_from_ckpt_path", [True, False])
def test_air_trainer_restore(ray_start_6_cpus, tmpdir, resume_from_ckpt_path):
"""Test restore for LightningTrainer from a failed/interrupted trail."""
exp_name = "air_trainer_restore_test"

datamodule = DummyDataModule(8, 256)
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

max_epochs = 5
init_epoch = 1 if resume_from_ckpt_path else 0
error_epoch = 2

# init_epoch -> [error_epoch] -> max_epoch
training_iterations = max_epochs - init_epoch
iterations_since_restore = max_epochs - init_epoch - error_epoch

lightning_config = (
LightningConfigBuilder()
.module(LinearModule, input_dim=32, output_dim=4)
.trainer(max_epochs=5, accelerator="cpu")
.trainer(max_epochs=max_epochs, accelerator="cpu")
.fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader)
.build()
)

if resume_from_ckpt_path:
ckpt_dir = f"{tmpdir}/ckpts"
callback = ModelCheckpoint(dirpath=ckpt_dir, save_last=True)
pl_trainer = pl.Trainer(
max_epochs=init_epoch, accelerator="cpu", callbacks=[callback]
)
pl_model = LinearModule(input_dim=32, output_dim=4)
pl_trainer.fit(pl_model, train_dataloaders=train_loader)
lightning_config.fit_params(ckpt_path=f"{ckpt_dir}/last.ckpt")

scaling_config = ray.air.ScalingConfig(num_workers=2, use_gpu=False)

trainer = LightningTrainer(
lightning_config=lightning_config,
lightning_config=lightning_config.build(),
scaling_config=scaling_config,
run_config=RunConfig(
local_dir=str(tmpdir),
name=exp_name,
checkpoint_config=CheckpointConfig(num_to_keep=1),
callbacks=[FailureInjectionCallback(num_iters=2)],
callbacks=[FailureInjectionCallback(num_iters=error_epoch)],
),
)

Expand All @@ -144,8 +164,8 @@ def test_air_trainer_restore(ray_start_6_cpus, tmpdir):
result = trainer.fit()

assert not result.error
assert result.metrics["training_iteration"] == 5
assert result.metrics["iterations_since_restore"] == 3
assert result.metrics["training_iteration"] == training_iterations
assert result.metrics["iterations_since_restore"] == iterations_since_restore
assert tmpdir / exp_name in result.log_dir.parents


Expand Down

0 comments on commit ebcf5b2

Please sign in to comment.