Skip to content

Commit

Permalink
Merge branch 'main' into eucap2025
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans authored Oct 17, 2024
2 parents a21b29e + 3e44c6f commit a35cbd8
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
- id: ruff-format
types_or: [python, pyi, jupyter]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.384
rev: v1.1.385
hooks:
- id: pyright
- repo: https://github.com/doublify/pre-commit-rust
Expand Down
70 changes: 50 additions & 20 deletions differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# ruff: noqa: ERA001

import sys
from typing import Any
from typing import Any, Literal, overload

import equinox as eqx
import jax
Expand Down Expand Up @@ -312,14 +312,42 @@ def set_face_colors(
is_leaf=lambda x: x is None,
)

@overload
@classmethod
def plane(
cls,
vertex_a: Float[Array, "3"],
vertex_b: Float[Array, "3"],
vertex_c: Float[Array, "3"],
*,
normal: Literal[None] = None,
side_length: Float[ArrayLike, " "] = 1.0,
rotate: Float[ArrayLike, " "] | None = None,
) -> Self: ...

@overload
@classmethod
def plane(
cls,
vertex_a: Float[Array, "3"],
vertex_b: Literal[None] = None,
vertex_c: Literal[None] = None,
*,
normal: Float[Array, "3"],
side_length: Float[ArrayLike, " "] = 1.0,
rotate: Float[ArrayLike, " "] | None = None,
) -> Self: ...

