From e905855569ce6c48fabef4bd4585541a25c84567 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Fri, 14 Apr 2023 15:07:09 -0700 Subject: [PATCH 1/3] init change Signed-off-by: woshiyyya --- python/ray/train/lightning/_lightning_utils.py | 11 ++++++++++- python/ray/train/lightning/lightning_trainer.py | 3 ++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index 43b8e8cf51db2..8f02a97dcb813 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -21,12 +21,21 @@ LIGHTNING_REPORT_STAGE_KEY = "_report_on" +def _get_worker_root_device(): + """Get a single torch device for the current worker.""" + devices = ray.train.torch.get_device() + if isinstance(devices, list): + return devices[0] + else: + return devices + + class RayDDPStrategy(DDPStrategy): """Subclass of DDPStrategy to ensure compatibility with Ray orchestration.""" @property def root_device(self) -> torch.device: - return ray.train.torch.get_device() + return _get_worker_root_device() @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: diff --git a/python/ray/train/lightning/lightning_trainer.py b/python/ray/train/lightning/lightning_trainer.py index 3c82420ba4940..723307ee0332f 100644 --- a/python/ray/train/lightning/lightning_trainer.py +++ b/python/ray/train/lightning/lightning_trainer.py @@ -26,6 +26,7 @@ RayEnvironment, RayDataModule, RayModelCheckpoint, + _get_worker_root_device ) @@ -503,7 +504,7 @@ def _lightning_train_loop_per_worker(config): # Setup trainer's parallel devices if trainer_config.get("accelerator", None) == "gpu": - current_device = ray.train.torch.get_device() + current_device = _get_worker_root_device() trainer_config["devices"] = [current_device.index] # Setup ray cluster environment info From 891c89e23693b7ac36377dd30d6290fe2e0a54ab Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Fri, 14 Apr 2023 15:07:38 -0700 Subject: [PATCH 2/3] wip Signed-off-by: woshiyyya --- python/ray/train/lightning/lightning_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/train/lightning/lightning_trainer.py b/python/ray/train/lightning/lightning_trainer.py index 723307ee0332f..88d5833804955 100644 --- a/python/ray/train/lightning/lightning_trainer.py +++ b/python/ray/train/lightning/lightning_trainer.py @@ -1,5 +1,4 @@ import os -import ray from inspect import isclass from typing import Any, Dict, Optional, Type import pytorch_lightning as pl @@ -26,7 +25,7 @@ RayEnvironment, RayDataModule, RayModelCheckpoint, - _get_worker_root_device + _get_worker_root_device, ) From 732288e4a013de26bd48739b4d81803cdabc8ceb Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Fri, 14 Apr 2023 15:54:36 -0700 Subject: [PATCH 3/3] address comments Signed-off-by: woshiyyya --- python/ray/train/lightning/_lightning_utils.py | 6 +++--- python/ray/train/lightning/lightning_trainer.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index 8f02a97dcb813..afb0ef57eb71f 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -21,8 +21,8 @@ LIGHTNING_REPORT_STAGE_KEY = "_report_on" -def _get_worker_root_device(): - """Get a single torch device for the current worker.""" +def get_worker_root_device(): + """Get the first torch device of the current worker if there are multiple.""" devices = ray.train.torch.get_device() if isinstance(devices, list): return devices[0] @@ -35,7 +35,7 @@ class RayDDPStrategy(DDPStrategy): @property def root_device(self) -> torch.device: - return _get_worker_root_device() + return get_worker_root_device() @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: diff --git a/python/ray/train/lightning/lightning_trainer.py b/python/ray/train/lightning/lightning_trainer.py index 88d5833804955..4e65e320ac81f 100644 --- a/python/ray/train/lightning/lightning_trainer.py +++ b/python/ray/train/lightning/lightning_trainer.py @@ -25,7 +25,7 @@ RayEnvironment, RayDataModule, RayModelCheckpoint, - _get_worker_root_device, + get_worker_root_device, ) @@ -503,7 +503,7 @@ def _lightning_train_loop_per_worker(config): # Setup trainer's parallel devices if trainer_config.get("accelerator", None) == "gpu": - current_device = _get_worker_root_device() + current_device = get_worker_root_device() trainer_config["devices"] = [current_device.index] # Setup ray cluster environment info