Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(lib): path segments that are too small are masked #132

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion differt/src/differt/geometry/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def group_by_objects(self) -> Int[Array, " *batch"]:

>>> from differt.geometry.paths import Paths
>>>
>>> key = jax.random.PRNGKey(1234)
>>> key = jax.random.key(1234)
>>> key_v, key_o = jax.random.split(key, 2)
>>> *batch, path_length = (2, 6, 3)
>>> vertices = jax.random.uniform(key_v, (*batch, path_length, 3))
Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def sample(
Args:
size: The size of the sample, i.e., the number of triangles.
replace: Whether to sample with or without replacement.
key: The :func:`jax.random.PRNGKey` to be used.
key: The :func:`jax.random.key` to be used.

Returns:
A new random mesh.
Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/rt/image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def image_of_vertices_with_respect_to_mirrors(
... image_of_vertices_with_respect_to_mirrors,
... )
>>>
>>> key = jax.random.PRNGKey(0)
>>> key = jax.random.key(0)
>>> (
... key0,
... key1,
Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/rt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def rays_intersect_any_triangle(

@jaxtyped(typechecker=typechecker)
def scan_fun(
intersect: Bool[Array, " *#batch"],
intersect: Bool[Array, " *batch"],
triangle_vertices: Float[Array, "*#batch 3 3"],
) -> tuple[Bool[Array, " *batch"], None]:
t, hit = rays_intersect_triangles(
Expand Down
39 changes: 30 additions & 9 deletions differt/src/differt/scene/triangle_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ def _compute_paths(
path_candidates: Int[Array, "num_path_candidates order"],
*,
parallel: bool = False,
**kwargs: Any,
epsilon: Float[ArrayLike, " "] | None = None,
hit_tol: Float[ArrayLike, " "] | None = None,
min_len: Float[ArrayLike, " "] | None = None,
) -> Paths:
epsilon = kwargs.pop("epsilon", None)
hit_tol = kwargs.pop("hit_tol", None)
if min_len is None:
dtype = jnp.result_type(mesh.vertices, from_vertices, to_vertices)
min_len = 10 * jnp.finfo(dtype).eps

# 1 - Broadcast arrays

Expand Down Expand Up @@ -144,9 +147,15 @@ def fun(
hit_tol=hit_tol,
).any(axis=-1) # Reduce on 'order'

# TODO: we also need to somehow mask degenerate paths, e.g., when two reflections occur on an edge
# 3.4 - Identify path segments that are too small (e.g., double-reflection inside an edge)

mask = inside_triangles & valid_reflections & ~blocked
ray_lengths = jnp.sum(ray_directions * ray_directions, axis=-1) # Squared norm

too_small = (ray_lengths < min_len).any(
axis=-1
) # Any path segment being too smal

mask = inside_triangles & valid_reflections & ~blocked & ~too_small

return full_paths, mask

Expand Down Expand Up @@ -359,7 +368,9 @@ def compute_paths(
chunk_size: int | None = None,
path_candidates: Int[Array, "num_path_candidates order"] | None = None,
parallel: bool = False,
**kwargs: Any,
epsilon: Float[ArrayLike, " "] | None = None,
hit_tol: Float[ArrayLike, " "] | None = None,
min_len: Float[ArrayLike, " "] | None = None,
) -> Paths | SizedIterator[Paths]:
"""
Compute paths between all pairs of transmitters and receivers in the scene, that undergo a fixed number of interaction with objects.
Expand All @@ -383,8 +394,14 @@ def compute_paths(
parallel: If :data:`True`, ray tracing is performed in parallel across all available
devices. Either the number of transmitters or the number of receivers
**must** be a multiple of :func:`jax.device_count`, otherwise an error is raised.
kwargs: Keyword arguments passed to
epsilon: Tolelance for checking ray / objects intersection, see
:func:`rays_intersect_triangles<differt.rt.utils.rays_intersect_triangles>`.
hit_tol: Tolerance for checking blockage (i.e., obstruction), see
:func:`rays_intersect_any_triangle<differt.rt.utils.rays_intersect_any_triangle>`.
min_len: Minimal (squared) length that each path segment must have for a path to be valid.

If not specified, the default is ten times the epsilon value
of the currently used floating point dtype.

Returns:
The paths, as class wrapping path vertices, object indices, and a masked
Expand Down Expand Up @@ -426,7 +443,9 @@ def compute_paths(
to_vertices,
path_candidates,
parallel=parallel,
**kwargs,
epsilon=epsilon,
hit_tol=hit_tol,
min_len=min_len,
).reshape(*tx_batch, *rx_batch, path_candidates.shape[0])
for path_candidates in path_candidates_iter
)
Expand All @@ -445,7 +464,9 @@ def compute_paths(
to_vertices,
path_candidates,
parallel=parallel,
**kwargs,
epsilon=epsilon,
hit_tol=hit_tol,
min_len=min_len,
).reshape(*tx_batch, *rx_batch, path_candidates.shape[0])

def plot(
Expand Down
6 changes: 3 additions & 3 deletions differt/src/differt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]:
... )
>>>
>>> arr = jnp.arange(10).reshape(5, 2)
>>> key = jax.random.PRNGKey(1234)
>>> key = jax.random.key(1234)
>>> (
... key1,
... key2,
Expand Down Expand Up @@ -175,7 +175,7 @@ def minimize(
>>>
>>> batch = (1, 2, 3)
>>> n = 10
>>> key = jax.random.PRNGKey(1234)
>>> key = jax.random.key(1234)
>>> offset = jax.random.uniform(key, (*batch, n))
>>>
>>> def f(x, offset, scale=2.0):
Expand Down Expand Up @@ -246,7 +246,7 @@ def sample_points_in_bounding_box(
bounding_box: The bounding box (min. and max. coordinates).
size: The sample size or :data:`None`. If :data:`None`,
the returned array is 1D. Otherwise, it is 2D.
key: The :func:`jax.random.PRNGKey` to be used.
key: The :func:`jax.random.key` to be used.

Returns:
An array of points randomly sampled.
Expand Down
2 changes: 1 addition & 1 deletion differt/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def seed() -> int:

@pytest.fixture
def key(seed: int) -> PRNGKeyArray:
return jax.random.PRNGKey(seed)
return jax.random.key(seed)


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions differt/tests/scene/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ def advanced_path_tracing_example_scene(
def simple_street_canyon_scene(sionna_folder: Path) -> TriangleScene:
file = get_sionna_scene("simple_street_canyon", folder=sionna_folder)
scene = TriangleScene.load_xml(file)
scene = eqx.tree_at(lambda s: s.transmitters, scene, jnp.array([-37.0, 14.0, 35.0]))
return eqx.tree_at(lambda s: s.receivers, scene, jnp.array([12.0, 0.0, 35.0]))
scene = eqx.tree_at(lambda s: s.transmitters, scene, jnp.array([-22.0, 0.0, 32.0]))
return eqx.tree_at(lambda s: s.receivers, scene, jnp.array([+22.0, 0.0, 32.0]))
78 changes: 74 additions & 4 deletions differt/tests/scene/test_triangle_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from jaxtyping import Array, Int, PRNGKeyArray

from differt.geometry.paths import Paths
from differt.geometry.utils import assemble_paths, normalize
from differt.scene.sionna import (
get_sionna_scene,
Expand Down Expand Up @@ -97,12 +98,81 @@ def test_compute_paths_on_advanced_path_tracing_example(
with jax.debug_nans(False): # noqa: FBT003
got = scene.compute_paths(order)

chex.assert_trees_all_close(got.masked_vertices, expected_path_vertices) # type: ignore[reportAttributeAccessIssue]
chex.assert_trees_all_equal(got.masked_objects, expected_objects) # type: ignore[reportAttributeAccessIssue]
assert isinstance(got, Paths) # Hint to Pyright

normals = jnp.take(scene.mesh.normals, got.masked_objects[..., 1:-1], axis=0) # type: ignore[reportAttributeAccessIssue]
chex.assert_trees_all_close(got.masked_vertices, expected_path_vertices)
chex.assert_trees_all_equal(got.masked_objects, expected_objects)

rays = jnp.diff(got.masked_vertices, axis=-2) # type: ignore[reportAttributeAccessIssue]
normals = jnp.take(scene.mesh.normals, got.masked_objects[..., 1:-1], axis=0)

rays = jnp.diff(got.masked_vertices, axis=-2)

rays = normalize(rays)[0]

indicents = rays[..., :-1, :]
reflecteds = rays[..., +1:, :]

dot_incidents = jnp.sum(-indicents * normals, axis=-1)
dot_reflecteds = jnp.sum(reflecteds * normals, axis=-1)

chex.assert_trees_all_close(dot_incidents, dot_reflecteds)

@pytest.mark.parametrize(
("order", "expected_path_vertices", "expected_objects"),
[
(0, jnp.empty((1, 0, 3)), jnp.array([[0, 0]])),
(
1,
jnp.array([
[[0.0, -8.613334655761719, 32.0]],
[[0.0, 9.571563720703125, 32.0]],
[[1.9073486328125e-06, 0.0, -0.030788421630859375]],
]),
jnp.array([[0, 18, 0], [0, 38, 0], [0, 72, 0]]),
),
(
2,
jnp.array([
[
[-11.579630851745605, -8.613335609436035, 32.0],
[10.420369148254395, 9.571564674377441, 32.0],
],
[
[-10.420370101928711, 9.571562767028809, 32.0],
[11.579629898071289, -8.613335609436035, 32.0],
],
]),
jnp.array([[0, 19, 39, 0], [0, 38, 18, 0]]),
),
],
)
def test_compute_paths_on_simple_street_canyon(
self,
order: int,
expected_path_vertices: Array,
expected_objects: Array,
simple_street_canyon_scene: TriangleScene,
) -> None:
scene = simple_street_canyon_scene
expected_path_vertices = assemble_paths(
scene.transmitters[None, :],
expected_path_vertices,
scene.receivers[None, :],
)

with jax.debug_nans(False): # noqa: FBT003
got = scene.compute_paths(order)

assert isinstance(got, Paths) # Hint to Pyright

chex.assert_trees_all_close(
got.masked_vertices, expected_path_vertices, atol=1e-5
)
chex.assert_trees_all_equal(got.masked_objects, expected_objects)

normals = jnp.take(scene.mesh.normals, got.masked_objects[..., 1:-1], axis=0)

rays = jnp.diff(got.masked_vertices, axis=-2)

rays = normalize(rays)[0]

Expand Down
2 changes: 1 addition & 1 deletion differt/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def wrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fun)
def _wrapper_(*args: Any, **kwargs: Any) -> Any:
bound_args = sig.bind(*args, **kwargs)
keys = jax.random.split(jax.random.PRNGKey(seed), len(arg_names))
keys = jax.random.split(jax.random.key(seed), len(arg_names))
for key, arg_name in zip(keys, arg_names, strict=False):
shape = bound_args.arguments[arg_name]
bound_args.arguments[arg_name] = sampler(key, shape)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/notebooks/performance_tips.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(1234)\n",
"key = jax.random.key(1234)\n",
"key1, key2, key3 = jax.random.split(key, 3)\n",
"\n",
"batch = (10, 100)\n",
Expand Down Expand Up @@ -250,7 +250,7 @@
"source": [
"from beartype import beartype as typechecker\n",
"\n",
"key = jax.random.PRNGKey(1234)\n",
"key = jax.random.key(1234)\n",
"key1, key2, key3 = jax.random.split(key, 3)\n",
"\n",
"batch = (100, 10, 2)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/type_checking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"metadata": {},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(1234)\n",
"key = jax.random.key(1234)\n",
"\n",
"arr = jax.random.randint(key, (10, 4), 0, 2)\n",
"arr"
Expand Down
Loading