Skip to content

Commit

Permalink
chore(lib): stacking pytrees to use scan
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Aug 4, 2023
1 parent e72e91a commit 7184ab2
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 98 deletions.
78 changes: 62 additions & 16 deletions differt2d/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -551,6 +551,37 @@ def parametric_to_cartesian(objects, parametric_coords, n, tx_coords, rx_coords)
return cartesian_coords


Pytree = Union[list, tuple, dict]


def stack_leaves(pytrees: Pytree, axis: int = 0) -> Pytree:
"""
Stack the leaves of one or more Pytrees along a new axis.
Solution inspired from:
https://github.com/google/jax/discussions/16882#discussioncomment-6638501.
:param pytress: One or more Pytrees.
:param axis: Axis along which leaves are stacked.
:return: A new Pytree with leaves stacked along the new axis.
"""
return jax.tree_util.tree_map(lambda *xs: jnp.stack(xs, axis=axis), *pytrees)


def unstack_leaves(pytrees) -> List[Pytree]:
"""
Unstack the leaves of a Pytree.
Reciprocal of :func:`stack_leaves`.
:param pytrees: A Pytree.
:return: A list of Pytrees,
where each Pytree has the same structure as the input Pytree,
but each leaf contains only one part of the original leaf.
"""
leaves, treedef = jax.tree_util.tree_flatten(pytrees)
return [treedef.unflatten(leaf) for leaf in zip(*leaves)]


@dataclass
class ImagePath(Path):
"""
Expand Down Expand Up @@ -597,6 +628,12 @@ def from_tx_objects_rx(
"""
n = len(objects)

if n == 0:
points = jnp.row_stack([tx.point, rx.point])
return cls(points=points, loss=jnp.array(0.0))

walls = stack_leaves(objects)

@jit
def path_loss(cartesian_coords):
_loss = 0.0
Expand All @@ -605,24 +642,23 @@ def path_loss(cartesian_coords):

return _loss

image = tx.point
images = jnp.empty((n, 2))
def forward(image, wall):
image = wall.image_of(image)
return image, image

for i in range(n):
image = objects[i].image_of(image)
images = images.at[i, :].set(image)

points = jnp.empty_like(images)

point = rx.point
for i in reversed(range(n)):
obj = objects[i]
p = obj.origin()
n = obj.normal()
u = point - images[i, :]
def backward(point, x):
wall, image = x
p = wall.origin()
n = wall.normal()
u = point - image
v = p - point
point = point + jnp.dot(v, n) * u / jnp.dot(u, n)
points = points.at[i, :].set(point)
return point, point

_, images = jax.lax.scan(forward, init=tx.point, xs=walls)
_, points = jax.lax.scan(
backward, init=rx.point, xs=(walls, images), reverse=True
)

points = jnp.row_stack([tx.point, points, rx.point])

Expand Down Expand Up @@ -679,6 +715,11 @@ def from_tx_objects_rx(
plt.show()
"""
n = len(objects)

if n == 0:
points = jnp.row_stack([tx.point, rx.point])
return cls(points=points, loss=jnp.array(0.0))

n_unknowns = sum([obj.parameters_count() for obj in objects])

@jit
Expand Down Expand Up @@ -759,6 +800,11 @@ def from_tx_objects_rx(
plt.show()
"""
n = len(objects)

if n == 0:
points = jnp.row_stack([tx.point, rx.point])
return cls(points=points, loss=jnp.array(0.0))

n_unknowns = sum(obj.parameters_count() for obj in objects)

@jit
Expand Down
24 changes: 11 additions & 13 deletions differt2d/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ def power(path, path_candidate, objects):
return 1 / (1.0 + l2)


@partial(jax.jit, inline=True)
def los_exists(path, path_candidate, objects):
l1 = path.length()
l2 = l1 * l1
return 1 / (1.0 + l2)


# @partial(jax.jit, static_argnames=("objects", "function"))
def accumulate_at_location(
tx: Point, objects, rx: Point, path_candidates, function
Expand Down Expand Up @@ -601,19 +594,24 @@ def all_paths(

return paths

def accumulate_over_paths(
self, function=power, tol: float = 1e-4, **kwargs: Any
) -> Array:
def accumulate_over_paths(self, function=power, **kwargs: Any) -> Array:
"""
Accumulates some function evaluated for each path in the scene.
:param function: The function to accumulate.
"""
path_candidates = self.all_path_candidates(**kwargs)

return accumulate_at_location(
self.tx, self.objects, self.rx, path_candidates, function
)

def accumulate_on_grid(
self, X, Y, function=power, tol: float = 1e-4, **kwargs
self, X, Y, function=power, min_order: int = 0, max_order: int = 1, **kwargs
) -> Array:
path_candidates = self.all_path_candidates(**kwargs)
path_candidates = self.all_path_candidates(
min_order=min_order, max_order=max_order
)

grid = jnp.dstack((X, Y))

Expand All @@ -622,4 +620,4 @@ def accumulate_on_grid(
in_axes=(None, None, 0, None, None),
)

return vacc(self.tx, self.objects, grid, path_candidates, function)
return vacc(self.tx, self.objects, grid, path_candidates, function, **kwargs)
3 changes: 2 additions & 1 deletion examples/basic_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ def main():

X, Y = scene.grid(n=150)

Z = scene.accumulate_on_grid(X, Y)
Z = scene.accumulate_on_grid(X, Y, max_order=1)

plt.pcolormesh(X, Y, Z)

plt.savefig("power.png", transparent=True)
plt.show()


Expand Down
6 changes: 2 additions & 4 deletions examples/basic_scene_los.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ def line_of_sight(

def main():
ax = plt.gca()
ax.set_facecolor((1.0, 1.0, 1.0, 0.0))
scene = Scene.basic_scene()
scene.plot(ax)

for path in scene.all_paths():
path.plot(ax)

X, Y = scene.grid(n=100)
X, Y = scene.grid(n=300)

grid = jnp.dstack((X, Y))

Expand Down
25 changes: 25 additions & 0 deletions examples/basic_scene_power_received.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import matplotlib.pyplot as plt
import typer

from differt2d.scene import Scene


def main(min_order: int = 0, max_order: int = 1, resolution: int = 150):
ax = plt.gca()
scene = Scene.basic_scene()
scene.plot(ax)

for path in scene.all_paths():
path.plot(ax)

X, Y = scene.grid(n=resolution)

Z = scene.accumulate_on_grid(X, Y, min_order=min_order, max_order=max_order)

plt.pcolormesh(X, Y, Z)

plt.show()


if __name__ == "__main__":
typer.run(main)
25 changes: 25 additions & 0 deletions examples/square_scene_power_received.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import matplotlib.pyplot as plt
import typer

from differt2d.scene import Scene


def main(resolution: int = 150):
ax = plt.gca()
scene = Scene.square_scene()
scene.plot(ax)

for path in scene.all_paths():
path.plot(ax)

X, Y = scene.grid(n=resolution)

Z = scene.accumulate_on_grid(X, Y, min_order=1, max_order=1)

plt.pcolormesh(X, Y, Z)

plt.show()


if __name__ == "__main__":
typer.run(main)
Loading

0 comments on commit 7184ab2

Please sign in to comment.