Skip to content

Commit

Permalink
Merge pull request #19 from jeertmans/jit
Browse files Browse the repository at this point in the history
chore(lib): use jax.jit instead of jit
  • Loading branch information
jeertmans authored Aug 10, 2023
2 parents 5e06129 + 81cea1a commit 636e1d9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:

- name: Install Differt2D
run: |
poetry install --with test
poetry install --with test,github-action
- name: Install FilesFinder
uses: taiki-e/install-action@v2
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ List of available extras:

+ **`examples`**: install required dependencies to run all the scripts in the
[examples](https://github.com/jeertmans/DiffeRT2d/tree/main/examples)
folder.
+ **`gui`**: W.I.P, do not use.
folder.
+ **`gui`**: W.I.P., do not use.


<!-- end install -->
Expand Down
49 changes: 24 additions & 25 deletions differt2d/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import jax
import jax.numpy as jnp
from jax import jit

from .abc import Interactable, Plottable
from .logic import greater_equal, less_equal, logical_and, logical_or, true_value
Expand Down Expand Up @@ -96,7 +95,7 @@ def segments_intersect(
kwargs_no_function = kwargs.copy()
kwargs_no_function.pop("function", None)

@jit
@jax.jit
def test(num, den):
den_is_zero = den == 0.0
den = jnp.where(den_is_zero, 1.0, den)
Expand Down Expand Up @@ -188,7 +187,7 @@ class Ray(Plottable):
"""Ray points (origin, dest)."""
points: Array # a b

@partial(jit, inline=True)
@partial(jax.jit, inline=True)
def origin(self) -> Array:
"""
Returns the origin of this object.
Expand All @@ -197,7 +196,7 @@ def origin(self) -> Array:
"""
return self.points[0]

@partial(jit, inline=True)
@partial(jax.jit, inline=True)
def dest(self) -> Array:
"""
Returns the destination of this object.
Expand All @@ -206,7 +205,7 @@ def dest(self) -> Array:
"""
return self.points[1]

@partial(jit, inline=True)
@partial(jax.jit, inline=True)
def t(self) -> Array:
"""
Returns the direction vector of this object.
Expand Down Expand Up @@ -279,7 +278,7 @@ class Wall(Ray, Interactable):
plt.show()
"""

@jit
@jax.jit
def normal(self) -> Array:
"""
Returns the normal to the current wall,
Expand All @@ -295,28 +294,28 @@ def normal(self) -> Array:
return n

@staticmethod
@partial(jit, inline=True)
@partial(jax.jit, inline=True)
def parameters_count() -> int:
return 1

@jit
@jax.jit
def parametric_to_cartesian(self, param_coords: Array) -> Array:
return self.origin() + param_coords * self.t()

@jit
@jax.jit
def cartesian_to_parametric(self, carte_coords: Array) -> Array:
other = carte_coords - self.origin()
squared_length = jnp.dot(self.t(), self.t())
squared_length = jnp.where(squared_length == 0.0, 1.0, squared_length)
return jnp.dot(self.t(), other) / squared_length

@jit
@jax.jit
def contains_parametric(self, param_coords: Array) -> Array:
ge = greater_equal(param_coords, 0.0)
le = less_equal(param_coords, 1.0)
return logical_and(ge, le)

@jit
@jax.jit
def intersects_cartesian(
self,
ray: Array,
Expand All @@ -329,7 +328,7 @@ def intersects_cartesian(
ray[1, :],
)

@jit
@jax.jit
def evaluate_cartesian(self, ray_path: Array) -> Array:
i = ray_path[1, :] - ray_path[0, :] # Incident
r = ray_path[2, :] - ray_path[1, :] # Reflected
Expand All @@ -339,7 +338,7 @@ def evaluate_cartesian(self, ray_path: Array) -> Array:
e = r - (i - 2 * jnp.dot(i, n) * n)
return jnp.dot(e, e)

@jit
@jax.jit
def image_of(self, point: Array) -> Array:
"""
Returns the image of a point with respect to
Expand Down Expand Up @@ -374,7 +373,7 @@ class RIS(Wall):
The constant angle of reflection.
"""

@jit
@jax.jit
def evaluate_cartesian(self, ray_path: Array) -> Array:
v2 = ray_path[2, :] - ray_path[1, :]
n = self.normal()
Expand Down Expand Up @@ -456,7 +455,7 @@ def from_tx_objects_rx(
points = jnp.row_stack([tx.point, points, rx.point])
return cls(points=points)

@jit
@jax.jit
def length(self) -> Array:
"""
Returns the length of this path.
Expand All @@ -465,7 +464,7 @@ def length(self) -> Array:
"""
return path_length(self.points)

@jit
@jax.jit
def on_objects(self, objects: List[Interactable]) -> Array:
"""
Returns whether the path correctly passes on the objects.
Expand All @@ -485,7 +484,7 @@ def on_objects(self, objects: List[Interactable]) -> Array:

return contains

@jit
@jax.jit
def intersects_with_objects(
self, objects: List[Interactable], path_candidate: List[int]
) -> Array:
Expand Down Expand Up @@ -535,7 +534,7 @@ def parametric_to_cartesian_from_slice(obj, parametric_coords, start, size):
return obj.parametric_to_cartesian(parametric_coords)


@partial(jit, static_argnames=("n",))
@partial(jax.jit, static_argnames=("n",))
def parametric_to_cartesian(objects, parametric_coords, n, tx_coords, rx_coords):
cartesian_coords = jnp.empty((n + 2, 2))
cartesian_coords = cartesian_coords.at[0].set(tx_coords)
Expand Down Expand Up @@ -589,7 +588,7 @@ class ImagePath(Path):
"""

@classmethod
@partial(jit, static_argnames=["cls"])
@partial(jax.jit, static_argnames=["cls"])
def from_tx_objects_rx(
cls,
tx: Point,
Expand Down Expand Up @@ -634,7 +633,7 @@ def from_tx_objects_rx(

walls = stack_leaves(objects)

@jit
@jax.jit
def path_loss(cartesian_coords):
_loss = 0.0
for i, obj in enumerate(objects):
Expand Down Expand Up @@ -672,7 +671,7 @@ class FermatPath(Path):
"""

@classmethod
@partial(jit, static_argnames=("cls", "steps", "optimizer"))
@partial(jax.jit, static_argnames=("cls", "steps", "optimizer"))
def from_tx_objects_rx(
cls,
tx: Point,
Expand Down Expand Up @@ -722,15 +721,15 @@ def from_tx_objects_rx(

n_unknowns = sum([obj.parameters_count() for obj in objects])

@jit
@jax.jit
def loss_fun(theta):
cartesian_coords = parametric_to_cartesian(
objects, theta, n, tx.point, rx.point
)

return path_length(cartesian_coords)

@jit
@jax.jit
def path_loss(cartesian_coords):
_loss = 0.0
for i, obj in enumerate(objects):
Expand All @@ -757,7 +756,7 @@ class MinPath(Path):
"""

@classmethod
@partial(jit, static_argnames=("cls", "steps", "optimizer"))
@partial(jax.jit, static_argnames=("cls", "steps", "optimizer"))
def from_tx_objects_rx(
cls,
tx: Point,
Expand Down Expand Up @@ -807,7 +806,7 @@ def from_tx_objects_rx(

n_unknowns = sum(obj.parameters_count() for obj in objects)

@jit
@jax.jit
def loss_fun(theta):
cartesian_coords = parametric_to_cartesian(
objects, theta, n, tx.point, rx.point
Expand Down

0 comments on commit 636e1d9

Please sign in to comment.