Skip to content

Commit

Permalink
chore(lib): make plotting utilities accept JAX arrays (#146)
Browse files Browse the repository at this point in the history
* chore(lib): make plotting utilities accept JAX arrays

* chore(docs): fixes

* chore(tests): fix tests
  • Loading branch information
jeertmans authored Oct 14, 2024
1 parent ef9013e commit 451ad49
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 115 deletions.
7 changes: 3 additions & 4 deletions differt/src/differt/geometry/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from beartype import beartype as typechecker
from jaxtyping import Array, ArrayLike, Bool, Float, Int, Shaped, jaxtyped

from differt.plotting import draw_paths
from differt.plotting import PlotOutput, draw_paths

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -311,7 +310,7 @@ def reduce(
"""
return jnp.sum(fun(self.vertices), where=self.mask)

def plot(self, **kwargs: Any) -> Any:
def plot(self, **kwargs: Any) -> PlotOutput:
"""
Plot the (masked) paths on a 3D scene.
Expand All @@ -322,4 +321,4 @@ def plot(self, **kwargs: Any) -> Any:
Returns:
The resulting plot output.
"""
return draw_paths(np.asarray(self.masked_vertices), **kwargs)
return draw_paths(self.masked_vertices, **kwargs)
11 changes: 5 additions & 6 deletions differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from beartype import beartype as typechecker
from jaxtyping import Array, ArrayLike, Bool, Float, Int, PRNGKeyArray, jaxtyped

import differt_core.geometry.triangle_mesh
from differt.plotting import draw_mesh
from differt.plotting import PlotOutput, draw_mesh

from .utils import normalize, orthogonal_basis, rotation_matrix_along_axis

Expand Down Expand Up @@ -420,7 +419,7 @@ def load_ply(cls, file: str) -> Self:
core_mesh = differt_core.geometry.triangle_mesh.TriangleMesh.load_ply(file)
return cls.from_core(core_mesh)

def plot(self, **kwargs: Any) -> Any:
def plot(self, **kwargs: Any) -> PlotOutput:
"""
Plot this mesh on a 3D scene.
Expand All @@ -432,11 +431,11 @@ def plot(self, **kwargs: Any) -> Any:
The resulting plot output.
"""
if "face_colors" not in kwargs and self.face_colors is not None:
kwargs["face_colors"] = np.asarray(self.face_colors)
kwargs["face_colors"] = self.face_colors

return draw_mesh(
vertices=np.asarray(self.vertices),
triangles=np.asarray(self.triangles),
vertices=self.vertices,
triangles=self.triangles,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/geometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def fibonacci_lattice(
... )
>>> from differt.plotting import draw_markers
>>>
>>> xyz = np.asarray(fibonacci_lattice(100))
>>> xyz = fibonacci_lattice(100)
>>> fig = draw_markers(xyz, marker={"color": xyz[:, 0]}, backend="plotly")
>>> fig # doctest: +SKIP
"""
Expand Down
12 changes: 6 additions & 6 deletions differt/src/differt/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
"""
Plotting utilities for DiffeRT objects.
.. warning::
.. tip::
Unlike in other modules, plotting utilities work
Unlike in other modules, plotting utilities also work
with NumPy arrays (:class:`np.ndarray<numpy.ndarray>`)
instead of JAX arrays. Therefore, it is important to first
convert any JAX array into its NumPy equivalent with
:func:`np.asarray<numpy.asarray>` before using it as an
argument to any of the functions defined here.
.. note::
Expand Down Expand Up @@ -88,13 +84,15 @@
"""

__all__ = (
"PlotOutput",
"dispatch",
"draw_contour",
"draw_image",
"draw_markers",
"draw_mesh",
"draw_paths",
"draw_rays",
"draw_surface",
"get_backend",
"process_matplotlib_kwargs",
"process_plotly_kwargs",
Expand All @@ -106,12 +104,14 @@
)

from ._core import (
PlotOutput,
draw_contour,
draw_image,
draw_markers,
draw_mesh,
draw_paths,
draw_rays,
draw_surface,
)
from ._utils import (
dispatch,
Expand Down
Loading

0 comments on commit 451ad49

Please sign in to comment.