Skip to content

Commit

Permalink
feat(lib): add basic vertex diffraction (#70)
Browse files Browse the repository at this point in the history
* feat(lib): add basic vertex diffraction

* fix(tests): correct

* chore(tests): improve coverage

* chore(tests): added missed line coverage

* fix(docs): examples

* fix(docs): typo

* fix(docs): add missing key argument

* chore(lib): changing default color

* chore(docs): re-order changelog

* fix(examples): set_offsets
  • Loading branch information
jeertmans authored Jun 27, 2024
1 parent ab6f9fb commit 59cdf8f
Show file tree
Hide file tree
Showing 11 changed files with 441 additions and 40 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(unreleased)=
## [Unreleased](https://github.com/jeertmans/DiffeRT2d/compare/v0.3.2...HEAD)

(unreleased-added)=
### Added

+ Added `Vertex` class for basic vertex diffraction.
[#70](https://github.com/jeertmans/DiffeRT2d/pull/70)
+ Added `get_vertices` method to `Wall`.
[#70](https://github.com/jeertmans/DiffeRT2d/pull/70)
+ Added `filter_objects` method to `Scene`.
[#70](https://github.com/jeertmans/DiffeRT2d/pull/70)
+ Added `filter_objects` parameters to `Scene.all_path_candidates`
and related methods.
[#70](https://github.com/jeertmans/DiffeRT2d/pull/70)
+ Added a lower-level, cached, variant of `all_path_candidates`.
[#70](https://github.com/jeertmans/DiffeRT2d/pull/70)

(unreleased-chore)=
### Chore

Expand Down
108 changes: 98 additions & 10 deletions differt2d/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,14 @@ def plot(

x, y = self.xy

artists: MutableSequence[Artist] = ax.plot(
[x],
[y],
*args,
**kwargs,
) # type: ignore[reportGeneralTypeIssues]
artists: list[Artist] = [
ax.scatter(
x,
y,
*args,
**kwargs,
)
]

if annotate:
xytext: tuple[float, float] = self.xy + jnp.asarray(
Expand All @@ -335,6 +337,84 @@ def bounding_box(self) -> Float[Array, "2 2"]: # type: ignore[reportIncompatibl
return jnp.vstack([self.xy, self.xy])


@jaxtyped(typechecker=typechecker)
class Vertex(Point, Object, eqx.Module):
"""
A vertex for corner diffraction.
.. plot::
:include-source: true
import matplotlib.pyplot as plt
import jax.numpy as jnp
from differt2d.geometry import Wall
ax = plt.gca()
wall = Wall(xys=jnp.array([[0., 0.], [1., 0.]]))
_ = wall.plot(ax)
for vertex in wall.get_vertices():
_ = vertex.plot(ax)
plt.show() # doctest: +SKIP
"""

@staticmethod
@partial(jax.jit, inline=True)
@jaxtyped(typechecker=typechecker)
def parameters_count() -> int: # type: ignore[reportIncompatibleMethodOverride] # noqa: D102
return 0

@jax.jit
@jaxtyped(typechecker=typechecker)
def parametric_to_cartesian( # type: ignore[reportIncompatibleMethodOverride] # noqa: D102
self,
param_coords: Float[Array, " {self.parameters_count()}"], # type: ignore[reportUndefinedVariable]
) -> Float[Array, "2"]:
return self.xy

@jax.jit
@jaxtyped(typechecker=typechecker)
def cartesian_to_parametric( # type: ignore[reportIncompatibleMethodOverride] # noqa: D102
self, carte_coords: Float[Array, "2"]
) -> Float[Array, " {self.parameters_count()}"]: # type: ignore[reportUndefinedVariable]
return jnp.empty_like(carte_coords, shape=0)

@partial(jax.jit, static_argnames=("approx", "function"))
@jaxtyped(typechecker=typechecker)
def contains_parametric( # type: ignore[reportIncompatibleMethodOverride] # noqa: D102
self,
param_coords: Float[Array, " {self.parameters_count()}"], # type: ignore[reportUndefinedVariable]
approx: Optional[bool] = None,
**kwargs: Any,
) -> Truthy:
return true_value(approx=approx)

@partial(jax.jit, static_argnames=("approx", "function"))
@jaxtyped(typechecker=typechecker)
def intersects_cartesian( # type: ignore[reportIncompatibleMethodOverride] # noqa: D102
self,
ray: Float[Array, "2 2"],
patch: ScalarFloat = DEFAULT_PATCH,
approx: Optional[bool] = None,
**kwargs: Any,
) -> Truthy:
return false_value(approx=approx)

@jax.jit
@jaxtyped(typechecker=typechecker)
def evaluate_cartesian(self, ray_path: Float[Array, "3 2"]) -> Float[Array, " "]: # type: ignore[reportIncompatibleMethodOverride] # noqa: D102
return jnp.array(0.0, dtype=ray_path.dtype)

@jaxtyped(typechecker=typechecker)
def plot( # noqa: D102
self, ax: Axes, *args: Any, **kwargs: Any
) -> MutableSequence[Artist]: # pragma: no cover
kwargs.setdefault("edgecolors", "black")
kwargs.setdefault("facecolors", (1.0, 1.0, 0.0, 0.5))
kwargs.setdefault("linestyle", "dashed")
return super().plot(ax, *args, **kwargs)


@jaxtyped(typechecker=typechecker)
class Ray(Plottable, eqx.Module):
"""
Expand Down Expand Up @@ -572,6 +652,16 @@ def image_of(self, point: Float[Array, "2"]) -> Float[Array, "2"]:
i = point - self.origin()
return point - 2.0 * jnp.dot(i, self.normal()) * self.normal()

@jax.jit
@jaxtyped(typechecker=typechecker)
def get_vertices(self) -> tuple[Vertex, Vertex]:
"""
Returns the two vertices of this wall.
:return: The two vertices.
"""
return Vertex(xy=self.xys[0, :]), Vertex(xy=self.xys[1, :])


@jaxtyped(typechecker=typechecker)
class RIS(Wall, eqx.Module):
Expand Down Expand Up @@ -740,7 +830,7 @@ def on_objects(
return contains

@partial(jax.jit, inline=True, static_argnames=("approx", "function"))
@jaxtyped(typechecker=None)
@jaxtyped(typechecker=typechecker)
def intersects_with_objects(
self,
objects: Sequence[Interactable],
Expand All @@ -765,9 +855,7 @@ def intersects_with_objects(
:return: Whether this path intersects any of the objects.
"""
interacting_object_indices = [-1] + [i for i in path_candidate] + [-1]
intersects = false_value(
approx=approx
) # TODO(fixme): why is this a NumPy array?
intersects = false_value(approx=approx)

for i in range(self.xys.shape[0] - 1):
ray_path = self.xys[i : i + 2, :]
Expand Down
6 changes: 3 additions & 3 deletions differt2d/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,10 @@ def true_value(approx: Optional[bool] = None) -> Truthy:
"""
if approx is None:
approx = ENABLE_APPROX
return jnp.array(1.0) if approx else jnp.array(True)
return jnp.array(1.0) if approx else jnp.array(True, dtype=bool)


@partial(jax.jit, inline=True, static_argnames=("approx",))
@partial(jax.jit, inline=False, static_argnames=("approx",))
@jaxtyped(typechecker=typechecker)
def false_value(approx: Optional[bool] = None) -> Truthy:
"""
Expand All @@ -613,4 +613,4 @@ def false_value(approx: Optional[bool] = None) -> Truthy:
"""
if approx is None:
approx = ENABLE_APPROX
return jnp.array(0.0) if approx else jnp.array(False)
return jnp.array(0.0) if approx else jnp.array(False, dtype=bool)
Loading

0 comments on commit 59cdf8f

Please sign in to comment.