@classmethod
@jaxtyped(
typechecker=None
) # typing.Self is (currently) not compatible with jaxtyping and beartype
def plane(
cls,
vertex: Float[Array, "3"],
*other_vertices: Float[Array, "3"],
vertex_a: Float[Array, "3"],
vertex_b: Float[Array, "3"] | None = None,
vertex_c: Float[Array, "3"] | None = None,
*,
normal: Float[Array, "3"] | None = None,
side_length: Float[ArrayLike, " "] = 1.0,
rotate: Float[ArrayLike, " "] | None = None,
Expand All @@ -328,10 +356,13 @@ def plane(
Create an plane mesh, made of two triangles.
Args:
vertex: The center of the plane.
other_vertices: Two other vertices that define the plane.
vertex_a: The center of the plane.
vertex_b: Any second vertex on the plane.
This or ``normal`` is required.
This and ``vertex_c``, or ``normal`` is required.
vertex_c: Any third vertex on the plane.
This and ``vertex_b``, or ``normal`` is required.
normal: The plane normal.
Must be of unit length.
Expand All @@ -344,23 +375,22 @@ def plane(
A new plane mesh.
Raises:
ValueError: If neither ``other_vertices`` nor ``normal`` has been provided,
ValueError: If neither ``vertex_b`` and ``vertex_c``, nor ``normal`` have been provided,
or if both have been provided simultaneously.
"""
if (other_vertices == ()) == (normal is None):
msg = "You must specify one of 'other_vertices' or 'normal', not both."
if (vertex_b is None) != (vertex_c is None):
msg = "You must specify either of both of 'vertex_b' and 'vertex_c', or none."
raise ValueError(msg)

if (vertex_b is None) == (normal is None):
msg = "You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both."
raise ValueError(msg)
if other_vertices:
if len(other_vertices) != 2: # noqa: PLR2004
msg = (
"You must provide exactly 3 vertices to create a new plane, "
f"but you provided {len(other_vertices) + 1}."
)
raise ValueError(msg)
u = other_vertices[0] - vertex
v = other_vertices[1] - vertex

if vertex_b is not None:
u = vertex_b - vertex_a
v = vertex_c - vertex_a
w = jnp.cross(u, v)
(normal, _) = normalize(w)
normal = normalize(w)[0]

u, v = orthogonal_basis(
normal, # type: ignore[reportArgumentType]
Expand All @@ -375,7 +405,7 @@ def plane(
rotation_matrix = rotation_matrix_along_axis(rotate, normal)
vertices = (rotation_matrix @ vertices.T).T

vertices += vertex
vertices += vertex_a

triangles = jnp.array([[0, 1, 2], [0, 2, 3]], dtype=int)
return cls(vertices=vertices, triangles=triangles)
Expand Down
28 changes: 28 additions & 0 deletions differt/src/differt/geometry/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for working with 3D geometries."""

from typing import Literal, overload

import equinox as eqx
import jax
import jax.numpy as jnp
Expand All @@ -26,6 +28,32 @@ def pairwise_cross(
return jnp.cross(u[:, None, :], v[None, :, :])


@overload
def normalize(
vector: Float[Array, "*batch 3"],
keepdims: Literal[False] = False,
) -> tuple[Float[Array, "*batch 3"], Float[Array, " *batch"]]: ...


@overload
def normalize(
vector: Float[Array, "*batch 3"],
keepdims: Literal[True],
) -> tuple[Float[Array, "*batch 3"], Float[Array, " *batch 1"]]: ...


# Workaround currently needed,
# see: https://github.com/microsoft/pyright/issues/9149
@overload
def normalize(
vector: Float[Array, "*batch 3"],
keepdims: bool,
) -> (
tuple[Float[Array, "*batch 3"], Float[Array, " *batch"]]
| tuple[Float[Array, "*batch 3"], Float[Array, " *batch 1"]]
): ...


@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def normalize(
Expand Down
62 changes: 51 additions & 11 deletions differt/src/differt/scene/triangle_scene.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Scene made of triangles and utilities."""
# ruff: noqa: ERA001

import math
import sys
import warnings
from collections.abc import Mapping
from typing import Any
from typing import Any, Literal, overload

import equinox as eqx
import jax
Expand Down Expand Up @@ -188,25 +189,25 @@ def fun(
if parallel:
num_devices = jax.device_count()

# TODO: allow also to have i,i mesh if product of both is a multiple of 'num_devices'
if from_vertices.shape[0] % num_devices == 0:
in_specs = (P("i", None), P(None, None))
out_specs = (P("i", None, None, None, None), P("i", None, None))
elif to_vertices.shape[0] % num_devices == 0:
in_specs = (P(None, None), P("i", None))
out_specs = (P(None, "i", None, None, None), P(None, "i", None))
if (from_vertices.shape[0] * to_vertices.shape[0]) % num_devices == 0:
tx_mesh = math.gcd(from_vertices.shape[0], num_devices)
rx_mesh = num_devices // tx_mesh
in_specs = (P("i", None), P("j", None))
out_specs = (P("i", "j", None, None, None), P("i", "j", None))
else:
msg = (
f"Found {num_devices} devices available, "
"but could not find any input with a size that is a multiple of that value. "
"Please user a number of transmitter or receiver points that is a "
"Please user a number of transmitter and receiver points that is a "
f"multiple of {num_devices}."
)
raise ValueError(msg)

fun = shard_map( # type: ignore[reportAssigmentType]
fun,
Mesh(mesh_utils.create_device_mesh((num_devices,)), axis_names=("i",)),
Mesh(
mesh_utils.create_device_mesh((tx_mesh, rx_mesh)), axis_names=("i", "j")
),
in_specs=in_specs,
out_specs=out_specs,
)
Expand Down Expand Up @@ -401,6 +402,45 @@ def load_xml(cls, file: str) -> Self:
core_scene = differt_core.scene.triangle_scene.TriangleScene.load_xml(file)
return cls.from_core(core_scene)

@overload
def compute_paths(
self,
order: int | None = None,
*,
chunk_size: Literal[None] = None,
path_candidates: Int[Array, "num_path_candidates order"] | None = None,
parallel: bool = False,
epsilon: Float[ArrayLike, " "] | None = None,
hit_tol: Float[ArrayLike, " "] | None = None,
min_len: Float[ArrayLike, " "] | None = None,
) -> Paths: ...

@overload
def compute_paths(
self,
order: int | None = None,
*,
chunk_size: int,
path_candidates: Literal[None] = None,
parallel: bool = False,
epsilon: Float[ArrayLike, " "] | None = None,
hit_tol: Float[ArrayLike, " "] | None = None,
min_len: Float[ArrayLike, " "] | None = None,
) -> SizedIterator[Paths]: ...

@overload
def compute_paths(
self,
order: int | None = None,
*,
chunk_size: int,
path_candidates: Int[Array, "num_path_candidates order"],
parallel: bool = False,
epsilon: Float[ArrayLike, " "] | None = None,
hit_tol: Float[ArrayLike, " "] | None = None,
min_len: Float[ArrayLike, " "] | None = None,
) -> Paths: ...

def compute_paths(
self,
order: int | None = None,
Expand Down Expand Up @@ -435,7 +475,7 @@ def compute_paths(
If ``self.mesh.assume_quads`` is :data:`True`, then path candidates are
rounded down toward the nearest even value.
parallel: If :data:`True`, ray tracing is performed in parallel across all available
devices. Either the number of transmitters or the number of receivers
devices. The number of transmitters times the number of receivers
**must** be a multiple of :func:`jax.device_count`, otherwise an error is raised.
epsilon: Tolelance for checking ray / objects intersection, see
:func:`rays_intersect_triangles<differt.rt.utils.rays_intersect_triangles>`.
Expand Down
20 changes: 19 additions & 1 deletion differt/src/differt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from collections.abc import Callable, Iterable, Mapping
from typing import Any
from typing import Any, Literal, overload

import chex
import equinox as eqx
Expand Down Expand Up @@ -231,6 +231,24 @@ def f(
return x, losses[-1]


@overload
def sample_points_in_bounding_box(
bounding_box: Float[Array, "2 3"],
shape: Literal[None] = None,
*,
key: PRNGKeyArray,
) -> Float[Array, "3"]: ...


@overload
def sample_points_in_bounding_box(
bounding_box: Float[Array, "2 3"],
shape: tuple[int, ...],
*,
key: PRNGKeyArray,
) -> Float[Array, "3"]: ...


@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def sample_points_in_bounding_box(
Expand Down
4 changes: 2 additions & 2 deletions differt/tests/benchmarks/test_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def bench_fun() -> None:
def bench_fun() -> None:
scene.compute_paths(
order,
chunk_size=chunk_size,
).vertices.block_until_ready() # type: ignore[reportAttributeAccessIssue]
chunk_size=None,
).vertices.block_until_ready()

_ = benchmark(bench_fun)
40 changes: 31 additions & 9 deletions differt/tests/geometry/test_triangle_mesh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from contextlib import AbstractContextManager
from contextlib import nullcontext as does_not_raise

Expand Down Expand Up @@ -168,25 +169,46 @@ def test_plane(self, key: PRNGKeyArray) -> None:

chex.assert_trees_all_equal(got, expected)

vertices = jax.random.uniform(key, (3, 3))
_ = TriangleMesh.plane(*vertices)
vertex_a, vertex_b, vertex_c = jax.random.uniform(key, (3, 3)).T
_ = TriangleMesh.plane(vertex_a, vertex_b, vertex_c)

with pytest.raises(
ValueError,
match="You must specify one of 'other_vertices' or 'normal', not both.",
match="You must specify either of both of 'vertex_b' and 'vertex_c', or none.",
):
_ = TriangleMesh.plane(*vertices, normal=normal)
_ = TriangleMesh.plane(vertex_a, vertex_b) # type: ignore[reportCallIssue]

vertices = jax.random.uniform(key, (4, 3))
with pytest.raises(
ValueError,
match=re.escape(
"You must specify either of both of 'vertex_b' and 'vertex_c', or none."
),
):
_ = TriangleMesh.plane(vertex_a, vertex_c=vertex_c) # type: ignore[reportCallIssue]

with pytest.raises(
ValueError,
match=re.escape(
"You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both."
),
):
_ = TriangleMesh.plane(vertex_a, vertex_b, vertex_c, normal=normal)

with pytest.raises(ValueError, match="You must provide exactly 3 vertices"):
_ = TriangleMesh.plane(*vertices)
with pytest.raises(
ValueError,
match=re.escape(
"You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both."
),
):
_ = TriangleMesh.plane(vertex_a, vertex_b, vertex_c, normal=normal)

with pytest.raises(
ValueError,
match="You must specify one of 'other_vertices' or 'normal', not both.",
match=re.escape(
"You must specify one of ('vertex_b', 'vertex_c') or 'normal', not both."
),
):
_ = TriangleMesh.plane(center)
_ = TriangleMesh.plane(center) # type: ignore[reportCallIssue]

def test_empty(self) -> None:
assert TriangleMesh.empty().is_empty
Expand Down
Loading

0 comments on commit a35cbd8

Please sign in to comment.