Skip to content

Commit

Permalink
chore(docs): small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Sep 10, 2024
1 parent c3e8674 commit 68cf19e
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 29 deletions.
6 changes: 3 additions & 3 deletions differt/src/differt/em/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Utilities for working with EM fields."""

from typing import Any, Union
from typing import Any

import jax
from beartype import beartype as typechecker
from jaxtyping import Array, Float, jaxtyped
from jaxtyping import Array, ArrayLike, Float, jaxtyped

from ..geometry.utils import path_lengths
from .constants import c
Expand All @@ -14,7 +14,7 @@
@jaxtyped(typechecker=typechecker)
def lengths_to_delays(
lengths: Float[Array, " *#batch"],
speed: Union[float, Float[Array, " *#batch"]] = c,
speed: Float[ArrayLike, " *#batch"] = c,
) -> Float[Array, " *#batch"]:
"""
Compute the delay, in seconds, corresponding to each length.
Expand Down
12 changes: 7 additions & 5 deletions differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp
import numpy as np
from beartype import beartype as typechecker
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, jaxtyped
from jaxtyping import Array, ArrayLike, Bool, Float, Int, PRNGKeyArray, jaxtyped

import differt_core.geometry.triangle_mesh

Expand All @@ -17,6 +17,7 @@
from .utils import normalize, orthogonal_basis, rotation_matrix_along_axis


