Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(lib): make plotting utilities accept JAX arrays #146

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions differt/src/differt/geometry/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from beartype import beartype as typechecker
from jaxtyping import Array, ArrayLike, Bool, Float, Int, Shaped, jaxtyped

from differt.plotting import draw_paths
from differt.plotting import PlotOutput, draw_paths

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -311,7 +310,7 @@ def reduce(
"""
return jnp.sum(fun(self.vertices), where=self.mask)

def plot(self, **kwargs: Any) -> Any:
def plot(self, **kwargs: Any) -> PlotOutput:
"""
Plot the (masked) paths on a 3D scene.

Expand All @@ -322,4 +321,4 @@ def plot(self, **kwargs: Any) -> Any:
Returns:
The resulting plot output.
"""
return draw_paths(np.asarray(self.masked_vertices), **kwargs)
return draw_paths(self.masked_vertices, **kwargs)
11 changes: 5 additions & 6 deletions differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from beartype import beartype as typechecker
from jaxtyping import Array, ArrayLike, Bool, Float, Int, PRNGKeyArray, jaxtyped

import differt_core.geometry.triangle_mesh
from differt.plotting import draw_mesh
from differt.plotting import PlotOutput, draw_mesh

from .utils import normalize, orthogonal_basis, rotation_matrix_along_axis

Expand Down Expand Up @@ -420,7 +419,7 @@ def load_ply(cls, file: str) -> Self:
core_mesh = differt_core.geometry.triangle_mesh.TriangleMesh.load_ply(file)
return cls.from_core(core_mesh)

def plot(self, **kwargs: Any) -> Any:
def plot(self, **kwargs: Any) -> PlotOutput:
"""
Plot this mesh on a 3D scene.

Expand All @@ -432,11 +431,11 @@ def plot(self, **kwargs: Any) -> Any:
The resulting plot output.
"""
if "face_colors" not in kwargs and self.face_colors is not None:
kwargs["face_colors"] = np.asarray(self.face_colors)
kwargs["face_colors"] = self.face_colors

return draw_mesh(
vertices=np.asarray(self.vertices),
triangles=np.asarray(self.triangles),
vertices=self.vertices,
triangles=self.triangles,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/geometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def fibonacci_lattice(
... )
>>> from differt.plotting import draw_markers
>>>
>>> xyz = np.asarray(fibonacci_lattice(100))
>>> xyz = fibonacci_lattice(100)
>>> fig = draw_markers(xyz, marker={"color": xyz[:, 0]}, backend="plotly")
>>> fig # doctest: +SKIP
"""
Expand Down
12 changes: 6 additions & 6 deletions differt/src/differt/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
"""
Plotting utilities for DiffeRT objects.

.. warning::
.. tip::

Unlike in other modules, plotting utilities work
Unlike in other modules, plotting utilities also work
with NumPy arrays (:class:`np.ndarray<numpy.ndarray>`)
instead of JAX arrays. Therefore, it is important to first
convert any JAX array into its NumPy equivalent with
:func:`np.asarray<numpy.asarray>` before using it as an
argument to any of the functions defined here.


.. note::
Expand Down Expand Up @@ -88,13 +84,15 @@
"""

__all__ = (
"PlotOutput",
"dispatch",
"draw_contour",
"draw_image",
"draw_markers",
"draw_mesh",
"draw_paths",
"draw_rays",
"draw_surface",
"get_backend",
"process_matplotlib_kwargs",
"process_plotly_kwargs",
Expand All @@ -106,12 +104,14 @@
)

from ._core import (
PlotOutput,
draw_contour,
draw_image,
draw_markers,
draw_mesh,
draw_paths,
draw_rays,
draw_surface,
)
from ._utils import (
dispatch,
Expand Down
Loading