Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Add node rank and local world size info to session #29919

Merged
merged 21 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions python/ray/air/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,66 @@ def train_loop_per_worker():
return session.local_rank


@_warn_session_misuse(default_value=0)
def get_local_world_size() -> int:
"""Get the local rank of this worker (rank of the worker on its node).

Example:
>>> import ray
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> from ray.train.torch import TorchTrainer
>>>
>>> def train_loop_per_worker():
... return session.get_local_world_size()
>>>
>>> train_dataset = ray.data.from_items(
... [{"x": x, "y": x + 1} for x in range(32)])
>>> trainer = TorchTrainer(train_loop_per_worker,
... scaling_config=ScalingConfig(num_workers=1),
... datasets={"train": train_dataset})
>>> trainer.fit() # doctest: +SKIP
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_local_world_size` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.local_world_size


@_warn_session_misuse(default_value=0)
def get_node_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).

Example:
>>> import ray
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> from ray.train.torch import TorchTrainer
>>>
>>> def train_loop_per_worker():
... return session.get_node_rank()
>>>
>>> train_dataset = ray.data.from_items(
... [{"x": x, "y": x + 1} for x in range(32)])
>>> trainer = TorchTrainer(train_loop_per_worker,
... scaling_config=ScalingConfig(num_workers=1),
... datasets={"train": train_dataset})
>>> trainer.fit() # doctest: +SKIP
"""
session = _get_session()
if not isinstance(session, _TrainSessionImpl):
raise RuntimeError(
"`get_node_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.node_rank


@_warn_session_misuse()
def get_dataset_shard(
dataset_name: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ py_test(
size = "large",
srcs = ["tests/test_backend.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib"]
deps = [":train_lib", ":conftest"]
)

py_test(
Expand Down
65 changes: 57 additions & 8 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,12 @@ def set_gpu_ids():
)
ray.get(futures)

def _create_local_rank_map(self) -> Dict:
"""Create mapping from worker world_rank to local_rank.
def _create_rank_world_size_mappings(self) -> List[Dict]:
"""Create rank and world size mappings for workers.
There are three maps returned:
- local_rank_map, which maps from worker world_rank to local_rank.
- local_world_size_map, which maps from world_rank to local_world_size
- node_rank_map, which maps from world rank to node rank

Example:
Worker 0: 0.0.0.0
Expand All @@ -254,23 +258,58 @@ def _create_local_rank_map(self) -> Dict:
Workers 0, 1, 3 are on 0.0.0.0.
Workers 2, 4 are on 0.0.0.1.

Expected Output:
Expected local_rank_map:
{
0 -> 0,
1 -> 1,
2 -> 0,
3 -> 2,
4 -> 1
}

Expected local_world_size_map:
{
0 -> 3,
1 -> 3,
2 -> 2,
3 -> 3,
4 -> 2
}

Expected node_rank_map:
{
0 -> 0,
1 -> 0,
2 -> 1,
3 -> 0,
4 -> 1
}

"""
rank_mapping = {}
ip_dict = defaultdict(int)
local_rank_map = {} # map from world rank to local rank
local_world_size_map = {} # map from world rank to local world size
node_rank_map = {} # map from world rank to node rank
node_ips = {} # map from node ip to node index
node_cnt = 0 # count the number of nodes

ip_dict = defaultdict(int) # map from node ip to the number of workers on it.
for world_rank in range(len(self.worker_group)):
worker = self.worker_group.workers[world_rank]
node_ip = worker.metadata.node_ip
rank_mapping[world_rank] = ip_dict[node_ip]
local_rank_map[world_rank] = ip_dict[node_ip]
ip_dict[node_ip] += 1
return rank_mapping

if node_ip not in node_ips:
node_ips[node_ip] = node_cnt
node_cnt += 1
node_rank_map[world_rank] = node_ips[node_ip]

for world_rank in range(len(self.worker_group)):
worker = self.worker_group.workers[world_rank]
node_ip = worker.metadata.node_ip
local_world_size_map[world_rank] = ip_dict[node_ip]

return local_rank_map, local_world_size_map, node_rank_map

def start_training(
self,
Expand Down Expand Up @@ -301,6 +340,8 @@ def initialize_session(
train_func,
world_rank,
local_rank,
node_rank,
local_world_size,
world_size,
trial_info,
checkpoint,
Expand All @@ -312,6 +353,8 @@ def initialize_session(
training_func=train_func,
world_rank=world_rank,
local_rank=local_rank,
node_rank=node_rank,
local_world_size=local_world_size,
world_size=world_size,
trial_info=trial_info,
dataset_shard=dataset_shard,
Expand All @@ -331,7 +374,11 @@ def initialize_session(
actors = [worker.actor for worker in self.worker_group.workers]
self.dataset_shards = dataset_spec.get_dataset_shards(actors)

local_rank_map = self._create_local_rank_map()
(
local_rank_map,
local_world_size_map,
node_rank_map,
) = self._create_rank_world_size_mappings()

futures = []
for index in range(len(self.worker_group)):
Expand All @@ -341,6 +388,8 @@ def initialize_session(
initialize_session,
world_rank=index,
local_rank=local_rank_map[index],
node_rank=node_rank_map[index],
local_world_size=local_world_size_map[index],
world_size=len(self.worker_group),
trial_info=self._trial_info,
train_func=train_func,
Expand Down
4 changes: 4 additions & 0 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
training_func: Callable,
world_rank: int,
local_rank: int,
node_rank: int,
local_world_size: int,
world_size: int,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
trial_info: Optional[TrialInfo] = None,
Expand All @@ -80,6 +82,8 @@ def __init__(

self.world_rank = world_rank
self.local_rank = local_rank
self.node_rank = node_rank
self.local_world_size = local_world_size
self.world_size = world_size
self.trial_info = trial_info
# TODO(xwjiang): Legacy Ray Train trainer clean up!
Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def world_rank(self) -> int:
def local_rank(self) -> int:
return self._session.local_rank

@property
def local_world_size(self) -> int:
return self._session.local_world_size

@property
def node_rank(self) -> int:
return self._session.node_rank

def get_dataset_shard(
self,
dataset_name: Optional[str] = None,
Expand Down
50 changes: 50 additions & 0 deletions python/ray/train/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,53 @@ def ray_2_node_2_gpu():

ray.shutdown()
cluster.shutdown()


@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


@pytest.fixture
def ray_4_node_4_cpu():
cluster = Cluster()
for _ in range(4):
cluster.add_node(num_cpus=4)

ray.init(address=cluster.address)

yield

ray.shutdown()
cluster.shutdown()


@pytest.fixture
def ray_2_node_4_gpu():
cluster = Cluster()
for _ in range(2):
cluster.add_node(num_cpus=2, num_gpus=4)

ray.init(address=cluster.address)

yield

ray.shutdown()
cluster.shutdown()


@pytest.fixture
def ray_2_node_2_cpu():
cluster = Cluster()
for _ in range(2):
cluster.add_node(num_cpus=2)

ray.init(address=cluster.address)

yield

ray.shutdown()
cluster.shutdown()
Loading