@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def triangles_contain_vertices_assuming_inside_same_plane(
triangle_vertices: Float[Array, "*batch 3 3"], vertices: Float[Array, "*batch 3"]
Expand Down Expand Up @@ -75,11 +76,12 @@ def triangles_contain_vertices_assuming_inside_same_plane(
return all_pos | all_neg


@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def paths_intersect_triangles(
paths: Float[Array, "*batch path_length 3"],
triangle_vertices: Float[Array, "num_triangles 3 3"],
epsilon: float = 1e-6,
epsilon: Float[ArrayLike, " "] = 1e-6,
) -> Bool[Array, " *batch"]:
"""
Return whether each path intersect with any of the triangles.
Expand Down Expand Up @@ -161,8 +163,8 @@ def plane(
vertex: Float[Array, "3"],
*other_vertices: Float[Array, "3"],
normal: Optional[Float[Array, "3"]] = None,
side_length: float = 1.0,
rotate: Optional[float] = None,
side_length: Float[ArrayLike, " "] = 1.0,
rotate: Optional[Float[ArrayLike, " "]] = None,
) -> "TriangleMesh":
"""
Create an plane mesh, made of two triangles.
Expand Down Expand Up @@ -205,7 +207,7 @@ def plane(
vertices = s * jnp.array([u + v, v - u, -u - v, u - v])

if rotate:
rotation_matrix = rotation_matrix_along_axis(normal)
rotation_matrix = rotation_matrix_along_axis(rotate, normal)
vertices = (rotation_matrix @ vertices.T).T

vertices += vertex
Expand Down
33 changes: 18 additions & 15 deletions differt/src/differt/rt/image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
>>> from differt.plotting import draw_markers, draw_paths, reuse
>>> from differt.rt.image_method import image_method
>>>
>>> from_vertices = jnp.array([[+2.0, -1.0, +0.0]])
>>> to_vertices = jnp.array([[+2.0, +4.0, +0.0]])
>>> from_vertex = jnp.array([+2.0, -1.0, +0.0])
>>> to_vertex = jnp.array([+2.0, +4.0, +0.0])
>>> mirror_vertices = jnp.array(
... [
... [2.8, 2.8, 0.0],
... [3.0, 3.0, 0.0],
... [4.0, 3.4, 0.0],
... ]
... )
Expand All @@ -56,30 +56,33 @@
... )
>>> mirror_normals, _ = normalize(mirror_normals)
>>> path = image_method(
... from_vertices,
... to_vertices,
... from_vertex,
... to_vertex,
... mirror_vertices,
... mirror_normals,
... )
>>> with reuse(backend="plotly") as fig:
>>> with reuse(backend="plotly") as fig: # doctest: +SKIP
... TriangleMesh.plane(
... mirror_vertices[0], normal=mirror_normals[0], rotate=jnp.pi / 4
... ).plot(color="red") # TODO: fix angle
... mirror_vertices[0], normal=mirror_normals[0], rotate=-0.954
... ).plot(color="red")
... TriangleMesh.plane(mirror_vertices[1], normal=mirror_normals[1]).plot(
... color="red"
... )
...
... full_path = jnp.concatenate(
... (
... jnp.expand_dims(from_vertices, -2),
... from_vertex[None, :],
... path,
... jnp.expand_dims(to_vertices, -2),
... to_vertex[None, :],
... ),
... axis=-2,
... axis=0,
... )
... draw_paths(full_path, marker={"color": "green"})
... markers = jnp.concatenate((from_vertices, to_vertices))
... draw_markers(markers, labels=["BS", "UE"])
... draw_paths(full_path, marker={"color": "green"}, name="Final path")
... markers = jnp.vstack((from_vertex, to_vertex))
... draw_markers(
... markers, labels=["BS", "UE"], marker={"color": "black"}, name="BS/UE"
... )
... fig.update_layout(scene_aspectmode="data")
>>> fig # doctest: +SKIP
"""

Expand Down Expand Up @@ -302,7 +305,7 @@ def consecutive_vertices_are_on_same_side_of_mirrors(
This check is needed after using :func:`image_method` because it can return
vertices that are behind a mirror, which causes the path to go through this
mirror, and is someone we want to avoid.
mirror, and is something we want to avoid.
Args:
vertices: An array of vertices, usually describing ray paths.
Expand Down
8 changes: 4 additions & 4 deletions differt/src/differt/rt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Array, Bool, Float, Int, jaxtyped
from jaxtyping import Array, ArrayLike, Bool, Float, Int, jaxtyped

from differt_core.rt.graph import CompleteGraph

Expand Down Expand Up @@ -167,7 +167,7 @@ def rays_intersect_triangles(
ray_origins: Float[Array, "*batch 3"],
ray_directions: Float[Array, "*batch 3"],
triangle_vertices: Float[Array, "*batch 3 3"],
epsilon: Union[float, Float[Array, " "]] = 1e-6,
epsilon: Float[ArrayLike, " "] = 1e-6,
) -> tuple[Float[Array, " *batch"], Bool[Array, " *batch"]]:
"""
Return whether rays intersect corresponding triangles using the Möller-Trumbore algorithm.
Expand Down Expand Up @@ -229,8 +229,8 @@ def rays_intersect_any_triangle(
ray_origins: Float[Array, "*batch 3"],
ray_directions: Float[Array, "*batch 3"],
triangle_vertices: Float[Array, "num_triangles 3 3"],
epsilon: Union[float, Float[Array, " "]] = 1e-6,
hit_threshold: Union[float, Float[Array, " "]] = 0.999,
epsilon: Float[ArrayLike, " "] = 1e-6,
hit_threshold: Float[ArrayLike, " "] = 0.999,
) -> Bool[Array, " *batch"]:
"""
Return whether rays intersect any of the triangles using the Möller-Trumbore algorithm.
Expand Down
4 changes: 2 additions & 2 deletions differt/tests/plotting/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
def test_draw_mesh(
backend: str,
) -> None:
vertices = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], dtype=np.float32)
triangles = np.array([[0, 1, 2], [0, 2, 3]], dtype=np.int32)
vertices = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], dtype=float)
triangles = np.array([[0, 1, 2], [0, 2, 3]], dtype=int)
with use(backend):
_ = draw_mesh(vertices, triangles)

Expand Down

0 comments on commit 68cf19e

Please sign in to comment.