diff --git a/differt/src/differt/geometry/paths.py b/differt/src/differt/geometry/paths.py index f12f5662..30114b28 100644 --- a/differt/src/differt/geometry/paths.py +++ b/differt/src/differt/geometry/paths.py @@ -251,7 +251,7 @@ def group_by_objects(self) -> Int[Array, " *batch"]: >>> from differt.geometry.paths import Paths >>> - >>> key = jax.random.PRNGKey(1234) + >>> key = jax.random.key(1234) >>> key_v, key_o = jax.random.split(key, 2) >>> *batch, path_length = (2, 6, 3) >>> vertices = jax.random.uniform(key_v, (*batch, path_length, 3)) diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index f3d536f4..49b35659 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -409,7 +409,7 @@ def sample( Args: size: The size of the sample, i.e., the number of triangles. replace: Whether to sample with or without replacement. - key: The :func:`jax.random.PRNGKey` to be used. + key: The :func:`jax.random.key` to be used. Returns: A new random mesh. diff --git a/differt/src/differt/rt/image_method.py b/differt/src/differt/rt/image_method.py index 44c0fcab..5134fb23 100644 --- a/differt/src/differt/rt/image_method.py +++ b/differt/src/differt/rt/image_method.py @@ -123,7 +123,7 @@ def image_of_vertices_with_respect_to_mirrors( ... image_of_vertices_with_respect_to_mirrors, ... ) >>> - >>> key = jax.random.PRNGKey(0) + >>> key = jax.random.key(0) >>> ( ... key0, ... key1, diff --git a/differt/src/differt/rt/utils.py b/differt/src/differt/rt/utils.py index 57fdd671..a7dafe88 100644 --- a/differt/src/differt/rt/utils.py +++ b/differt/src/differt/rt/utils.py @@ -367,7 +367,7 @@ def rays_intersect_any_triangle( @jaxtyped(typechecker=typechecker) def scan_fun( - intersect: Bool[Array, " *#batch"], + intersect: Bool[Array, " *batch"], triangle_vertices: Float[Array, "*#batch 3 3"], ) -> tuple[Bool[Array, " *batch"], None]: t, hit = rays_intersect_triangles( diff --git a/differt/src/differt/scene/triangle_scene.py b/differt/src/differt/scene/triangle_scene.py index 4d20a707..ae5bf7ec 100644 --- a/differt/src/differt/scene/triangle_scene.py +++ b/differt/src/differt/scene/triangle_scene.py @@ -51,10 +51,13 @@ def _compute_paths( path_candidates: Int[Array, "num_path_candidates order"], *, parallel: bool = False, - **kwargs: Any, + epsilon: Float[ArrayLike, " "] | None = None, + hit_tol: Float[ArrayLike, " "] | None = None, + min_len: Float[ArrayLike, " "] | None = None, ) -> Paths: - epsilon = kwargs.pop("epsilon", None) - hit_tol = kwargs.pop("hit_tol", None) + if min_len is None: + dtype = jnp.result_type(mesh.vertices, from_vertices, to_vertices) + min_len = 10 * jnp.finfo(dtype).eps # 1 - Broadcast arrays @@ -144,9 +147,15 @@ def fun( hit_tol=hit_tol, ).any(axis=-1) # Reduce on 'order' - # TODO: we also need to somehow mask degenerate paths, e.g., when two reflections occur on an edge + # 3.4 - Identify path segments that are too small (e.g., double-reflection inside an edge) - mask = inside_triangles & valid_reflections & ~blocked + ray_lengths = jnp.sum(ray_directions * ray_directions, axis=-1) # Squared norm + + too_small = (ray_lengths < min_len).any( + axis=-1 + ) # Any path segment being too smal + + mask = inside_triangles & valid_reflections & ~blocked & ~too_small return full_paths, mask @@ -359,7 +368,9 @@ def compute_paths( chunk_size: int | None = None, path_candidates: Int[Array, "num_path_candidates order"] | None = None, parallel: bool = False, - **kwargs: Any, + epsilon: Float[ArrayLike, " "] | None = None, + hit_tol: Float[ArrayLike, " "] | None = None, + min_len: Float[ArrayLike, " "] | None = None, ) -> Paths | SizedIterator[Paths]: """ Compute paths between all pairs of transmitters and receivers in the scene, that undergo a fixed number of interaction with objects. @@ -383,8 +394,14 @@ def compute_paths( parallel: If :data:`True`, ray tracing is performed in parallel across all available devices. Either the number of transmitters or the number of receivers **must** be a multiple of :func:`jax.device_count`, otherwise an error is raised. - kwargs: Keyword arguments passed to + epsilon: Tolelance for checking ray / objects intersection, see + :func:`rays_intersect_triangles`. + hit_tol: Tolerance for checking blockage (i.e., obstruction), see :func:`rays_intersect_any_triangle`. + min_len: Minimal (squared) length that each path segment must have for a path to be valid. + + If not specified, the default is ten times the epsilon value + of the currently used floating point dtype. Returns: The paths, as class wrapping path vertices, object indices, and a masked @@ -426,7 +443,9 @@ def compute_paths( to_vertices, path_candidates, parallel=parallel, - **kwargs, + epsilon=epsilon, + hit_tol=hit_tol, + min_len=min_len, ).reshape(*tx_batch, *rx_batch, path_candidates.shape[0]) for path_candidates in path_candidates_iter ) @@ -445,7 +464,9 @@ def compute_paths( to_vertices, path_candidates, parallel=parallel, - **kwargs, + epsilon=epsilon, + hit_tol=hit_tol, + min_len=min_len, ).reshape(*tx_batch, *rx_batch, path_candidates.shape[0]) def plot( diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index afeff339..8595c3fb 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -43,7 +43,7 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]: ... ) >>> >>> arr = jnp.arange(10).reshape(5, 2) - >>> key = jax.random.PRNGKey(1234) + >>> key = jax.random.key(1234) >>> ( ... key1, ... key2, @@ -175,7 +175,7 @@ def minimize( >>> >>> batch = (1, 2, 3) >>> n = 10 - >>> key = jax.random.PRNGKey(1234) + >>> key = jax.random.key(1234) >>> offset = jax.random.uniform(key, (*batch, n)) >>> >>> def f(x, offset, scale=2.0): @@ -246,7 +246,7 @@ def sample_points_in_bounding_box( bounding_box: The bounding box (min. and max. coordinates). size: The sample size or :data:`None`. If :data:`None`, the returned array is 1D. Otherwise, it is 2D. - key: The :func:`jax.random.PRNGKey` to be used. + key: The :func:`jax.random.key` to be used. Returns: An array of points randomly sampled. diff --git a/differt/tests/conftest.py b/differt/tests/conftest.py index 1b62cb23..dc663424 100644 --- a/differt/tests/conftest.py +++ b/differt/tests/conftest.py @@ -22,7 +22,7 @@ def seed() -> int: @pytest.fixture def key(seed: int) -> PRNGKeyArray: - return jax.random.PRNGKey(seed) + return jax.random.key(seed) @pytest.fixture diff --git a/differt/tests/scene/fixtures.py b/differt/tests/scene/fixtures.py index 236c1011..fc0f8d4e 100644 --- a/differt/tests/scene/fixtures.py +++ b/differt/tests/scene/fixtures.py @@ -33,5 +33,5 @@ def advanced_path_tracing_example_scene( def simple_street_canyon_scene(sionna_folder: Path) -> TriangleScene: file = get_sionna_scene("simple_street_canyon", folder=sionna_folder) scene = TriangleScene.load_xml(file) - scene = eqx.tree_at(lambda s: s.transmitters, scene, jnp.array([-37.0, 14.0, 35.0])) - return eqx.tree_at(lambda s: s.receivers, scene, jnp.array([12.0, 0.0, 35.0])) + scene = eqx.tree_at(lambda s: s.transmitters, scene, jnp.array([-22.0, 0.0, 32.0])) + return eqx.tree_at(lambda s: s.receivers, scene, jnp.array([+22.0, 0.0, 32.0])) diff --git a/differt/tests/scene/test_triangle_scene.py b/differt/tests/scene/test_triangle_scene.py index c8ccded9..19028057 100644 --- a/differt/tests/scene/test_triangle_scene.py +++ b/differt/tests/scene/test_triangle_scene.py @@ -10,6 +10,7 @@ import pytest from jaxtyping import Array, Int, PRNGKeyArray +from differt.geometry.paths import Paths from differt.geometry.utils import assemble_paths, normalize from differt.scene.sionna import ( get_sionna_scene, @@ -97,12 +98,81 @@ def test_compute_paths_on_advanced_path_tracing_example( with jax.debug_nans(False): # noqa: FBT003 got = scene.compute_paths(order) - chex.assert_trees_all_close(got.masked_vertices, expected_path_vertices) # type: ignore[reportAttributeAccessIssue] - chex.assert_trees_all_equal(got.masked_objects, expected_objects) # type: ignore[reportAttributeAccessIssue] + assert isinstance(got, Paths) # Hint to Pyright - normals = jnp.take(scene.mesh.normals, got.masked_objects[..., 1:-1], axis=0) # type: ignore[reportAttributeAccessIssue] + chex.assert_trees_all_close(got.masked_vertices, expected_path_vertices) + chex.assert_trees_all_equal(got.masked_objects, expected_objects) - rays = jnp.diff(got.masked_vertices, axis=-2) # type: ignore[reportAttributeAccessIssue] + normals = jnp.take(scene.mesh.normals, got.masked_objects[..., 1:-1], axis=0) + + rays = jnp.diff(got.masked_vertices, axis=-2) + + rays = normalize(rays)[0] + + indicents = rays[..., :-1, :] + reflecteds = rays[..., +1:, :] + + dot_incidents = jnp.sum(-indicents * normals, axis=-1) + dot_reflecteds = jnp.sum(reflecteds * normals, axis=-1) + + chex.assert_trees_all_close(dot_incidents, dot_reflecteds) + + @pytest.mark.parametrize( + ("order", "expected_path_vertices", "expected_objects"), + [ + (0, jnp.empty((1, 0, 3)), jnp.array([[0, 0]])), + ( + 1, + jnp.array([ + [[0.0, -8.613334655761719, 32.0]], + [[0.0, 9.571563720703125, 32.0]], + [[1.9073486328125e-06, 0.0, -0.030788421630859375]], + ]), + jnp.array([[0, 18, 0], [0, 38, 0], [0, 72, 0]]), + ), + ( + 2, + jnp.array([ + [ + [-11.579630851745605, -8.613335609436035, 32.0], + [10.420369148254395, 9.571564674377441, 32.0], + ], + [ + [-10.420370101928711, 9.571562767028809, 32.0], + [11.579629898071289, -8.613335609436035, 32.0], + ], + ]), + jnp.array([[0, 19, 39, 0], [0, 38, 18, 0]]), + ), + ], + ) + def test_compute_paths_on_simple_street_canyon( + self, + order: int, + expected_path_vertices: Array, + expected_objects: Array, + simple_street_canyon_scene: TriangleScene, + ) -> None: + scene = simple_street_canyon_scene + expected_path_vertices = assemble_paths( + scene.transmitters[None, :], + expected_path_vertices, + scene.receivers[None, :], + ) + + with jax.debug_nans(False): # noqa: FBT003 + got = scene.compute_paths(order) + + assert isinstance(got, Paths) # Hint to Pyright + + chex.assert_trees_all_close( + got.masked_vertices, expected_path_vertices, atol=1e-5 + ) + chex.assert_trees_all_equal(got.masked_objects, expected_objects) + + normals = jnp.take(scene.mesh.normals, got.masked_objects[..., 1:-1], axis=0) + + rays = jnp.diff(got.masked_vertices, axis=-2) rays = normalize(rays)[0] diff --git a/differt/tests/utils.py b/differt/tests/utils.py index 836391f0..da593f2b 100644 --- a/differt/tests/utils.py +++ b/differt/tests/utils.py @@ -29,7 +29,7 @@ def wrapper(fun: Callable[..., Any]) -> Callable[..., Any]: @wraps(fun) def _wrapper_(*args: Any, **kwargs: Any) -> Any: bound_args = sig.bind(*args, **kwargs) - keys = jax.random.split(jax.random.PRNGKey(seed), len(arg_names)) + keys = jax.random.split(jax.random.key(seed), len(arg_names)) for key, arg_name in zip(keys, arg_names, strict=False): shape = bound_args.arguments[arg_name] bound_args.arguments[arg_name] = sampler(key, shape) diff --git a/docs/source/notebooks/performance_tips.ipynb b/docs/source/notebooks/performance_tips.ipynb index ecc716d7..ee734bd4 100644 --- a/docs/source/notebooks/performance_tips.ipynb +++ b/docs/source/notebooks/performance_tips.ipynb @@ -109,7 +109,7 @@ }, "outputs": [], "source": [ - "key = jax.random.PRNGKey(1234)\n", + "key = jax.random.key(1234)\n", "key1, key2, key3 = jax.random.split(key, 3)\n", "\n", "batch = (10, 100)\n", @@ -250,7 +250,7 @@ "source": [ "from beartype import beartype as typechecker\n", "\n", - "key = jax.random.PRNGKey(1234)\n", + "key = jax.random.key(1234)\n", "key1, key2, key3 = jax.random.split(key, 3)\n", "\n", "batch = (100, 10, 2)\n", diff --git a/docs/source/notebooks/type_checking.ipynb b/docs/source/notebooks/type_checking.ipynb index 5d230705..1fc04dfc 100644 --- a/docs/source/notebooks/type_checking.ipynb +++ b/docs/source/notebooks/type_checking.ipynb @@ -77,7 +77,7 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.PRNGKey(1234)\n", + "key = jax.random.key(1234)\n", "\n", "arr = jax.random.randint(key, (10, 4), 0, 2)\n", "arr"