diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index 43b8e8cf51db..afb0ef57eb71 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -21,12 +21,21 @@ LIGHTNING_REPORT_STAGE_KEY = "_report_on" +def get_worker_root_device(): + """Get the first torch device of the current worker if there are multiple.""" + devices = ray.train.torch.get_device() + if isinstance(devices, list): + return devices[0] + else: + return devices + + class RayDDPStrategy(DDPStrategy): """Subclass of DDPStrategy to ensure compatibility with Ray orchestration.""" @property def root_device(self) -> torch.device: - return ray.train.torch.get_device() + return get_worker_root_device() @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: diff --git a/python/ray/train/lightning/lightning_trainer.py b/python/ray/train/lightning/lightning_trainer.py index 3c82420ba494..4e65e320ac81 100644 --- a/python/ray/train/lightning/lightning_trainer.py +++ b/python/ray/train/lightning/lightning_trainer.py @@ -1,5 +1,4 @@ import os -import ray from inspect import isclass from typing import Any, Dict, Optional, Type import pytorch_lightning as pl @@ -26,6 +25,7 @@ RayEnvironment, RayDataModule, RayModelCheckpoint, + get_worker_root_device, ) @@ -503,7 +503,7 @@ def _lightning_train_loop_per_worker(config): # Setup trainer's parallel devices if trainer_config.get("accelerator", None) == "gpu": - current_device = ray.train.torch.get_device() + current_device = get_worker_root_device() trainer_config["devices"] = [current_device.index] # Setup ray cluster environment info