Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(lib): unroll image method and use_scan=True/False #119

Merged
merged 10 commits into from
Sep 28, 2024
2 changes: 2 additions & 0 deletions differt/src/differt/rt/image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,14 @@ def backward(
forward,
init=from_vertices,
xs=(mirror_vertices, mirror_normals),
unroll=True,
)
_, paths = jax.lax.scan(
backward,
init=to_vertices,
xs=(mirror_vertices, mirror_normals, images),
reverse=True,
unroll=True,
)

return jnp.moveaxis(paths, 0, -2) # Put 'num_mirrors' axis at the end
Expand Down
127 changes: 82 additions & 45 deletions differt/src/differt/rt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def rays_intersect_any_triangle(
triangle_vertices: Float[Array, "*#batch num_triangles 3 3"],
*,
hit_tol: Float[ArrayLike, " "] | None = None,
use_scan: bool = True,
**kwargs: Any,
) -> Bool[Array, " *batch"]:
"""
Expand All @@ -328,6 +329,9 @@ def rays_intersect_any_triangle(

If not specified, the default is ten times the epsilon value
of the currently used floating point dtype.
use_scan: Whether to use :func:`jax.lax.scan` to potentially avoid
allocating multiple arrays of size ``#batch num_triangles 3 3``,
at the cost of a slower runtime.
kwargs: Keyword arguments passed to
:func:`rays_intersect_triangles`.

Expand All @@ -340,32 +344,44 @@ def rays_intersect_any_triangle(

hit_threshold = 1.0 - hit_tol

# Put 'num_triangles' axis as leading axis
triangle_vertices = jnp.moveaxis(triangle_vertices, -3, 0)
if use_scan:
# Put 'num_triangles' axis as leading axis
triangle_vertices = jnp.moveaxis(triangle_vertices, -3, 0)

batch = jnp.broadcast_shapes(
ray_origins.shape[:-1], ray_directions.shape[:-1], triangle_vertices.shape[1:-2]
)

@jaxtyped(typechecker=typechecker)
def scan_fun(
intersect: Bool[Array, " *#batch"],
triangle_vertices: Float[Array, "*#batch 3 3"],
) -> tuple[Bool[Array, " *batch"], None]:
t, hit = rays_intersect_triangles(
ray_origins,
ray_directions,
triangle_vertices,
**kwargs,
batch = jnp.broadcast_shapes(
ray_origins.shape[:-1],
ray_directions.shape[:-1],
triangle_vertices.shape[1:-2],
)
intersect = intersect | ((t < hit_threshold) & hit)
return intersect, None

return jax.lax.scan(
scan_fun,
init=jnp.zeros(batch, dtype=bool),
xs=triangle_vertices,
)[0]
@jaxtyped(typechecker=typechecker)
def scan_fun(
intersect: Bool[Array, " *#batch"],
triangle_vertices: Float[Array, "*#batch 3 3"],
) -> tuple[Bool[Array, " *batch"], None]:
t, hit = rays_intersect_triangles(
ray_origins,
ray_directions,
triangle_vertices,
**kwargs,
)
intersect = intersect | ((t < hit_threshold) & hit)
return intersect, None

return jax.lax.scan(
scan_fun,
init=jnp.zeros(batch, dtype=bool),
xs=triangle_vertices,
)[0]

t, hit = rays_intersect_triangles(
ray_origins[..., None, :],
ray_directions[..., None, :],
triangle_vertices,
**kwargs,
)

return ((t < hit_threshold) & hit).any(axis=-1)


@eqx.filter_jit
Expand All @@ -374,6 +390,8 @@ def triangles_visible_from_vertices(
vertices: Float[Array, "*#batch 3"],
triangle_vertices: Float[Array, "*#batch num_triangles 3 3"],
num_rays: int = int(1e6),
*,
use_scan: bool = True,
**kwargs: Any,
) -> Bool[Array, "*batch num_triangles"]:
"""
Expand All @@ -392,6 +410,9 @@ def triangles_visible_from_vertices(
num_rays: The number of rays to launch.

The larger, the more accurate.
use_scan: Whether to use :func:`jax.lax.scan` to potentially avoid
allocating multiple arrays of size ``#batch num_triangles num_rays 3 3``,
at the cost of a slower runtime.
kwargs: Keyword arguments passed to
:func:`rays_intersect_triangles`.

Expand Down Expand Up @@ -434,28 +455,44 @@ def triangles_visible_from_vertices(
# [num_rays 3]
ray_directions = fibonacci_lattice(num_rays)

batch = jnp.broadcast_shapes(ray_origins.shape[:-1], triangle_vertices.shape[:-3])

@jaxtyped(typechecker=typechecker)
def scan_fun(
visible: Bool[Array, "*batch num_triangles"],
ray_direction: Float[Array, "3"],
) -> tuple[Bool[Array, " *batch num_triangles"], None]:
t, hit = rays_intersect_triangles(
ray_origins[..., None, :],
ray_direction[..., None, :],
triangle_vertices,
**kwargs,
)
# A triangle is visible if it is the first triangle to be intersected by a ray.
visible = visible | (
t == jnp.min(t, axis=-1, keepdims=True, initial=jnp.inf, where=hit)
if use_scan:
batch = jnp.broadcast_shapes(
ray_origins.shape[:-1], triangle_vertices.shape[:-3]
)

return visible, None
@jaxtyped(typechecker=typechecker)
def scan_fun(
visible: Bool[Array, "*batch num_triangles"],
ray_direction: Float[Array, "3"],
) -> tuple[Bool[Array, " *batch num_triangles"], None]:
t, hit = rays_intersect_triangles(
ray_origins[..., None, :],
ray_direction[..., None, :],
triangle_vertices,
**kwargs,
)
# A triangle is visible if it is the first triangle to be intersected by a ray.
visible = visible | (
t == jnp.min(t, axis=-1, keepdims=True, initial=jnp.inf, where=hit)
)

return visible, None

return jax.lax.scan(
scan_fun,
init=jnp.zeros((*batch, triangle_vertices.shape[-3]), dtype=bool),
xs=ray_directions,
)[0]

# TODO: test swapping axes to see if it improves performances

t, hit = rays_intersect_triangles(
ray_origins[..., None, None, :],
ray_directions,
triangle_vertices[..., :, None, :, :],
**kwargs,
)

return jax.lax.scan(
scan_fun,
init=jnp.zeros((*batch, triangle_vertices.shape[-3]), dtype=bool),
xs=ray_directions,
)[0]
return (t == jnp.min(t, axis=-2, keepdims=True, initial=jnp.inf, where=hit)).any(
axis=-1
)
4 changes: 2 additions & 2 deletions differt/tests/benchmarks/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
__all__ = (
"basic_planar_mirrors_setup",
"large_random_planar_mirrors_setup",
"simple_street_canyon_scene",
"sionna_folder",
)


from ..rt.fixtures import basic_planar_mirrors_setup
from ..scene.fixtures import simple_street_canyon_scene, sionna_folder
from .fixtures import large_random_planar_mirrors_setup
22 changes: 22 additions & 0 deletions differt/tests/benchmarks/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import jax
import pytest

from ..rt.utils import PlanarMirrorsSetup

from jaxtyping import PRNGKeyArray


@pytest.fixture
def large_random_planar_mirrors_setup(key: PRNGKeyArray) -> PlanarMirrorsSetup:
num_mirrors = 3
num_path_candidates = 10_000

key_from, key_to, key_m_vertices, key_m_normals, key_paths = jax.random.split(key, 5)

from_vertices = jax.random.uniform(key_from, (3,))
to_vertices = jax.random.uniform(key_to, (3,))
mirror_vertices = jax.random.uniform(key_m_vertices, (num_path_candidates, num_mirrors, 3))
mirror_normals = jax.random.uniform(key_m_normals, (num_path_candidates, num_mirrors, 3))
paths = jax.random.uniform(key_paths, (num_path_candidates, num_mirrors, 3))

return PlanarMirrorsSetup(from_vertices, to_vertices, mirror_vertices, mirror_normals, paths)
47 changes: 20 additions & 27 deletions differt/tests/benchmarks/test_rt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
import pytest
from jaxtyping import PRNGKeyArray
from pytest_codspeed import BenchmarkFixture

from differt.rt.fermat import fermat_path_on_planar_mirrors
Expand All @@ -12,29 +11,13 @@

from ..rt.utils import PlanarMirrorsSetup

batches = pytest.mark.parametrize(
"batch",
[
(),
(10,),
(
10,
20,
30,
),
],
)


@pytest.mark.benchmark(group="image_method")
@batches
def test_image_method(
batch: tuple[int, ...],
basic_planar_mirrors_setup: PlanarMirrorsSetup,
key: PRNGKeyArray,
large_random_planar_mirrors_setup: PlanarMirrorsSetup,
benchmark: BenchmarkFixture,
) -> None:
setup = basic_planar_mirrors_setup.broadcast_to(*batch).add_noeffect_noise(key=key)
setup = large_random_planar_mirrors_setup
_ = benchmark(
lambda: image_method(
setup.from_vertices,
Expand All @@ -46,14 +29,11 @@ def test_image_method(


@pytest.mark.benchmark(group="fermat_method")
@batches
def test_fermat(
batch: tuple[int, ...],
basic_planar_mirrors_setup: PlanarMirrorsSetup,
key: PRNGKeyArray,
large_random_planar_mirrors_setup: PlanarMirrorsSetup,
benchmark: BenchmarkFixture,
) -> None:
setup = basic_planar_mirrors_setup.broadcast_to(*batch).add_noeffect_noise(key=key)
setup = large_random_planar_mirrors_setup
_ = benchmark(
lambda: fermat_path_on_planar_mirrors(
setup.from_vertices,
Expand All @@ -66,25 +46,32 @@ def test_fermat(

@pytest.mark.benchmark(group="triangles_visible_from_vertices")
@pytest.mark.parametrize("num_rays", [100, 1000, 10000])
@pytest.mark.parametrize("use_scan", [False, True])
def test_transmitter_visibility_in_simple_street_canyon_scene(
num_rays: int,
use_scan: bool,
simple_street_canyon_scene: TriangleScene,
benchmark: BenchmarkFixture,
) -> None:
scene = simple_street_canyon_scene
_ = benchmark(
lambda: triangles_visible_from_vertices(
scene.transmitters, scene.mesh.triangle_vertices, num_rays=num_rays
scene.transmitters,
scene.mesh.triangle_vertices,
num_rays=num_rays,
use_scan=use_scan,
).block_until_ready()
)


@pytest.mark.benchmark(group="compute_paths")
@pytest.mark.parametrize("order", [0, 1, 2])
@pytest.mark.parametrize("chunk_size", [None, 20_000])
@pytest.mark.parametrize("use_scan", [False, True])
def test_compute_paths_in_simple_street_canyon_scene(
order: int,
chunk_size: int | None,
use_scan: bool,
simple_street_canyon_scene: TriangleScene,
benchmark: BenchmarkFixture,
) -> None:
Expand All @@ -93,15 +80,21 @@ def test_compute_paths_in_simple_street_canyon_scene(

@jax.debug_nans(False) # noqa: FBT003
def bench_fun() -> None:
for path in scene.compute_paths(order, chunk_size=chunk_size):
for path in scene.compute_paths(
order,
chunk_size=chunk_size,
use_scan=use_scan,
):
path.vertices.block_until_ready()

else:

@jax.debug_nans(False) # noqa: FBT003
def bench_fun() -> None:
scene.compute_paths(
order, chunk_size=chunk_size
order,
chunk_size=chunk_size,
use_scan=use_scan,
).vertices.block_until_ready() # type: ignore[reportAttributeAccessIssue]

_ = benchmark(bench_fun)
6 changes: 6 additions & 0 deletions differt/tests/rt/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,15 @@ def test_rays_intersect_triangles_random_inputs(
)
@pytest.mark.parametrize("epsilon", [None, 1e-6, 1e-2])
@pytest.mark.parametrize("hit_tol", [None, 0.0, 0.001, -0.5, 0.5])
@pytest.mark.parametrize("use_scan", [False, True])
@random_inputs("ray_origins", "ray_directions", "triangle_vertices")
def test_rays_intersect_any_triangle(
ray_origins: Array,
ray_directions: Array,
triangle_vertices: Array,
epsilon: float | None,
hit_tol: float | None,
use_scan: bool,
expectation: AbstractContextManager[Exception],
) -> None:
if hit_tol is None:
Expand All @@ -259,6 +261,7 @@ def test_rays_intersect_any_triangle(
triangle_vertices,
epsilon=epsilon,
hit_tol=hit_tol,
use_scan=use_scan,
)
expected_t, expected_hit = rays_intersect_triangles(
ray_origins[..., None, :],
Expand Down Expand Up @@ -293,17 +296,20 @@ def test_rays_intersect_any_triangle(
),
],
)
@pytest.mark.parametrize("use_scan", [False, True])
def test_triangles_visible_from_vertices(
vertex: Array,
expected_number: int,
num_rays: int,
use_scan: bool,
expectation: AbstractContextManager[Exception],
cube_vertices: Array,
) -> None:
visible_triangles = triangles_visible_from_vertices(
vertex,
cube_vertices,
num_rays=num_rays,
use_scan=use_scan,
)

with expectation:
Expand Down
10 changes: 5 additions & 5 deletions differt/tests/rt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

@jaxtyped(typechecker=typechecker)
class PlanarMirrorsSetup(eqx.Module):
from_vertices: Float[Array, "*batch 3"]
to_vertices: Float[Array, "*batch 3"]
mirror_vertices: Float[Array, "*batch num_mirrors 3"]
mirror_normals: Float[Array, "*batch num_mirrors 3"]
paths: Float[Array, "*batch num_mirrors 3"]
from_vertices: Float[Array, "*#batch 3"]
to_vertices: Float[Array, "*#batch 3"]
mirror_vertices: Float[Array, "*#batch num_mirrors 3"]
mirror_normals: Float[Array, "*#batch num_mirrors 3"]
paths: Float[Array, "*#batch num_mirrors 3"]

def broadcast_to(self, *batch: int) -> "PlanarMirrorsSetup":
num_mirrors = self.mirror_vertices.shape[-2]
Expand Down
Loading
Loading