From b153bf23072ce670f71f7865f09328d815381ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 19 Feb 2024 18:28:43 +0100 Subject: [PATCH] chore(lib): implementing Fermat path tracing --- docs/source/index.rst | 1 + docs/source/references.bib | 7 +++ python/differt/geometry/triangle_mesh.py | 8 +-- python/differt/geometry/utils.py | 4 +- python/differt/plotting/_core.py | 6 +- python/differt/plotting/_utils.py | 8 +-- python/differt/rt/fermat.py | 72 +++++++++++++++++++++ python/differt/rt/image_method.py | 39 +++++++----- python/differt/rt/utils.py | 8 +-- python/differt/utils.py | 79 +++++++++++++++++++++++- 10 files changed, 199 insertions(+), 33 deletions(-) create mode 100644 python/differt/rt/fermat.py diff --git a/docs/source/index.rst b/docs/source/index.rst index d6707072..68ca6100 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,6 +8,7 @@ DiffeRT: Differentiable Ray Tracing Toolbox for Radio Propagation notebooks/quickstart notebooks/advanced_path_tracing notebooks/plotting_backend + notebooks/diffraction notebooks/ray_tracing_at_city_scale .. toctree:: diff --git a/docs/source/references.bib b/docs/source/references.bib index 97eab436..7f5ab501 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -11,3 +11,10 @@ @inproceedings{mpt-eucap2023 pages = {1--5}, doi = {10.23919/EuCAP57121.2023.10132934} } +@misc{fermat-principle, + title = {Fermat's principle --- {Wikipedia}{,} The Free Encyclopedia}, + author = {{Wikipedia contributors}}, + year = 2024, + url = {https://en.wikipedia.org/w/index.php?title=Fermat%27s_principle&oldid=1206886888}, + note = {[Online; accessed 19-February-2024]} +} diff --git a/python/differt/geometry/triangle_mesh.py b/python/differt/geometry/triangle_mesh.py index 80cfb2ac..41946781 100644 --- a/python/differt/geometry/triangle_mesh.py +++ b/python/differt/geometry/triangle_mesh.py @@ -31,7 +31,7 @@ def triangles_contain_vertices_assuming_inside_same_plane( triangle_vertices: an array of triangle vertices. vertices: an array of vertices that will be checked. - Returns: + Return: A boolean array indicating whether vertices are in the corresponding triangles or not. """ # [*batch 3] @@ -89,7 +89,7 @@ def paths_intersect_triangles( a small portion of the path, to avoid indicating intersection when a path *bounces off* a triangle. - Returns: + Return: A boolean array indicating whether vertices are in the corresponding triangles or not. """ ray_origins = paths[..., :-1, :] @@ -146,7 +146,7 @@ def load_obj(cls, file: Path) -> "TriangleMesh": Args: file: The path to the Wavefront .obj file. - Returns: + Return: The corresponding mesh containing only triangles. """ mesh = _core.geometry.triangle_mesh.TriangleMesh.load_obj(str(file)) @@ -163,7 +163,7 @@ def plot(self, **kwargs: Any) -> Any: kwargs: Keyword arguments passed to :py:func:`draw_mesh`. - Returns: + Return: The resulting plot output. """ return draw_mesh( diff --git a/python/differt/geometry/utils.py b/python/differt/geometry/utils.py index 8c5b20b8..34b44ab1 100644 --- a/python/differt/geometry/utils.py +++ b/python/differt/geometry/utils.py @@ -15,7 +15,7 @@ def pairwise_cross( u: First array of vectors. v: Second array of vectors. - Returns: + Return: A 3D tensor with all cross products. """ return jnp.cross(u[:, None, :], v[None, :, :]) @@ -31,7 +31,7 @@ def normalize( Args: vector: An array of vectors. - Returns: + Return: The normalized vector and their length. :Examples: diff --git a/python/differt/plotting/_core.py b/python/differt/plotting/_core.py index 688ce36d..b1f05fcf 100644 --- a/python/differt/plotting/_core.py +++ b/python/differt/plotting/_core.py @@ -38,7 +38,7 @@ def draw_mesh( or :py:class:`Mesh3d`, depending on the backend. - Returns: + Return: The resulting plot output. """ @@ -102,7 +102,7 @@ def draw_paths( or :py:class:`Scatter3d`, depending on the backend. - Returns: + Return: The resulting plot output. """ @@ -165,7 +165,7 @@ def draw_markers( or :py:class:`Scatter3d`, depending on the backend. - Returns: + Return: The resulting plot output. Warning: diff --git a/python/differt/plotting/_utils.py b/python/differt/plotting/_utils.py index 41021a69..31d7b0e9 100644 --- a/python/differt/plotting/_utils.py +++ b/python/differt/plotting/_utils.py @@ -236,7 +236,7 @@ def view_from_canvas(canvas: SceneCanvas) -> ViewBox: Args: canvas: The canvas that draws the contents of the scene. - Returns: + Return: The view on which contents are displayed. """ from vispy.scene.widgets.viewbox import ViewBox @@ -287,7 +287,7 @@ def process_vispy_kwargs( must ensure that ``view in canvas.central_widget.children`` evaluates to :py:data:`True`. - Returns: + Return: The canvas and view used to display contents. """ from vispy import scene @@ -332,7 +332,7 @@ def process_matplotlib_kwargs( must ensure that ``ax in figure.axes`` evaluates to :py:data:`True`. - Returns: + Return: The figure and axes used to display contents. """ import matplotlib.pyplot as plt @@ -378,7 +378,7 @@ def process_plotly_kwargs( figure (:py:class:`Figure`): The figure that draws contents of the scene. - Returns: + Return: The figure used to display contents. """ import plotly.graph_objects as go diff --git a/python/differt/rt/fermat.py b/python/differt/rt/fermat.py new file mode 100644 index 00000000..8e168e17 --- /dev/null +++ b/python/differt/rt/fermat.py @@ -0,0 +1,72 @@ +""" +Path tracing utilities that utilize Fermat's principle. + +Fermat's principle states that the path taken by a ray between two +given points is the path that can be traveled in the least time +:cite:`fermat-principle`. In a homogeneous medium, +this means that the path of least time is also the path of last distance. + +As a result, this module offers minimization methods for finding ray paths. +""" +from jaxtyping import Array, Float, jaxtyped +from typeguard import typechecked as typechecker + + +@jaxtyped(typechecker=typechecker) +def fermat_path_on_planar_surfaces( + from_vertices: Float[Array, "*batch 3"], + to_vertices: Float[Array, "*batch 3"], + mirror_vertices: Float[Array, "num_mirrors *batch 3"], + mirror_normals: Float[Array, "num_mirrors *batch 3"], +) -> Float[Array, "num_mirrors *batch 3"]: + """ + Return the ray paths between pairs of vertices, that reflect on a given list of mirrors in between. + + Args: + from_vertices: An array of ``from`` vertices, i.e., vertices from which the + ray paths start. In a radio communications context, this is usually + an array of transmitters. + to_vertices: An array of ``to`` vertices, i.e., vertices to which the + ray paths end. In a radio communications context, this is usually + an array of receivers. + mirror_vertices: An array of mirror vertices. For each mirror, any + vertex on the infinite plane that describes the mirror is considered + to be a valid vertex. + mirror_normals: An array of mirror normals, where each normal has a unit + length and if perpendicular to the corresponding mirror. + + Return: + An array of ray paths obtained using Fermat's principle. + + .. note:: + + The paths do not contain the starting and ending vertices. + + You can easily create the complete ray paths using + :func:`jax.numpy.concatenate`: + + .. code-block:: python + + paths = fermat_path_on_planar_surfaces( + from_vertices, + to_vertices, + mirror_vertices, + mirror_normals, + ) + + full_paths = jnp.concatenate( + ( + from_vertices[ + None, + ..., + ], + paths, + to_vertices[ + None, + ..., + ], + ) + ) + """ + _mirror_directions = ... + raise NotImplementedError diff --git a/python/differt/rt/image_method.py b/python/differt/rt/image_method.py index 2fd606e3..16476d64 100644 --- a/python/differt/rt/image_method.py +++ b/python/differt/rt/image_method.py @@ -37,7 +37,7 @@ def image_of_vertices_with_respect_to_mirrors( mirror_normals: An array of mirror normals, where each normal has a unit length and if perpendicular to the corresponding mirror. - Returns: + Return: An array of image vertices. Examples: @@ -112,7 +112,7 @@ def intersection_of_line_segments_with_planes( plane_normals: an array of plane normals, where each normal has a unit length and if perpendicular to the corresponding plane. - Returns: + Return: An array of intersection vertices. """ u = segment_ends - segment_starts @@ -134,13 +134,20 @@ def image_method( Return the ray paths between pairs of vertices, that reflect on a given list of mirrors in between. Args: - from_vertices: *TODO*. - to_vertices: *TODO*. - mirror_vertices: *TODO*. - mirror_normals: *TODO*. + from_vertices: An array of ``from`` vertices, i.e., vertices from which the + ray paths start. In a radio communications context, this is usually + an array of transmitters. + to_vertices: An array of ``to`` vertices, i.e., vertices to which the + ray paths end. In a radio communications context, this is usually + an array of receivers. + mirror_vertices: An array of mirror vertices. For each mirror, any + vertex on the infinite plane that describes the mirror is considered + to be a valid vertex. + mirror_normals: An array of mirror normals, where each normal has a unit + length and if perpendicular to the corresponding mirror. - Returns: - *TODO*. + Return: + An array of ray paths obtained with the image method. .. note:: @@ -225,16 +232,20 @@ def consecutive_vertices_are_on_same_side_of_mirrors( Check if consecutive vertices, but skiping one every other vertex, are on the same side of a given mirror. The number of vertices ``num_vertices`` must be equal to ``num_mirrors + 2``. This check is needed after using :func:`image_method` because it can return - vertices that are behind a mirror, which causes the path to go trough this + vertices that are behind a mirror, which causes the path to go through this mirror, and is someone we want to avoid. Args: - vertices: *TODO*. - mirror_vertices: *TODO*. - mirror_normals: *TODO*. + vertices: An array of vertices, usually describing ray paths. + mirror_vertices: An array of mirror vertices. For each mirror, any + vertex on the infinite plane that describes the mirror is considered + to be a valid vertex. + mirror_normals: An array of mirror normals, where each normal has a unit + length and if perpendicular to the corresponding mirror. - Returns: - *TODO*. + Return: + A boolean array indicating whether pairs of consecutive vertices + are on the same side of the corresponding mirror. """ chex.assert_axis_dimension(vertices, 0, mirror_vertices.shape[0] + 2) diff --git a/python/differt/rt/utils.py b/python/differt/rt/utils.py index 4a907a65..133e1057 100644 --- a/python/differt/rt/utils.py +++ b/python/differt/rt/utils.py @@ -54,7 +54,7 @@ def generate_all_path_candidates( num_primitives: The (positive) number of primitives. order: The path order. An order less than one returns an empty array. - Returns: + Return: An unsigned array with primitive indices on each columns. Its number of columns is actually equal to ``num_primitives * ((num_primitives - 1) ** (order - 1))``. @@ -75,7 +75,7 @@ def generate_all_path_candidates_iter( num_primitives: The (positive) number of primitives. order: The path order. - Returns: + Return: An iterator of unsigned arrays with primitive indices. """ return map( @@ -96,7 +96,7 @@ def generate_all_path_candidates_chunks_iter( order: The path order. chunk_size: The size of each chunk. - Returns: + Return: An iterator of unsigned arrays with primitive indices. """ return map( @@ -133,7 +133,7 @@ def rays_intersect_triangles( triangle edges, a very common case if geometries are planes split into multiple triangles. - Returns: + Return: For each ray, return the scale factor of ``ray_directions`` for the vector to reach the corresponding triangle, and whether the intersection actually lies inside the triangle. diff --git a/python/differt/utils.py b/python/differt/utils.py index 84dac3d3..5d200636 100644 --- a/python/differt/utils.py +++ b/python/differt/utils.py @@ -1,6 +1,10 @@ """General purpose utilities.""" +from typing import Any, Callable + +import jax import jax.numpy as jnp -from jaxtyping import Array, Shaped, jaxtyped +import optax +from jaxtyping import Array, Num, Shaped, jaxtyped from typeguard import typechecked as typechecker @@ -12,7 +16,7 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]: Args: array: The input array. - Returns: + Return: A sorted copy of the input array. Examples: @@ -70,3 +74,74 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]: return array return array[jnp.lexsort(array.T[::-1])] # type: ignore + + +@jaxtyped(typechecker=typechecker) +def minimize( + fun: Callable[[Num[Array, "*batch n"]], Num[Array, " *batch"]], + x0: Num[Array, "*batch n"], + fun_args: tuple | None = None, + fun_kwargs: dict[str, Any] | None = None, + steps: int = 100, + optimizer: optax.GradientTransformation | None = None, +) -> tuple[Num[Array, "*batch n"], Num[Array, " *batch"]]: + """ + Minimize a scalar function of one or more variables. + + Args: + fun: The objective function to be minimized. + x0: The initial guess. + fun_args: Positional arguments to be passed to ``fun``. + fun_kwargs: Keyword arguments to be passed to ``fun``. + steps: The number of steps to perform. + optimizer: The optimizer to use. If not provided, + uses :func:`optax.adam` with a learning rate of ``0.1``. + + Return: + The solution array and the corresponding loss. + + Examples: + The following example shows how to minimize a basic function. + + >>> from differt.utils import minimize + >>> import chex + >>> + >>> def f(x, offset=1.0): + ... x = x - offset + ... return jnp.dot(x, x) + >>> + >>> x, y = minimize(f, jnp.zeros(10)) + >>> chex.assert_trees_all_close(x, jnp.ones(10), rtol=1e-2) + >>> chex.assert_trees_all_close(y, 0.0, atol=1e-4) + >>> + >>> # It is also possible to pass positional arguments + >>> x, y = minimize(f, jnp.zeros(10), fun_args=(2.0,)) + >>> chex.assert_trees_all_close(x, 2.0 * jnp.ones(10), rtol=1e-2) + >>> chex.assert_trees_all_close(y, 0.0, atol=1e-3) + >>> + >>> # Or even keyword arguments + >>> x, y = minimize(f, jnp.zeros(10), fun_kwargs=dict(offset=3.0)) + >>> chex.assert_trees_all_close(x, 3.0 * jnp.ones(10), rtol=1e-2) + >>> chex.assert_trees_all_close(y, 0.0, atol=1e-2) + """ + fun_args = fun_args if fun_args else tuple() + fun_kwargs = fun_kwargs if fun_kwargs else {} + optimizer = optimizer if optimizer else optax.adam(learning_rate=0.1) + + f_and_df = jax.value_and_grad(fun) + opt_state = optimizer.init(x0) + + # Cannot type check because jaxtyping fails with optax.OptState + # @jaxtyped(typechecker=typechecker) + def f( + carry: tuple[Num[Array, "*batch n"], optax.OptState], _: None + ) -> tuple[tuple[Num[Array, "*batch n"], optax.OptState], Num[Array, " *batch"]]: + x, opt_state = carry + loss, grads = f_and_df(x, *fun_args, **fun_kwargs) + updates, opt_state = optimizer.update(grads, opt_state) + x = x + updates + carry = (x, opt_state) + return carry, loss + + (x, _), losses = jax.lax.scan(f, init=(x0, opt_state), xs=None, length=steps) + return x, losses[-1]