Skip to content

Commit

Permalink
feat(lib): extend support for parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Oct 17, 2024
1 parent 451ad49 commit 66bffd6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
21 changes: 11 additions & 10 deletions differt/src/differt/scene/triangle_scene.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Scene made of triangles and utilities."""
# ruff: noqa: ERA001

import math
import sys
import warnings
from collections.abc import Mapping
Expand Down Expand Up @@ -188,25 +189,25 @@ def fun(
if parallel:
num_devices = jax.device_count()

# TODO: allow also to have i,i mesh if product of both is a multiple of 'num_devices'
if from_vertices.shape[0] % num_devices == 0:
in_specs = (P("i", None), P(None, None))
out_specs = (P("i", None, None, None, None), P("i", None, None))
elif to_vertices.shape[0] % num_devices == 0:
in_specs = (P(None, None), P("i", None))
out_specs = (P(None, "i", None, None, None), P(None, "i", None))
if (from_vertices.shape[0] * to_vertices.shape[0]) % num_devices == 0:
tx_mesh = math.gcd(from_vertices.shape[0], num_devices)
rx_mesh = num_devices // tx_mesh
in_specs = (P("i", None), P("j", None))
out_specs = (P("i", "j", None, None, None), P("i", "j", None))
else:
msg = (
f"Found {num_devices} devices available, "
"but could not find any input with a size that is a multiple of that value. "
"Please user a number of transmitter or receiver points that is a "
"Please user a number of transmitter and receiver points that is a "
f"multiple of {num_devices}."
)
raise ValueError(msg)

fun = shard_map( # type: ignore[reportAssigmentType]
fun,
Mesh(mesh_utils.create_device_mesh((num_devices,)), axis_names=("i",)),
Mesh(
mesh_utils.create_device_mesh((tx_mesh, rx_mesh)), axis_names=("i", "j")
),
in_specs=in_specs,
out_specs=out_specs,
)
Expand Down Expand Up @@ -435,7 +436,7 @@ def compute_paths(
If ``self.mesh.assume_quads`` is :data:`True`, then path candidates are
rounded down toward the nearest even value.
parallel: If :data:`True`, ray tracing is performed in parallel across all available
devices. Either the number of transmitters or the number of receivers
devices. Either the number of transmitters times the number of receivers
**must** be a multiple of :func:`jax.device_count`, otherwise an error is raised.
epsilon: Tolelance for checking ray / objects intersection, see
:func:`rays_intersect_triangles<differt.rt.utils.rays_intersect_triangles>`.
Expand Down
4 changes: 3 additions & 1 deletion differt/tests/scene/test_triangle_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ def test_compute_paths_on_grid(
(1, 1, 8, 8, does_not_raise()),
(4, 2, 1, 1, does_not_raise()),
(1, 1, 2, 4, does_not_raise()),
(1, 4, 2, 1, does_not_raise()),
(1, 2, 4, 1, does_not_raise()),
(
7,
1,
Expand All @@ -299,7 +301,7 @@ def test_compute_paths_on_grid(
(
1,
2,
4,
3,
1,
pytest.raises(ValueError, match="Found 8 devices available"),
),
Expand Down

0 comments on commit 66bffd6

Please sign in to comment.