diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5749ad4..5f29a07 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,7 @@ jobs: - name: Install Differt2D run: | - poetry install --with test + poetry install --with test,github-action - name: Install FilesFinder uses: taiki-e/install-action@v2 diff --git a/README.md b/README.md index 93a34de..a7bc5c3 100644 --- a/README.md +++ b/README.md @@ -89,8 +89,8 @@ List of available extras: + **`examples`**: install required dependencies to run all the scripts in the [examples](https://github.com/jeertmans/DiffeRT2d/tree/main/examples) - folder. -+ **`gui`**: W.I.P, do not use. + folder. ++ **`gui`**: W.I.P., do not use. diff --git a/differt2d/geometry.py b/differt2d/geometry.py index 878c77e..705c5fc 100644 --- a/differt2d/geometry.py +++ b/differt2d/geometry.py @@ -9,7 +9,6 @@ import jax import jax.numpy as jnp -from jax import jit from .abc import Interactable, Plottable from .logic import greater_equal, less_equal, logical_and, logical_or, true_value @@ -96,7 +95,7 @@ def segments_intersect( kwargs_no_function = kwargs.copy() kwargs_no_function.pop("function", None) - @jit + @jax.jit def test(num, den): den_is_zero = den == 0.0 den = jnp.where(den_is_zero, 1.0, den) @@ -188,7 +187,7 @@ class Ray(Plottable): """Ray points (origin, dest).""" points: Array # a b - @partial(jit, inline=True) + @partial(jax.jit, inline=True) def origin(self) -> Array: """ Returns the origin of this object. @@ -197,7 +196,7 @@ def origin(self) -> Array: """ return self.points[0] - @partial(jit, inline=True) + @partial(jax.jit, inline=True) def dest(self) -> Array: """ Returns the destination of this object. @@ -206,7 +205,7 @@ def dest(self) -> Array: """ return self.points[1] - @partial(jit, inline=True) + @partial(jax.jit, inline=True) def t(self) -> Array: """ Returns the direction vector of this object. @@ -279,7 +278,7 @@ class Wall(Ray, Interactable): plt.show() """ - @jit + @jax.jit def normal(self) -> Array: """ Returns the normal to the current wall, @@ -295,28 +294,28 @@ def normal(self) -> Array: return n @staticmethod - @partial(jit, inline=True) + @partial(jax.jit, inline=True) def parameters_count() -> int: return 1 - @jit + @jax.jit def parametric_to_cartesian(self, param_coords: Array) -> Array: return self.origin() + param_coords * self.t() - @jit + @jax.jit def cartesian_to_parametric(self, carte_coords: Array) -> Array: other = carte_coords - self.origin() squared_length = jnp.dot(self.t(), self.t()) squared_length = jnp.where(squared_length == 0.0, 1.0, squared_length) return jnp.dot(self.t(), other) / squared_length - @jit + @jax.jit def contains_parametric(self, param_coords: Array) -> Array: ge = greater_equal(param_coords, 0.0) le = less_equal(param_coords, 1.0) return logical_and(ge, le) - @jit + @jax.jit def intersects_cartesian( self, ray: Array, @@ -329,7 +328,7 @@ def intersects_cartesian( ray[1, :], ) - @jit + @jax.jit def evaluate_cartesian(self, ray_path: Array) -> Array: i = ray_path[1, :] - ray_path[0, :] # Incident r = ray_path[2, :] - ray_path[1, :] # Reflected @@ -339,7 +338,7 @@ def evaluate_cartesian(self, ray_path: Array) -> Array: e = r - (i - 2 * jnp.dot(i, n) * n) return jnp.dot(e, e) - @jit + @jax.jit def image_of(self, point: Array) -> Array: """ Returns the image of a point with respect to @@ -374,7 +373,7 @@ class RIS(Wall): The constant angle of reflection. """ - @jit + @jax.jit def evaluate_cartesian(self, ray_path: Array) -> Array: v2 = ray_path[2, :] - ray_path[1, :] n = self.normal() @@ -456,7 +455,7 @@ def from_tx_objects_rx( points = jnp.row_stack([tx.point, points, rx.point]) return cls(points=points) - @jit + @jax.jit def length(self) -> Array: """ Returns the length of this path. @@ -465,7 +464,7 @@ def length(self) -> Array: """ return path_length(self.points) - @jit + @jax.jit def on_objects(self, objects: List[Interactable]) -> Array: """ Returns whether the path correctly passes on the objects. @@ -485,7 +484,7 @@ def on_objects(self, objects: List[Interactable]) -> Array: return contains - @jit + @jax.jit def intersects_with_objects( self, objects: List[Interactable], path_candidate: List[int] ) -> Array: @@ -535,7 +534,7 @@ def parametric_to_cartesian_from_slice(obj, parametric_coords, start, size): return obj.parametric_to_cartesian(parametric_coords) -@partial(jit, static_argnames=("n",)) +@partial(jax.jit, static_argnames=("n",)) def parametric_to_cartesian(objects, parametric_coords, n, tx_coords, rx_coords): cartesian_coords = jnp.empty((n + 2, 2)) cartesian_coords = cartesian_coords.at[0].set(tx_coords) @@ -589,7 +588,7 @@ class ImagePath(Path): """ @classmethod - @partial(jit, static_argnames=["cls"]) + @partial(jax.jit, static_argnames=["cls"]) def from_tx_objects_rx( cls, tx: Point, @@ -634,7 +633,7 @@ def from_tx_objects_rx( walls = stack_leaves(objects) - @jit + @jax.jit def path_loss(cartesian_coords): _loss = 0.0 for i, obj in enumerate(objects): @@ -672,7 +671,7 @@ class FermatPath(Path): """ @classmethod - @partial(jit, static_argnames=("cls", "steps", "optimizer")) + @partial(jax.jit, static_argnames=("cls", "steps", "optimizer")) def from_tx_objects_rx( cls, tx: Point, @@ -722,7 +721,7 @@ def from_tx_objects_rx( n_unknowns = sum([obj.parameters_count() for obj in objects]) - @jit + @jax.jit def loss_fun(theta): cartesian_coords = parametric_to_cartesian( objects, theta, n, tx.point, rx.point @@ -730,7 +729,7 @@ def loss_fun(theta): return path_length(cartesian_coords) - @jit + @jax.jit def path_loss(cartesian_coords): _loss = 0.0 for i, obj in enumerate(objects): @@ -757,7 +756,7 @@ class MinPath(Path): """ @classmethod - @partial(jit, static_argnames=("cls", "steps", "optimizer")) + @partial(jax.jit, static_argnames=("cls", "steps", "optimizer")) def from_tx_objects_rx( cls, tx: Point, @@ -807,7 +806,7 @@ def from_tx_objects_rx( n_unknowns = sum(obj.parameters_count() for obj in objects) - @jit + @jax.jit def loss_fun(theta): cartesian_coords = parametric_to_cartesian( objects, theta, n, tx.point, rx.point