From 4eefd5d0f7cee9160c54da6235b17d4fe0c7ce69 Mon Sep 17 00:00:00 2001 From: LarryLian <554538252@qq.com> Date: Mon, 8 May 2023 14:37:49 +0800 Subject: [PATCH] [CORE] Add bundles_to_node_id info in placement_group_table Signed-off-by: LarryLian <554538252@qq.com> --- python/ray/_private/state.py | 4 ++++ python/ray/tests/test_placement_group_2.py | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/state.py b/python/ray/_private/state.py index c2d1fc594f17..45aa57d8b601 100644 --- a/python/ray/_private/state.py +++ b/python/ray/_private/state.py @@ -329,6 +329,10 @@ def get_strategy(strategy): bundle.bundle_id.bundle_index: MessageToDict(bundle)["unitResources"] for bundle in placement_group_info.bundles }, + "bundles_to_node_id": { + bundle.bundle_id.bundle_index: binary_to_hex(bundle.node_id) + for bundle in placement_group_info.bundles + }, "strategy": get_strategy(placement_group_info.strategy), "state": get_state(placement_group_info.state), "stats": { diff --git a/python/ray/tests/test_placement_group_2.py b/python/ray/tests/test_placement_group_2.py index 4356090bc0c4..c9b367e429f9 100644 --- a/python/ray/tests/test_placement_group_2.py +++ b/python/ray/tests/test_placement_group_2.py @@ -89,6 +89,9 @@ def test_pending_placement_group_wait(ray_start_cluster, connect_to_client): assert len(ready) == 0 table = ray.util.placement_group_table(placement_group) assert table["state"] == "PENDING" + for i in range(3): + assert len(table["bundles_to_node_id"][i]) == 0 + with pytest.raises(ray.exceptions.GetTimeoutError): ray.get(placement_group.ready(), timeout=0.1) @@ -115,11 +118,24 @@ def test_placement_group_wait(ray_start_cluster, connect_to_client): assert len(ready) == 1 table = ray.util.placement_group_table(placement_group) assert table["state"] == "CREATED" - pg = ray.get(placement_group.ready()) assert pg.bundle_specs == placement_group.bundle_specs assert pg.id.binary() == placement_group.id.binary() + @ray.remote + def get_node_id(): + return ray.get_runtime_context().get_node_id() + + for i in range(2): + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=i, + ) + node_id = ray.get( + get_node_id.options(scheduling_strategy=scheduling_strategy).remote() + ) + assert node_id == table["bundles_to_node_id"][i] + @pytest.mark.parametrize("connect_to_client", [False, True]) def test_schedule_placement_group_when_node_add(ray_start_cluster, connect_to_client):