From 3e44c6fbd9e6e217943e44a2f978b2949459e218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Thu, 17 Oct 2024 17:33:09 +0200 Subject: [PATCH] feat(lib): extend support for parallelism (#148) * feat(lib): extend support for parallelism * chore(lib): add `@overload` to provide better type hint * chore(lib): add more overloads * fix(lib): oupsi * fix(tests): escape regex * fix(docs): remove 'Either' as it is no longer true --- .pre-commit-config.yaml | 2 +- differt/src/differt/geometry/triangle_mesh.py | 70 +++++++++++++------ differt/src/differt/geometry/utils.py | 28 ++++++++ differt/src/differt/scene/triangle_scene.py | 62 +++++++++++++--- differt/src/differt/utils.py | 20 +++++- differt/tests/benchmarks/test_rt.py | 4 +- differt/tests/geometry/test_triangle_mesh.py | 40 ++++++++--- differt/tests/scene/test_triangle_scene.py | 13 ++-- pyproject.toml | 1 + 9 files changed, 188 insertions(+), 52 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 538615b1..a444e560 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: - id: ruff-format types_or: [python, pyi, jupyter] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.384 + rev: v1.1.385 hooks: - id: pyright - repo: https://github.com/doublify/pre-commit-rust diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index afcab8a0..6121e5e9 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -2,7 +2,7 @@ # ruff: noqa: ERA001 import sys -from typing import Any +from typing import Any, Literal, overload import equinox as eqx import jax @@ -312,14 +312,42 @@ def set_face_colors( is_leaf=lambda x: x is None, ) + @overload + @classmethod + def plane( + cls, + vertex_a: Float[Array, "3"], + vertex_b: Float[Array, "3"], + vertex_c: Float[Array, "3"], + *, + normal: Literal[None] = None, + side_length: Float[ArrayLike, " "] = 1.0, + rotate: Float[ArrayLike, " "] | None = None, + ) -> Self: ... + + @overload + @classmethod + def plane( + cls, + vertex_a: Float[Array, "3"], + vertex_b: Literal[None] = None, + vertex_c: Literal[None] = None, + *, + normal: Float[Array, "3"], + side_length: Float[ArrayLike, " "] = 1.0, + rotate: Float[ArrayLike, " "] | None = None, + ) -> Self: ... + @classmethod @jaxtyped( typechecker=None ) # typing.Self is (currently) not compatible with jaxtyping and beartype def plane( cls, - vertex: Float[Array, "3"], - *other_vertices: Float[Array, "3"], + vertex_a: Float[Array, "3"], + vertex_b: Float[Array, "3"] | None = None, + vertex_c: Float[Array, "3"] | None = None, + *, normal: Float[Array, "3"] | None = None, side_length: Float[ArrayLike, " "] = 1.0, rotate: Float[ArrayLike, " "] | None = None, @@ -328,10 +356,13 @@ def plane( Create an plane mesh, made of two triangles. Args: - vertex: The center of the plane. - other_vertices: Two other vertices that define the plane. + vertex_a: The center of the plane. + vertex_b: Any second vertex on the plane. - This or ``normal`` is required. + This and ``vertex_c``, or ``normal`` is required. + vertex_c: Any third vertex on the plane. + + This and ``vertex_b``, or ``normal`` is required. normal: The plane normal. Must be of unit length. @@ -344,23 +375,22 @@ def plane( A new plane mesh. Raises: - ValueError: If neither ``other_vertices`` nor ``normal`` has been provided, + ValueError: If neither ``vertex_b`` and ``vertex_c``, nor ``normal`` have been provided, or if both have been provided simultaneously. """ - if (other_vertices == ()) == (normal is None): - msg = "You must specify one of 'other_vertices' or 'normal', not both." + if (vertex_b is None) != (vertex_c is None): + msg = "You must specify either of both of 'vertex_b' and 'vertex_c', or none." + raise ValueError(msg) + + if (vertex_b is None) == (normal is None): + msg = "You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both." raise ValueError(msg) - if other_vertices: - if len(other_vertices) != 2: # noqa: PLR2004 - msg = ( - "You must provide exactly 3 vertices to create a new plane, " - f"but you provided {len(other_vertices) + 1}." - ) - raise ValueError(msg) - u = other_vertices[0] - vertex - v = other_vertices[1] - vertex + + if vertex_b is not None: + u = vertex_b - vertex_a + v = vertex_c - vertex_a w = jnp.cross(u, v) - (normal, _) = normalize(w) + normal = normalize(w)[0] u, v = orthogonal_basis( normal, # type: ignore[reportArgumentType] @@ -375,7 +405,7 @@ def plane( rotation_matrix = rotation_matrix_along_axis(rotate, normal) vertices = (rotation_matrix @ vertices.T).T - vertices += vertex + vertices += vertex_a triangles = jnp.array([[0, 1, 2], [0, 2, 3]], dtype=int) return cls(vertices=vertices, triangles=triangles) diff --git a/differt/src/differt/geometry/utils.py b/differt/src/differt/geometry/utils.py index 8454254c..d0b9694d 100644 --- a/differt/src/differt/geometry/utils.py +++ b/differt/src/differt/geometry/utils.py @@ -1,5 +1,7 @@ """Utilities for working with 3D geometries.""" +from typing import Literal, overload + import equinox as eqx import jax import jax.numpy as jnp @@ -26,6 +28,32 @@ def pairwise_cross( return jnp.cross(u[:, None, :], v[None, :, :]) +@overload +def normalize( + vector: Float[Array, "*batch 3"], + keepdims: Literal[False] = False, +) -> tuple[Float[Array, "*batch 3"], Float[Array, " *batch"]]: ... + + +@overload +def normalize( + vector: Float[Array, "*batch 3"], + keepdims: Literal[True], +) -> tuple[Float[Array, "*batch 3"], Float[Array, " *batch 1"]]: ... + + +# Workaround currently needed, +# see: https://github.com/microsoft/pyright/issues/9149 +@overload +def normalize( + vector: Float[Array, "*batch 3"], + keepdims: bool, +) -> ( + tuple[Float[Array, "*batch 3"], Float[Array, " *batch"]] + | tuple[Float[Array, "*batch 3"], Float[Array, " *batch 1"]] +): ... + + @eqx.filter_jit @jaxtyped(typechecker=typechecker) def normalize( diff --git a/differt/src/differt/scene/triangle_scene.py b/differt/src/differt/scene/triangle_scene.py index 55f120ef..ffc1db12 100644 --- a/differt/src/differt/scene/triangle_scene.py +++ b/differt/src/differt/scene/triangle_scene.py @@ -1,10 +1,11 @@ """Scene made of triangles and utilities.""" # ruff: noqa: ERA001 +import math import sys import warnings from collections.abc import Mapping -from typing import Any +from typing import Any, Literal, overload import equinox as eqx import jax @@ -188,25 +189,25 @@ def fun( if parallel: num_devices = jax.device_count() - # TODO: allow also to have i,i mesh if product of both is a multiple of 'num_devices' - if from_vertices.shape[0] % num_devices == 0: - in_specs = (P("i", None), P(None, None)) - out_specs = (P("i", None, None, None, None), P("i", None, None)) - elif to_vertices.shape[0] % num_devices == 0: - in_specs = (P(None, None), P("i", None)) - out_specs = (P(None, "i", None, None, None), P(None, "i", None)) + if (from_vertices.shape[0] * to_vertices.shape[0]) % num_devices == 0: + tx_mesh = math.gcd(from_vertices.shape[0], num_devices) + rx_mesh = num_devices // tx_mesh + in_specs = (P("i", None), P("j", None)) + out_specs = (P("i", "j", None, None, None), P("i", "j", None)) else: msg = ( f"Found {num_devices} devices available, " "but could not find any input with a size that is a multiple of that value. " - "Please user a number of transmitter or receiver points that is a " + "Please user a number of transmitter and receiver points that is a " f"multiple of {num_devices}." ) raise ValueError(msg) fun = shard_map( # type: ignore[reportAssigmentType] fun, - Mesh(mesh_utils.create_device_mesh((num_devices,)), axis_names=("i",)), + Mesh( + mesh_utils.create_device_mesh((tx_mesh, rx_mesh)), axis_names=("i", "j") + ), in_specs=in_specs, out_specs=out_specs, ) @@ -401,6 +402,45 @@ def load_xml(cls, file: str) -> Self: core_scene = differt_core.scene.triangle_scene.TriangleScene.load_xml(file) return cls.from_core(core_scene) + @overload + def compute_paths( + self, + order: int | None = None, + *, + chunk_size: Literal[None] = None, + path_candidates: Int[Array, "num_path_candidates order"] | None = None, + parallel: bool = False, + epsilon: Float[ArrayLike, " "] | None = None, + hit_tol: Float[ArrayLike, " "] | None = None, + min_len: Float[ArrayLike, " "] | None = None, + ) -> Paths: ... + + @overload + def compute_paths( + self, + order: int | None = None, + *, + chunk_size: int, + path_candidates: Literal[None] = None, + parallel: bool = False, + epsilon: Float[ArrayLike, " "] | None = None, + hit_tol: Float[ArrayLike, " "] | None = None, + min_len: Float[ArrayLike, " "] | None = None, + ) -> SizedIterator[Paths]: ... + + @overload + def compute_paths( + self, + order: int | None = None, + *, + chunk_size: int, + path_candidates: Int[Array, "num_path_candidates order"], + parallel: bool = False, + epsilon: Float[ArrayLike, " "] | None = None, + hit_tol: Float[ArrayLike, " "] | None = None, + min_len: Float[ArrayLike, " "] | None = None, + ) -> Paths: ... + def compute_paths( self, order: int | None = None, @@ -435,7 +475,7 @@ def compute_paths( If ``self.mesh.assume_quads`` is :data:`True`, then path candidates are rounded down toward the nearest even value. parallel: If :data:`True`, ray tracing is performed in parallel across all available - devices. Either the number of transmitters or the number of receivers + devices. The number of transmitters times the number of receivers **must** be a multiple of :func:`jax.device_count`, otherwise an error is raised. epsilon: Tolelance for checking ray / objects intersection, see :func:`rays_intersect_triangles`. diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index 0a489fff..e0ab13b3 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -2,7 +2,7 @@ import sys from collections.abc import Callable, Iterable, Mapping -from typing import Any +from typing import Any, Literal, overload import chex import equinox as eqx @@ -231,6 +231,24 @@ def f( return x, losses[-1] +@overload +def sample_points_in_bounding_box( + bounding_box: Float[Array, "2 3"], + shape: Literal[None] = None, + *, + key: PRNGKeyArray, +) -> Float[Array, "3"]: ... + + +@overload +def sample_points_in_bounding_box( + bounding_box: Float[Array, "2 3"], + shape: tuple[int, ...], + *, + key: PRNGKeyArray, +) -> Float[Array, "3"]: ... + + @eqx.filter_jit @jaxtyped(typechecker=typechecker) def sample_points_in_bounding_box( diff --git a/differt/tests/benchmarks/test_rt.py b/differt/tests/benchmarks/test_rt.py index 4d934d3e..18b8ad82 100644 --- a/differt/tests/benchmarks/test_rt.py +++ b/differt/tests/benchmarks/test_rt.py @@ -86,7 +86,7 @@ def bench_fun() -> None: def bench_fun() -> None: scene.compute_paths( order, - chunk_size=chunk_size, - ).vertices.block_until_ready() # type: ignore[reportAttributeAccessIssue] + chunk_size=None, + ).vertices.block_until_ready() _ = benchmark(bench_fun) diff --git a/differt/tests/geometry/test_triangle_mesh.py b/differt/tests/geometry/test_triangle_mesh.py index 18f20714..3946a3da 100644 --- a/differt/tests/geometry/test_triangle_mesh.py +++ b/differt/tests/geometry/test_triangle_mesh.py @@ -1,4 +1,5 @@ import logging +import re from contextlib import AbstractContextManager from contextlib import nullcontext as does_not_raise @@ -168,25 +169,46 @@ def test_plane(self, key: PRNGKeyArray) -> None: chex.assert_trees_all_equal(got, expected) - vertices = jax.random.uniform(key, (3, 3)) - _ = TriangleMesh.plane(*vertices) + vertex_a, vertex_b, vertex_c = jax.random.uniform(key, (3, 3)).T + _ = TriangleMesh.plane(vertex_a, vertex_b, vertex_c) with pytest.raises( ValueError, - match="You must specify one of 'other_vertices' or 'normal', not both.", + match="You must specify either of both of 'vertex_b' and 'vertex_c', or none.", ): - _ = TriangleMesh.plane(*vertices, normal=normal) + _ = TriangleMesh.plane(vertex_a, vertex_b) # type: ignore[reportCallIssue] - vertices = jax.random.uniform(key, (4, 3)) + with pytest.raises( + ValueError, + match=re.escape( + "You must specify either of both of 'vertex_b' and 'vertex_c', or none." + ), + ): + _ = TriangleMesh.plane(vertex_a, vertex_c=vertex_c) # type: ignore[reportCallIssue] + + with pytest.raises( + ValueError, + match=re.escape( + "You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both." + ), + ): + _ = TriangleMesh.plane(vertex_a, vertex_b, vertex_c, normal=normal) - with pytest.raises(ValueError, match="You must provide exactly 3 vertices"): - _ = TriangleMesh.plane(*vertices) + with pytest.raises( + ValueError, + match=re.escape( + "You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both." + ), + ): + _ = TriangleMesh.plane(vertex_a, vertex_b, vertex_c, normal=normal) with pytest.raises( ValueError, - match="You must specify one of 'other_vertices' or 'normal', not both.", + match=re.escape( + "You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both." + ), ): - _ = TriangleMesh.plane(center) + _ = TriangleMesh.plane(center) # type: ignore[reportCallIssue] def test_empty(self) -> None: assert TriangleMesh.empty().is_empty diff --git a/differt/tests/scene/test_triangle_scene.py b/differt/tests/scene/test_triangle_scene.py index f1013fbd..0415ce8c 100644 --- a/differt/tests/scene/test_triangle_scene.py +++ b/differt/tests/scene/test_triangle_scene.py @@ -10,7 +10,6 @@ import pytest from jaxtyping import Array, Int, PRNGKeyArray -from differt.geometry.paths import Paths from differt.geometry.utils import assemble_paths, normalize from differt.scene.sionna import ( get_sionna_scene, @@ -103,8 +102,6 @@ def test_compute_paths_on_advanced_path_tracing_example( with jax.debug_nans(False): # noqa: FBT003 got = scene.compute_paths(order) - assert isinstance(got, Paths) # Hint to Pyright - chex.assert_trees_all_close(got.masked_vertices, expected_path_vertices) chex.assert_trees_all_equal(got.masked_objects, expected_objects) @@ -173,8 +170,6 @@ def test_compute_paths_on_simple_street_canyon( with jax.debug_nans(False): # noqa: FBT003 got = scene.compute_paths(order) - assert isinstance(got, Paths) # Hint to Pyright - chex.assert_trees_all_close( got.masked_vertices, expected_path_vertices, atol=1e-5 ) @@ -277,7 +272,7 @@ def test_compute_paths_on_grid( num_path_candidates = scene.mesh.triangles.shape[0] chex.assert_shape( - paths.vertices, # type: ignore[reportAttributeAccessIssue] + paths.vertices, (n_tx, m_tx, n_rx, m_rx, num_path_candidates, 3, 3), ) @@ -289,6 +284,8 @@ def test_compute_paths_on_grid( (1, 1, 8, 8, does_not_raise()), (4, 2, 1, 1, does_not_raise()), (1, 1, 2, 4, does_not_raise()), + (1, 4, 2, 1, does_not_raise()), + (1, 2, 4, 1, does_not_raise()), ( 7, 1, @@ -299,7 +296,7 @@ def test_compute_paths_on_grid( ( 1, 2, - 4, + 3, 1, pytest.raises(ValueError, match="Found 8 devices available"), ), @@ -324,7 +321,7 @@ def test_compute_paths_parallel( num_path_candidates = scene.mesh.triangles.shape[0] chex.assert_shape( - paths.vertices, # type: ignore[reportAttributeAccessIssue] + paths.vertices, (n_tx, m_tx, n_rx, m_rx, num_path_candidates, 3, 3), ) diff --git a/pyproject.toml b/pyproject.toml index 7df381f2..a2e91493 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,7 @@ extend-ignore = [ "ISC001", # single-line-implicit-string-concatenation, conflicts with formatter "ISC002", # multi-line-implicit-string-concatenation, conflicts with formatter "PD", # pandas-vet + "PLR0904", # too-many-public-methods, counts @overload... "PLR0913", # too-many-arguments "PLR0914", # too-many-local-variables "PLR6104", # non-augmented-assignment