From 66bffd6daa14ca1dcec9ad1e2252ab92c0091448 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Thu, 17 Oct 2024 15:27:21 +0200 Subject: [PATCH] feat(lib): extend support for parallelism --- differt/src/differt/scene/triangle_scene.py | 21 +++++++++++---------- differt/tests/scene/test_triangle_scene.py | 4 +++- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/differt/src/differt/scene/triangle_scene.py b/differt/src/differt/scene/triangle_scene.py index 55f120ef..36a4bc78 100644 --- a/differt/src/differt/scene/triangle_scene.py +++ b/differt/src/differt/scene/triangle_scene.py @@ -1,6 +1,7 @@ """Scene made of triangles and utilities.""" # ruff: noqa: ERA001 +import math import sys import warnings from collections.abc import Mapping @@ -188,25 +189,25 @@ def fun( if parallel: num_devices = jax.device_count() - # TODO: allow also to have i,i mesh if product of both is a multiple of 'num_devices' - if from_vertices.shape[0] % num_devices == 0: - in_specs = (P("i", None), P(None, None)) - out_specs = (P("i", None, None, None, None), P("i", None, None)) - elif to_vertices.shape[0] % num_devices == 0: - in_specs = (P(None, None), P("i", None)) - out_specs = (P(None, "i", None, None, None), P(None, "i", None)) + if (from_vertices.shape[0] * to_vertices.shape[0]) % num_devices == 0: + tx_mesh = math.gcd(from_vertices.shape[0], num_devices) + rx_mesh = num_devices // tx_mesh + in_specs = (P("i", None), P("j", None)) + out_specs = (P("i", "j", None, None, None), P("i", "j", None)) else: msg = ( f"Found {num_devices} devices available, " "but could not find any input with a size that is a multiple of that value. " - "Please user a number of transmitter or receiver points that is a " + "Please user a number of transmitter and receiver points that is a " f"multiple of {num_devices}." ) raise ValueError(msg) fun = shard_map( # type: ignore[reportAssigmentType] fun, - Mesh(mesh_utils.create_device_mesh((num_devices,)), axis_names=("i",)), + Mesh( + mesh_utils.create_device_mesh((tx_mesh, rx_mesh)), axis_names=("i", "j") + ), in_specs=in_specs, out_specs=out_specs, ) @@ -435,7 +436,7 @@ def compute_paths( If ``self.mesh.assume_quads`` is :data:`True`, then path candidates are rounded down toward the nearest even value. parallel: If :data:`True`, ray tracing is performed in parallel across all available - devices. Either the number of transmitters or the number of receivers + devices. Either the number of transmitters times the number of receivers **must** be a multiple of :func:`jax.device_count`, otherwise an error is raised. epsilon: Tolelance for checking ray / objects intersection, see :func:`rays_intersect_triangles`. diff --git a/differt/tests/scene/test_triangle_scene.py b/differt/tests/scene/test_triangle_scene.py index f1013fbd..9947a6cb 100644 --- a/differt/tests/scene/test_triangle_scene.py +++ b/differt/tests/scene/test_triangle_scene.py @@ -289,6 +289,8 @@ def test_compute_paths_on_grid( (1, 1, 8, 8, does_not_raise()), (4, 2, 1, 1, does_not_raise()), (1, 1, 2, 4, does_not_raise()), + (1, 4, 2, 1, does_not_raise()), + (1, 2, 4, 1, does_not_raise()), ( 7, 1, @@ -299,7 +301,7 @@ def test_compute_paths_on_grid( ( 1, 2, - 4, + 3, 1, pytest.raises(ValueError, match="Found 8 devices available"), ),