Skip to content

Commit

Permalink
feat(lib): fix empty scene path tracing and add path_candidates args (
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans authored Oct 7, 2024
1 parent 141bdec commit 35b69e5
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 26 deletions.
9 changes: 6 additions & 3 deletions differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def triangle_vertices(self) -> Float[Array, "{self.num_triangles} 3 3"]:
TODO: improve description.
"""
if self.triangles.size == 0:
return jnp.empty_like(self.vertices, shape=(0, 3, 3))

return jnp.take(self.vertices, self.triangles, axis=0)

@classmethod
Expand Down Expand Up @@ -294,11 +297,11 @@ def plane(
A new plane mesh.
Raises:
ValueError: If one of two ``other_vertices`` or ``normal``
were not provided.
ValueError: If neither ``other_vertices`` nor ``normal`` has been provided,
or if both have been provided simultaneously.
"""
if (other_vertices == ()) == (normal is None):
msg = "You must specify one of `other_vertices` or `normal`, not both."
msg = "You must specify one of 'other_vertices' or 'normal', not both."
raise ValueError(msg)
if other_vertices:
if len(other_vertices) != 2: # noqa: PLR2004
Expand Down
8 changes: 6 additions & 2 deletions differt/src/differt/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,12 @@ def draw_rays(
>>> from differt.plotting import draw_rays
>>>
>>> ray_origins = np.zeros(3)
>>> ray_directions = np.asarray(fibonacci_lattice(50)) # From JAX to NumPy array
>>> ray_origins, ray_directions = np.broadcast_arrays(ray_origins, ray_directions)
>>> ray_directions = np.asarray(
... fibonacci_lattice(50)
... ) # From JAX to NumPy array
>>> ray_origins, ray_directions = np.broadcast_arrays(
... ray_origins, ray_directions
... )
>>> fig = draw_rays(
... ray_origins,
... ray_directions,
Expand Down
12 changes: 9 additions & 3 deletions differt/src/differt/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,13 @@ def set_defaults(backend: str | None = None, **kwargs: Any) -> BackendName:
'matplotlib'
>>> my_plot() # So that now it defaults to 'matplotlib' and color='red'
Using matplotlib backend with args = (), kwargs = {'color': 'red'}
>>> my_plot(backend="vispy") # Of course, the 'vispy' backend is still available
>>> my_plot(
... backend="vispy"
... ) # Of course, the 'vispy' backend is still available
Using vispy backend with args = (), kwargs = {'color': 'red'}
>>> my_plot(backend="vispy", color="green") # And we can also override any default
>>> my_plot(
... backend="vispy", color="green"
... ) # And we can also override any default
Using vispy backend with args = (), kwargs = {'color': 'green'}
>>> dplt.set_defaults("vispy") # Reset all defaults
'vispy'
Expand Down Expand Up @@ -204,7 +208,9 @@ def use(backend: str | None = None, **kwargs: Any) -> Iterator[BackendName]:
>>>
>>> my_plot() # When not specified, use default backend
Using vispy backend with args = (), kwargs = {}
>>> with dplt.use(): # No parameters = reset defaults (except the default backend)
>>> with (
... dplt.use()
... ): # No parameters = reset defaults (except the default backend)
... my_plot()
Using vispy backend with args = (), kwargs = {}
>>> with dplt.use("plotly"): # We can change the default backend
Expand Down
5 changes: 4 additions & 1 deletion differt/src/differt/rt/image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@
... draw_paths(full_path, marker={"color": "green"}, name="Final path")
... markers = jnp.vstack((from_vertex, to_vertex))
... draw_markers(
... markers, labels=["BS", "UE"], marker={"color": "black"}, name="BS/UE"
... markers,
... labels=["BS", "UE"],
... marker={"color": "black"},
... name="BS/UE",
... )
... fig.update_layout(scene_aspectmode="data")
>>> fig # doctest: +SKIP
Expand Down
18 changes: 14 additions & 4 deletions differt/src/differt/rt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,18 @@ def rays_intersect_triangles(
>>> from differt.rt.utils import (
... rays_intersect_triangles,
... )
>>> from differt.scene.sionna import get_sionna_scene, download_sionna_scenes
>>> from differt.scene.sionna import (
... get_sionna_scene,
... download_sionna_scenes,
... )
>>> from differt.scene.triangle_scene import TriangleScene
>>>
>>> download_sionna_scenes()
>>> file = get_sionna_scene("simple_street_canyon")
>>> scene = TriangleScene.load_xml(file)
>>> scene = eqx.tree_at(lambda s: s.transmitters, scene, jnp.array([-33, 0, 32.0]))
>>> scene = eqx.tree_at(
... lambda s: s.transmitters, scene, jnp.array([-33, 0, 32.0])
... )
>>> ray_origins, ray_directions = jnp.broadcast_arrays(
... scene.transmitters, fibonacci_lattice(25)
... )
Expand Down Expand Up @@ -421,13 +426,18 @@ def triangles_visible_from_vertices(
>>> from differt.rt.utils import (
... triangles_visible_from_vertices,
... )
>>> from differt.scene.sionna import get_sionna_scene, download_sionna_scenes
>>> from differt.scene.sionna import (
... get_sionna_scene,
... download_sionna_scenes,
... )
>>> from differt.scene.triangle_scene import TriangleScene
>>>
>>> download_sionna_scenes()
>>> file = get_sionna_scene("simple_street_canyon")
>>> scene = TriangleScene.load_xml(file)
>>> scene = eqx.tree_at(lambda s: s.transmitters, scene, jnp.array([-33, 0, 32.0]))
>>> scene = eqx.tree_at(
... lambda s: s.transmitters, scene, jnp.array([-33, 0, 32.0])
... )
>>> visible_triangles = triangles_visible_from_vertices(
... scene.transmitters,
... scene.mesh.triangle_vertices,
Expand Down
49 changes: 43 additions & 6 deletions differt/src/differt/scene/triangle_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# ruff: noqa: ERA001

import sys
import warnings
from collections.abc import Mapping
from typing import Any

Expand Down Expand Up @@ -57,13 +58,17 @@ def _compute_paths(

# 1 - Broadcast arrays

num_path_candidates = path_candidates.shape[0]
num_path_candidates, order = path_candidates.shape

# [num_path_candidates order 3]
triangles = jnp.take(mesh.triangles, path_candidates, axis=0)
triangles = jnp.take(mesh.triangles, path_candidates, axis=0).reshape(
num_path_candidates, order, 3
) # reshape required if mesh is empty

# [num_path_candidates order 3 3]
triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0)
triangle_vertices = jnp.take(mesh.vertices, triangles, axis=0).reshape(
num_path_candidates, order, 3, 3
) # reshape required if mesh is empty

# [num_path_candidates order 3]
mirror_vertices = triangle_vertices[
Expand Down Expand Up @@ -349,19 +354,32 @@ def load_xml(cls, file: str) -> Self:

def compute_paths(
self,
order: int,
order: int | None,
*,
chunk_size: int | None = None,
path_candidates: Int[Array, "num_path_candidates order"] | None = None,
parallel: bool = False,
**kwargs: Any,
) -> Paths | SizedIterator[Paths]:
"""
Compute paths between all pairs of transmitters and receivers in the scene, that undergo a fixed number of interaction with objects.
Note:
Currently, only :abbr:`LOS (line of sight)` and fixed ``order`` reflection paths are computed,
using the :func:`image_method<differt.rt.image_method.image_method>`. More types of interactions
and path tracing methods will be added in the future, so stay tuned!
Args:
order: The number of interaction, i.e., the number of bounces.
This or ``path_candidates`` must be specified.
chunk_size: If specified, it will iterate through chunks of path
candidates, and yield the result as an iterator over paths chunks.
Unused if ``path_candidates`` is provided.
path_candidates: An option array of path candidates, see :ref:`path_candidates`.
This is helpful to only generate paths on a subset of the scene.
parallel: If :data:`True`, ray tracing is performed in parallel across all available
devices. Either the number of transmitters or the number of receivers
**must** be a multiple of :func:`jax.device_count`, otherwise an error is raised.
Expand All @@ -371,7 +389,19 @@ def compute_paths(
Returns:
The paths, as class wrapping path vertices, object indices, and a masked
identify valid paths.
Raises:
ValueError: If neither ``order`` nor ``path_candidates`` has been provided,
or if both have been provided simultaneously.
"""
if (order is None) == (path_candidates is None):
msg = "You must specify one of 'order' or `path_candidates`, not both."
raise ValueError(msg)
if (chunk_size is not None) and (path_candidates is not None):
msg = "Argument 'chunk_size' is ignored when 'path_candidates' is provided."
warnings.warn(msg, UserWarning, stacklevel=2)
chunk_size = None

# 0 - Constants arrays of chunks
num_triangles = self.mesh.triangles.shape[0]
tx_batch = self.transmitters.shape[:-1]
Expand All @@ -384,7 +414,9 @@ def compute_paths(

if chunk_size:
path_candidates_iter = generate_all_path_candidates_chunks_iter(
num_triangles, order, chunk_size=chunk_size
num_triangles,
order, # type: ignore[reportArgumentType]
chunk_size=chunk_size,
)
size = path_candidates_iter.__len__
it = (
Expand All @@ -401,7 +433,12 @@ def compute_paths(

return SizedIterator(it, size=size)

path_candidates = generate_all_path_candidates(num_triangles, order)
if path_candidates is None:
path_candidates = generate_all_path_candidates(
num_triangles,
order, # type: ignore[reportArgumentType]
)

return _compute_paths(
self.mesh,
from_vertices,
Expand Down
4 changes: 3 additions & 1 deletion differt/src/differt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def minimize(
>>> # You can also change the optimizer and the number of steps
>>> import optax
>>> optimizer = optax.noisy_sgd(learning_rate=0.003)
>>> x, y = minimize(f, jnp.zeros(5), args=(4.0,), steps=10000, optimizer=optimizer)
>>> x, y = minimize(
... f, jnp.zeros(5), args=(4.0,), steps=10000, optimizer=optimizer
... )
>>> chex.assert_trees_all_close(x, 4.0 * jnp.ones(5), rtol=1e-2)
>>> chex.assert_trees_all_close(y, 0.0, atol=1e-3)
Expand Down
4 changes: 2 additions & 2 deletions differt/tests/geometry/test_triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_plane(self, key: PRNGKeyArray) -> None:

with pytest.raises(
ValueError,
match="You must specify one of `other_vertices` or `normal`, not both.",
match="You must specify one of 'other_vertices' or 'normal', not both.",
):
_ = TriangleMesh.plane(*vertices, normal=normal)

Expand All @@ -143,7 +143,7 @@ def test_plane(self, key: PRNGKeyArray) -> None:

with pytest.raises(
ValueError,
match="You must specify one of `other_vertices` or `normal`, not both.",
match="You must specify one of 'other_vertices' or 'normal', not both.",
):
_ = TriangleMesh.plane(center)

Expand Down
71 changes: 67 additions & 4 deletions differt/tests/scene/test_triangle_scene.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterator
from contextlib import AbstractContextManager
from contextlib import nullcontext as does_not_raise
from pathlib import Path
Expand All @@ -7,7 +8,7 @@
import jax
import jax.numpy as jnp
import pytest
from jaxtyping import Array
from jaxtyping import Array, Int, PRNGKeyArray

from differt.geometry.utils import assemble_paths, normalize
from differt.scene.sionna import (
Expand All @@ -17,6 +18,10 @@
from differt.scene.triangle_scene import TriangleScene
from differt_core.scene.sionna import SionnaScene

skip_if_not_8_devices = pytest.mark.skipif(
jax.device_count() != 8, reason="This test assumes there are exactly 8 devices."
)


class TestTriangleScene:
def test_load_xml(self, sionna_folder: Path) -> None:
Expand Down Expand Up @@ -109,6 +114,66 @@ def test_compute_paths_on_advanced_path_tracing_example(

chex.assert_trees_all_close(dot_incidents, dot_reflecteds)

@pytest.mark.parametrize(
("order", "chunk_size", "path_candidates", "expectation"),
[
(0, None, None, does_not_raise()),
(0, 1000, None, does_not_raise()),
(None, None, jnp.empty((1, 0), dtype=jnp.int32), does_not_raise()),
(
0,
None,
jnp.empty((1, 0), dtype=jnp.int32),
pytest.raises(ValueError, match="You must specify one of"),
),
(
None,
1000,
jnp.empty((1, 0), dtype=jnp.int32),
pytest.warns(UserWarning, match="Argument 'chunk_size' is ignored"),
),
],
)
@pytest.mark.parametrize(
"parallel", [False, pytest.param(True, marks=skip_if_not_8_devices)]
)
def test_compute_paths_on_empty_scene(
self,
order: int | None,
chunk_size: int | None,
path_candidates: Int[Array, "num_path_candidates order"] | None,
expectation: AbstractContextManager[Exception],
parallel: bool,
key: PRNGKeyArray,
) -> None:
key_tx, key_rx = jax.random.split(key, 2)

if parallel:
transmitters = jax.random.uniform(key_tx, (8, 3))
else:
transmitters = jax.random.uniform(key_tx, (1, 3))

receivers = jax.random.uniform(key_rx, (1, 3))

scene = TriangleScene(transmitters=transmitters, receivers=receivers)
expected_path_vertices = assemble_paths(
transmitters[:, None, None, None, :],
receivers[None, :, None, None, :],
)

with expectation:
with jax.debug_nans(False): # noqa: FBT003
got = scene.compute_paths(
order=order,
chunk_size=chunk_size,
path_candidates=path_candidates,
parallel=parallel,
)

paths = next(got) if isinstance(got, Iterator) else got

chex.assert_trees_all_close(paths.vertices, expected_path_vertices)

@pytest.mark.parametrize(("m_tx", "n_tx"), [(5, None), (3, 4)])
@pytest.mark.parametrize(("m_rx", "n_rx"), [(2, None), (1, 6)])
def test_compute_paths_on_grid(
Expand Down Expand Up @@ -136,9 +201,7 @@ def test_compute_paths_on_grid(
(n_tx, m_tx, n_rx, m_rx, num_path_candidates, 3, 3),
)

@pytest.mark.skipif(
jax.device_count() != 8, reason="This test assumes there are exactly 8 devices."
)
@skip_if_not_8_devices
@pytest.mark.parametrize(
("m_tx", "n_tx", "m_rx", "n_rx", "expectation"),
[
Expand Down
1 change: 1 addition & 0 deletions docs/source/notebooks/path_candidates.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"id": "ce377447-7214-4981-b828-15d66d123c98",
"metadata": {},
"source": [
"(path_candidates)=\n",
"# Generating path candidates\n",
"\n",
"When performing deterministic, or exact, Ray Tracing, we aim at exactly finding all possibles paths\n",
Expand Down

0 comments on commit 35b69e5

Please sign in to comment.