From 140bc97a1ac3262b0a05902bebd9aa08f1be880c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 14 Oct 2024 15:17:42 +0200 Subject: [PATCH 1/3] chore(lib): make plotting utilities accept JAX arrays --- differt/src/differt/geometry/paths.py | 7 +- differt/src/differt/geometry/triangle_mesh.py | 11 +- differt/src/differt/geometry/utils.py | 2 +- differt/src/differt/plotting/__init__.py | 12 +- differt/src/differt/plotting/_core.py | 273 ++++++++++++------ differt/src/differt/scene/triangle_scene.py | 11 +- differt/tests/plotting/test_core.py | 36 ++- .../notebooks/advanced_path_tracing.ipynb | 5 +- docs/source/numpy_vs_jax.md | 4 +- 9 files changed, 246 insertions(+), 115 deletions(-) diff --git a/differt/src/differt/geometry/paths.py b/differt/src/differt/geometry/paths.py index b3c4d90d..b914cd7e 100644 --- a/differt/src/differt/geometry/paths.py +++ b/differt/src/differt/geometry/paths.py @@ -7,11 +7,10 @@ import equinox as eqx import jax import jax.numpy as jnp -import numpy as np from beartype import beartype as typechecker from jaxtyping import Array, ArrayLike, Bool, Float, Int, Shaped, jaxtyped -from differt.plotting import draw_paths +from differt.plotting import PlotOutput, draw_paths if sys.version_info >= (3, 11): from typing import Self @@ -311,7 +310,7 @@ def reduce( """ return jnp.sum(fun(self.vertices), where=self.mask) - def plot(self, **kwargs: Any) -> Any: + def plot(self, **kwargs: Any) -> PlotOutput: """ Plot the (masked) paths on a 3D scene. @@ -322,4 +321,4 @@ def plot(self, **kwargs: Any) -> Any: Returns: The resulting plot output. """ - return draw_paths(np.asarray(self.masked_vertices), **kwargs) + return draw_paths(self.masked_vertices, **kwargs) diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index 61e35cf3..afcab8a0 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -7,12 +7,11 @@ import equinox as eqx import jax import jax.numpy as jnp -import numpy as np from beartype import beartype as typechecker from jaxtyping import Array, ArrayLike, Bool, Float, Int, PRNGKeyArray, jaxtyped import differt_core.geometry.triangle_mesh -from differt.plotting import draw_mesh +from differt.plotting import PlotOutput, draw_mesh from .utils import normalize, orthogonal_basis, rotation_matrix_along_axis @@ -420,7 +419,7 @@ def load_ply(cls, file: str) -> Self: core_mesh = differt_core.geometry.triangle_mesh.TriangleMesh.load_ply(file) return cls.from_core(core_mesh) - def plot(self, **kwargs: Any) -> Any: + def plot(self, **kwargs: Any) -> PlotOutput: """ Plot this mesh on a 3D scene. @@ -432,11 +431,11 @@ def plot(self, **kwargs: Any) -> Any: The resulting plot output. """ if "face_colors" not in kwargs and self.face_colors is not None: - kwargs["face_colors"] = np.asarray(self.face_colors) + kwargs["face_colors"] = self.face_colors return draw_mesh( - vertices=np.asarray(self.vertices), - triangles=np.asarray(self.triangles), + vertices=self.vertices, + triangles=self.triangles, **kwargs, ) diff --git a/differt/src/differt/geometry/utils.py b/differt/src/differt/geometry/utils.py index 4827efb7..8454254c 100644 --- a/differt/src/differt/geometry/utils.py +++ b/differt/src/differt/geometry/utils.py @@ -357,7 +357,7 @@ def fibonacci_lattice( ... ) >>> from differt.plotting import draw_markers >>> - >>> xyz = np.asarray(fibonacci_lattice(100)) + >>> xyz = fibonacci_lattice(100) >>> fig = draw_markers(xyz, marker={"color": xyz[:, 0]}, backend="plotly") >>> fig # doctest: +SKIP """ diff --git a/differt/src/differt/plotting/__init__.py b/differt/src/differt/plotting/__init__.py index 63f961c4..28da7e6f 100644 --- a/differt/src/differt/plotting/__init__.py +++ b/differt/src/differt/plotting/__init__.py @@ -1,14 +1,10 @@ """ Plotting utilities for DiffeRT objects. -.. warning:: +.. tip:: - Unlike in other modules, plotting utilities work + Unlike in other modules, plotting utilities also work with NumPy arrays (:class:`np.ndarray`) - instead of JAX arrays. Therefore, it is important to first - convert any JAX array into its NumPy equivalent with - :func:`np.asarray` before using it as an - argument to any of the functions defined here. .. note:: @@ -88,6 +84,7 @@ """ __all__ = ( + "PlotOutput", "dispatch", "draw_contour", "draw_image", @@ -95,6 +92,7 @@ "draw_mesh", "draw_paths", "draw_rays", + "draw_surface", "get_backend", "process_matplotlib_kwargs", "process_plotly_kwargs", @@ -106,12 +104,14 @@ ) from ._core import ( + PlotOutput, draw_contour, draw_image, draw_markers, draw_mesh, draw_paths, draw_rays, + draw_surface, ) from ._utils import ( dispatch, diff --git a/differt/src/differt/plotting/_core.py b/differt/src/differt/plotting/_core.py index f3cdbbea..4242a090 100644 --- a/differt/src/differt/plotting/_core.py +++ b/differt/src/differt/plotting/_core.py @@ -5,7 +5,7 @@ from typing import Any import numpy as np -from jaxtyping import Float, Int, Num +from jaxtyping import ArrayLike, Int, Real from ._utils import ( dispatch, @@ -35,11 +35,14 @@ except ImportError: Canvas = Any +PlotOutput = Canvas | MplFigure | Figure +"""The output of any plotting function.""" + @dispatch def draw_mesh( - vertices: Float[np.ndarray, "num_vertices 3"], - triangles: Int[np.ndarray, "num_triangles 3"], + vertices: Real[ArrayLike, "num_vertices 3"], + triangles: Int[ArrayLike, "num_triangles 3"], **kwargs: Any, ) -> Canvas | MplFigure | Figure: # type: ignore[reportInvalidTypeForm] """ @@ -96,14 +99,16 @@ def draw_mesh( @draw_mesh.register("vispy") def _( - vertices: Float[np.ndarray, "num_vertices 3"], - triangles: Int[np.ndarray, "num_triangles 3"], + vertices: Real[ArrayLike, "num_vertices 3"], + triangles: Int[ArrayLike, "num_triangles 3"], **kwargs: Any, ) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Mesh # noqa: PLC0415 canvas, view = process_vispy_kwargs(kwargs) + vertices = np.asarray(vertices) + triangles = np.asarray(triangles) view.add(Mesh(vertices=vertices, faces=triangles, shading="flat", **kwargs)) view.camera.set_range() @@ -112,15 +117,16 @@ def _( @draw_mesh.register("matplotlib") def _( - vertices: Float[np.ndarray, "num_vertices 3"], - triangles: Int[np.ndarray, "num_triangles 3"], + vertices: Real[ArrayLike, "num_vertices 3"], + triangles: Int[ArrayLike, "num_triangles 3"], **kwargs: Any, ) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) kwargs.pop("face_colors", None) - x, y, z = vertices.T + x, y, z = np.asarray(vertices).T + triangles = np.asarray(triangles) ax.plot_trisurf(x, y, z, triangles=triangles, **kwargs) return fig @@ -128,8 +134,8 @@ def _( @draw_mesh.register("plotly") def _( - vertices: Float[np.ndarray, "num_vertices 3"], - triangles: Int[np.ndarray, "num_triangles 3"], + vertices: Real[ArrayLike, "num_vertices 3"], + triangles: Int[ArrayLike, "num_triangles 3"], **kwargs: Any, ) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) @@ -139,15 +145,15 @@ def _( ) is not None and "facecolor" not in kwargs: kwargs["facecolor"] = face_colors - x, y, z = vertices.T - i, j, k = triangles.T + x, y, z = np.asarray(vertices).T + i, j, k = np.asarray(triangles).T return fig.add_mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs) @dispatch def draw_paths( - paths: Float[np.ndarray, "batch path_length 3"], + paths: Real[ArrayLike, "*batch path_length 3"], **kwargs: Any, ) -> Canvas | MplFigure | Figure: # type: ignore[reportInvalidTypeForm] """ @@ -204,7 +210,7 @@ def draw_paths( @draw_paths.register("vispy") def _( - paths: Float[np.ndarray, "*batch path_length 3"], + paths: Real[ArrayLike, "*batch path_length 3"], **kwargs: Any, ) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import LinePlot # noqa: PLC0415 @@ -213,6 +219,7 @@ def _( kwargs.setdefault("width", 3.0) kwargs.setdefault("marker_size", 0.0) + paths = np.asarray(paths) for path in paths.reshape(-1, *paths.shape[-2:]): x, y, z = path.T @@ -225,11 +232,13 @@ def _( @draw_paths.register("matplotlib") def _( - paths: Float[np.ndarray, "*batch path_length 3"], + paths: Real[ArrayLike, "*batch path_length 3"], **kwargs: Any, ) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) + paths = np.asarray(paths) + for path in paths.reshape(-1, *paths.shape[-2:]): ax.plot(*path.T, **kwargs) @@ -238,11 +247,13 @@ def _( @draw_paths.register("plotly") def _( - paths: Float[np.ndarray, "*batch path_length 3"], + paths: Real[ArrayLike, "*batch path_length 3"], **kwargs: Any, ) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) + paths = np.asarray(paths) + for path in paths.reshape(-1, *paths.shape[-2:]): x, y, z = path.T fig = fig.add_scatter3d(x=x, y=y, z=z, **kwargs) @@ -252,8 +263,8 @@ def _( @dispatch def draw_rays( - ray_origins: Float[np.ndarray, "*batch 3"], - ray_directions: Float[np.ndarray, "*batch 3"], + ray_origins: Real[ArrayLike, "*batch 3"], + ray_directions: Real[ArrayLike, "*batch 3"], **kwargs: Any, ) -> Canvas | MplFigure | Figure: # type: ignore[reportInvalidTypeForm] """ @@ -280,11 +291,9 @@ def draw_rays( >>> from differt.geometry.utils import fibonacci_lattice >>> from differt.plotting import draw_rays >>> - >>> ray_origins = np.zeros(3) - >>> ray_directions = np.asarray( - ... fibonacci_lattice(50) - ... ) # From JAX to NumPy array - >>> ray_origins, ray_directions = np.broadcast_arrays( + >>> ray_origins = jnp.zeros(3) + >>> ray_directions = fibonacci_lattice(50) + >>> ray_origins, ray_directions = jnp.broadcast_arrays( ... ray_origins, ray_directions ... ) >>> fig = draw_rays( @@ -298,10 +307,12 @@ def draw_rays( @draw_rays.register("vispy") def _( - ray_origins: Float[np.ndarray, "*batch 3"], - ray_directions: Float[np.ndarray, "*batch 3"], + ray_origins: Real[ArrayLike, "*batch 3"], + ray_directions: Real[ArrayLike, "*batch 3"], **kwargs: Any, ) -> Canvas: # type: ignore[reportInvalidTypeForm] + ray_origins = np.asarray(ray_origins) + ray_directions = np.asarray(ray_directions) ray_ends = ray_origins + ray_directions paths = np.concatenate((ray_origins[..., None, :], ray_ends[..., None, :]), axis=-2) @@ -310,14 +321,14 @@ def _( @draw_rays.register("matplotlib") def _( - ray_origins: Float[np.ndarray, "*batch 3"], - ray_directions: Float[np.ndarray, "*batch 3"], + ray_origins: Real[ArrayLike, "*batch 3"], + ray_directions: Real[ArrayLike, "*batch 3"], **kwargs: Any, ) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) - ray_origins = ray_origins.reshape(-1, 3) - ray_directions = ray_directions.reshape(-1, 3) + ray_origins = np.asarray(ray_origins).reshape(-1, 3) + ray_directions = np.asarray(ray_directions).reshape(-1, 3) ax.quiver(*ray_origins.T, *ray_directions.T, **kwargs) @@ -326,10 +337,12 @@ def _( @draw_rays.register("plotly") def _( - ray_origins: Float[np.ndarray, "*batch 3"], - ray_directions: Float[np.ndarray, "*batch 3"], + ray_origins: Real[ArrayLike, "*batch 3"], + ray_directions: Real[ArrayLike, "*batch 3"], **kwargs: Any, ) -> Figure: # type: ignore[reportInvalidTypeForm] + ray_origins = np.asarray(ray_origins) + ray_directions = np.asarray(ray_directions) ray_ends = ray_origins + ray_directions paths = np.concatenate((ray_origins[..., None, :], ray_ends[..., None, :]), axis=-2) @@ -338,7 +351,7 @@ def _( @dispatch def draw_markers( - markers: Float[np.ndarray, "num_markers 3"], + markers: Real[ArrayLike, "*batch 3"], labels: Sequence[str] | None = None, text_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, @@ -387,7 +400,7 @@ def draw_markers( @draw_markers.register("vispy") def _( - markers: Float[np.ndarray, "num_markers 3"], + markers: Real[ArrayLike, "*batch 3"], labels: Sequence[str] | None = None, text_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, @@ -398,6 +411,7 @@ def _( kwargs.setdefault("size", 1) kwargs.setdefault("edge_width_rel", 0.05) kwargs.setdefault("scaling", "scene") + markers = np.asarray(markers).reshape(-1, 3) view.add(Markers(pos=markers, **kwargs)) if labels: @@ -411,7 +425,7 @@ def _( @draw_markers.register("matplotlib") def _( - markers: Float[np.ndarray, "num_markers 3"], + markers: Real[ArrayLike, "*batch 3"], labels: Sequence[str] | None = None, text_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, @@ -423,7 +437,7 @@ def _( warnings.warn(msg, UserWarning, stacklevel=2) del labels, text_kwargs - xs, ys, zs = markers.T + xs, ys, zs = np.asarray(markers).reshape(-1, 3).T ax.scatter(xs, ys, zs=zs, **kwargs) @@ -432,7 +446,7 @@ def _( @draw_markers.register("plotly") def _( - markers: Float[np.ndarray, "num_markers 3"], + markers: Real[ArrayLike, "*batch 3"], labels: Sequence[str] | None = None, text_kwargs: Mapping[str, Any] | None = None, # noqa: ARG001 **kwargs: Any, @@ -444,7 +458,7 @@ def _( else: kwargs = {"mode": "markers", **kwargs} - x, y, z = markers.T + x, y, z = np.asarray(markers).reshape(-1, 3).T return fig.add_scatter3d( x=x, y=y, @@ -456,11 +470,11 @@ def _( @dispatch def draw_image( - data: Num[np.ndarray, "rows cols"] - | Num[np.ndarray, "rows cols 3"] - | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"] + | Real[ArrayLike, "rows cols 3"] + | Real[ArrayLike, "rows cols 4"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols 3"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> Canvas | MplFigure | Figure: # type: ignore[reportInvalidTypeForm] @@ -515,11 +529,11 @@ def draw_image( @draw_image.register("vispy") def _( - data: Num[np.ndarray, "rows cols"] - | Num[np.ndarray, "rows cols 3"] - | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"] + | Real[ArrayLike, "rows cols 3"] + | Real[ArrayLike, "rows cols 4"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols 3"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> Canvas: # type: ignore[reportInvalidTypeForm] @@ -528,6 +542,7 @@ def _( canvas, view = process_vispy_kwargs(kwargs) + data = np.asarray(data) image = Image(data, **kwargs) m, n = data.shape[:2] @@ -562,23 +577,22 @@ def _( @draw_image.register("matplotlib") def _( - data: Num[np.ndarray, "rows cols"] - | Num[np.ndarray, "rows cols 3"] - | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"] + | Real[ArrayLike, "rows cols 3"] + | Real[ArrayLike, "rows cols 4"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols 3"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) + data = np.asarray(data) m, n = data.shape[:2] - if x is None: - x = np.arange(n) + x = np.arange(n) if x is None else np.asarray(x) - if y is None: - y = np.arange(m) + y = np.arange(m) if y is None else np.asarray(y) ax.contourf(x, y, data, offset=z0, **kwargs) @@ -587,16 +601,20 @@ def _( @draw_image.register("plotly") def _( - data: Num[np.ndarray, "rows cols"] - | Num[np.ndarray, "rows cols 3"] - | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"] + | Real[ArrayLike, "rows cols 3"] + | Real[ArrayLike, "rows cols 4"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols 3"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) + data = np.asarray(data) + x = None if x is None else np.asarray(x) + y = None if y is None else np.asarray(y) + return fig.add_surface( x=x, y=y, @@ -608,11 +626,11 @@ def _( @dispatch def draw_contour( # noqa: PLR0917 - data: Num[np.ndarray, "rows cols"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols 3"] | None = None, z0: float = 0.0, - levels: int | Float[np.ndarray, " num_levels"] | None = None, + levels: int | Real[ArrayLike, " num_levels"] | None = None, fill: bool = False, **kwargs: Any, ) -> Canvas | MplFigure | Figure: # type: ignore[reportInvalidTypeForm] @@ -622,10 +640,10 @@ def draw_contour( # noqa: PLR0917 Args: data: The values over which the contour is drawn. x: The x-coordinates corresponding to first dimension - of the image. Those coordinates will be used to scale and translate + of the contour. Those coordinates will be used to scale and translate the contour. y: The y-coordinates corresponding to second dimension - of the image. Those coordinates will be used to scale and translate + of the contour. Those coordinates will be used to scale and translate the contour. z0: The z-coordinate at which the contour is placed. levels: The levels at which the contour is drawn. @@ -670,17 +688,19 @@ def draw_contour( # noqa: PLR0917 @draw_contour.register("vispy") def _( # noqa: PLR0917 - data: Num[np.ndarray, "rows cols"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols"] | None = None, z0: float = 0.0, - levels: int | Float[np.ndarray, " num_levels"] | None = None, + levels: int | Real[ArrayLike, " num_levels"] | None = None, fill: bool = False, **kwargs: Any, ) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Isocurve # noqa: PLC0415 from vispy.visuals.transforms import STTransform # noqa: PLC0415 + data = np.asarray(data) + if isinstance(levels, int): msg = ( f"VisPy does not support using {type(levels)} as parameters for `levels`. " @@ -688,6 +708,8 @@ def _( # noqa: PLR0917 ) warnings.warn(msg, UserWarning, stacklevel=2) levels = np.linspace(data.min(), data.max(), levels + 1) + else: + levels = np.asarray(levels) if fill: msg = "VisPy does not support filling contour, this option is ignored." @@ -731,23 +753,24 @@ def _( # noqa: PLR0917 @draw_contour.register("matplotlib") def _( # noqa: PLR0917 - data: Num[np.ndarray, "rows cols"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols"] | None = None, z0: float = 0.0, - levels: int | Float[np.ndarray, " num_levels"] | None = None, + levels: int | Real[ArrayLike, " num_levels"] | None = None, fill: bool = False, **kwargs: Any, ) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) + data = np.asarray(data) m, n = data.shape[:2] - if x is None: - x = np.arange(n) + x = np.arange(n) if x is None else np.asarray(x) + + y = np.arange(m) if y is None else np.asarray(y) - if y is None: - y = np.arange(m) + levels = levels if isinstance(levels, int) else np.asarray(levels) if fill: ax.contourf(x, y, data, offset=z0, levels=levels, **kwargs) @@ -759,11 +782,11 @@ def _( # noqa: PLR0917 @draw_contour.register("plotly") def _( # noqa: PLR0917 - data: Num[np.ndarray, "rows cols"], - x: Float[np.ndarray, " cols"] | None = None, - y: Float[np.ndarray, " rows"] | None = None, + data: Real[ArrayLike, "rows cols"], + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols"] | None = None, z0: float = 0.0, - levels: int | Float[np.ndarray, " num_levels"] | None = None, + levels: int | Real[ArrayLike, " num_levels"] | None = None, fill: bool = False, **kwargs: Any, ) -> Figure: # type: ignore[reportInvalidTypeForm] @@ -777,7 +800,8 @@ def _( # noqa: PLR0917 if isinstance(levels, int): kwargs.setdefault("autocontour", True) kwargs.setdefault("ncontours", levels) - elif isinstance(levels, np.ndarray): + elif isinstance(levels, ArrayLike): + levels = np.asarray(levels) msg = ( "Plotly does not support arbitrary level values, but only linearly spaced levels. " f"A range of values from {levels.min() = } to {levels.max() = } with step " @@ -789,9 +813,90 @@ def _( # noqa: PLR0917 contours["end"] = levels.max() contours["size"] = (levels.max() - levels.min()) / max(1, levels.size - 1) + data = np.asarray(data) + x = None if x is None else np.asarray(x) + y = None if y is None else np.asarray(y) + return fig.add_contour( x=x, y=y, z=data, **kwargs, ) + + +@dispatch +def draw_surface( + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols"] | None = None, + *, + z: Real[ArrayLike, "rows cols"], + colors: Real[ArrayLike, "rows cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + **kwargs: Any, +) -> Canvas | MplFigure | Figure: # type: ignore[reportInvalidTypeForm] + """ + Plot a 3D surface. + + Args: + x: The x-coordinates corresponding to first dimension + of the surface. + y: The y-coordinates corresponding to second dimension + of the surface. + z: The z-coordinates corresponding to third dimension + of the surface. + colors: The color of values to use. + + In the Plotly backend, the default is to use the values in ``z``. + kwargs: Keyword arguments passed to + :class:`Isocurve`, + :meth:`contour`, + or :class:`Surface`, depending on the + backend. + + Returns: + The resulting plot output. + + Examples: + The following example shows how plot a 3-D surface, + without and with custom coloring. + + .. plotly:: + :fig-vars: fig1, fig2 + + >>> from differt.plotting import draw_surface + >>> + >>> u = np.linspace(0, 2 * np.pi, 100) + >>> v = np.linspace(0, np.pi, 100) + >>> x = np.outer(np.cos(u), np.sin(v)) + >>> y = np.outer(np.sin(u), np.sin(v)) + >>> z = np.outer(np.cos(u), np.cos(v)) + >>> fig1 = draw_surface(x, y, z=z, backend="plotly") + >>> fig1 # doctest: +SKIP + >>> + >>> fig2 = draw_surface( + ... x, y, z=z, colors=x * x + y * y + z * z, backend="plotly" + ... ) + >>> fig2 # doctest: +SKIP + + """ + + +@draw_surface.register("plotly") +def _( + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols"] | None = None, + *, + z: Real[ArrayLike, "rows cols"], + colors: Real[ArrayLike, "rows cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + **kwargs: Any, +) -> Figure: # type: ignore[reportInvalidTypeForm] + fig = process_plotly_kwargs(kwargs) + + x = None if x is None else np.asarray(x) + y = None if y is None else np.asarray(y) + z = np.asarray(z) + colors = None if colors is None else np.asarray(colors) + + fig.add_surface(x=x, y=y, z=z, surfacecolor=colors, **kwargs) + + return fig diff --git a/differt/src/differt/scene/triangle_scene.py b/differt/src/differt/scene/triangle_scene.py index cac9b84b..769c80f9 100644 --- a/differt/src/differt/scene/triangle_scene.py +++ b/differt/src/differt/scene/triangle_scene.py @@ -9,7 +9,6 @@ import equinox as eqx import jax import jax.numpy as jnp -import numpy as np from beartype import beartype as typechecker from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map @@ -23,7 +22,7 @@ TriangleMesh, ) from differt.geometry.utils import assemble_paths -from differt.plotting import draw_markers, reuse +from differt.plotting import PlotOutput, draw_markers, reuse from differt.rt.image_method import ( consecutive_vertices_are_on_same_side_of_mirrors, image_method, @@ -526,7 +525,7 @@ def plot( rx_kwargs: Mapping[str, Any] | None = None, mesh_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, - ) -> Any: # TODO: change output type + ) -> PlotOutput: """ Plot this scene on a 3D scene. @@ -549,12 +548,10 @@ def plot( with reuse(**kwargs) as result: if self.transmitters.size > 0: - draw_markers( - np.asarray(self.transmitters).reshape((-1, 3)), **tx_kwargs - ) + draw_markers(self.transmitters.reshape((-1, 3)), **tx_kwargs) if self.receivers.size > 0: - draw_markers(np.asarray(self.receivers).reshape((-1, 3)), **rx_kwargs) + draw_markers(self.receivers.reshape((-1, 3)), **rx_kwargs) self.mesh.plot(**mesh_kwargs) diff --git a/differt/tests/plotting/test_core.py b/differt/tests/plotting/test_core.py index 5084b100..94274ad0 100644 --- a/differt/tests/plotting/test_core.py +++ b/differt/tests/plotting/test_core.py @@ -10,6 +10,7 @@ draw_mesh, draw_paths, draw_rays, + draw_surface, use, ) @@ -21,8 +22,8 @@ def test_draw_mesh( backend: str, ) -> None: - vertices = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], dtype=float) - triangles = np.array([[0, 1, 2], [0, 2, 3]], dtype=int) + vertices = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]]) + triangles = np.array([[0, 1, 2], [0, 2, 3]]) with use(backend): _ = draw_mesh(vertices, triangles) @@ -135,3 +136,34 @@ def test_draw_contour( levels=levels, fill=fill, ) + + +@pytest.mark.parametrize( + "backend", + ["vispy", "matplotlib", "plotly"], +) +@pytest.mark.parametrize( + "pass_xy", + [True, False], +) +@pytest.mark.parametrize( + "pass_colors", + [True, False], +) +def test_draw_surface( + backend: str, + pass_xy: bool, + pass_colors: bool, +) -> None: + x = np.linspace(0, 1, 10) + y = np.linspace(0, 1, 20) + X, Y = np.meshgrid(x, y) # noqa: N806 + Z = X * Y + + with use(backend): + _ = draw_surface( + x=x if pass_xy else None, + y=y if pass_xy else None, + z=z, + colors=X * X + Y * Y + Z * Z if pass_colors else None, + ) diff --git a/docs/source/notebooks/advanced_path_tracing.ipynb b/docs/source/notebooks/advanced_path_tracing.ipynb index 3bbe50ea..3cbcfbae 100644 --- a/docs/source/notebooks/advanced_path_tracing.ipynb +++ b/docs/source/notebooks/advanced_path_tracing.ipynb @@ -183,9 +183,8 @@ " 23, # Green\n", "] # Ideally, you will never hard-code the primitive indices yourself\n", "\n", - "# differt.plotting (dplt) works with NumPy arrays, not JAX arrays\n", - "vertices = np.asarray(mesh.vertices)\n", - "triangles = np.asarray(mesh.triangles[select, :])\n", + "vertices = mesh.vertices\n", + "triangles = mesh.triangles[select, :]\n", "\n", "dplt.draw_mesh(vertices, triangles[:2, :], figure=fig, color=\"red\")\n", "dplt.draw_mesh(vertices, triangles[2:, :], figure=fig, color=\"green\")" diff --git a/docs/source/numpy_vs_jax.md b/docs/source/numpy_vs_jax.md index 75ed611d..9635cb8e 100644 --- a/docs/source/numpy_vs_jax.md +++ b/docs/source/numpy_vs_jax.md @@ -34,13 +34,13 @@ We can identify two specific cases: 1. For plotting, we rely on third-party libraries that may not support JAX arrays, e.g., Vispy. As a result, - {mod}`differt.plotting` only works with NumPy arrays. + {mod}`differt.plotting` automatically convert NumPy arrays. 2. In the Rust code, there is no way of directly creating JAX arrays, but well for NumPy. Therefore, directly calling the functions declared with Rust code will return NumPy arrays. Similarly, NumPy arrays use the following type annotations: -`Dtype[ndarray, 'Shape']`. +`Dtype[np.ndarray, 'Shape']`. ## From JAX to NumPy and vice-versa From 6d033f1973fc81ff40e8688bb83f784287607a0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 14 Oct 2024 15:21:21 +0200 Subject: [PATCH 2/3] chore(docs): fixes --- differt/src/differt/scene/triangle_scene.py | 4 ++-- differt/tests/plotting/test_core.py | 6 +++--- docs/source/conf.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/differt/src/differt/scene/triangle_scene.py b/differt/src/differt/scene/triangle_scene.py index 769c80f9..55f120ef 100644 --- a/differt/src/differt/scene/triangle_scene.py +++ b/differt/src/differt/scene/triangle_scene.py @@ -548,10 +548,10 @@ def plot( with reuse(**kwargs) as result: if self.transmitters.size > 0: - draw_markers(self.transmitters.reshape((-1, 3)), **tx_kwargs) + draw_markers(self.transmitters, **tx_kwargs) if self.receivers.size > 0: - draw_markers(self.receivers.reshape((-1, 3)), **rx_kwargs) + draw_markers(self.receivers, **rx_kwargs) self.mesh.plot(**mesh_kwargs) diff --git a/differt/tests/plotting/test_core.py b/differt/tests/plotting/test_core.py index 94274ad0..37d6410f 100644 --- a/differt/tests/plotting/test_core.py +++ b/differt/tests/plotting/test_core.py @@ -162,8 +162,8 @@ def test_draw_surface( with use(backend): _ = draw_surface( - x=x if pass_xy else None, - y=y if pass_xy else None, - z=z, + x=X if pass_xy else None, + y=Y if pass_xy else None, + z=Z, colors=X * X + Y * Y + Z * Z if pass_colors else None, ) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1e195228..033ea81c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -105,6 +105,7 @@ # -- Sphinx autodoc typehints settings always_document_param_types = False +always_use_bars_union = True autodoc_member_order = "bysource" # We force class variables to appear first # -- MyST-nb settings From b9daffbcf585d1befd5bc4f2d74ee9c23658fb3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 14 Oct 2024 16:37:12 +0200 Subject: [PATCH 3/3] chore(tests): fix tests --- differt/src/differt/plotting/_core.py | 82 ++++++++++++++++++++++++++- differt/tests/plotting/test_core.py | 9 ++- 2 files changed, 86 insertions(+), 5 deletions(-) diff --git a/differt/src/differt/plotting/_core.py b/differt/src/differt/plotting/_core.py index 4242a090..48769525 100644 --- a/differt/src/differt/plotting/_core.py +++ b/differt/src/differt/plotting/_core.py @@ -770,7 +770,8 @@ def _( # noqa: PLR0917 y = np.arange(m) if y is None else np.asarray(y) - levels = levels if isinstance(levels, int) else np.asarray(levels) + if not isinstance(levels, int) and isinstance(levels, ArrayLike): + levels = np.asarray(levels) if fill: ax.contourf(x, y, data, offset=z0, levels=levels, **kwargs) @@ -856,6 +857,10 @@ def draw_surface( Returns: The resulting plot output. + Warning: + Matplotlib requires ``colors`` to be RGB or RGBA values. + VisPy currently does not support colors. + Examples: The following example shows how plot a 3-D surface, without and with custom coloring. @@ -881,6 +886,75 @@ def draw_surface( """ +@draw_surface.register("vispy") +def _( + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols"] | None = None, + *, + z: Real[ArrayLike, "rows cols"], + colors: Real[ArrayLike, "rows cols"] | Real[ArrayLike, "rows cols 3"] | None = None, + **kwargs: Any, +) -> Canvas: # type: ignore[reportInvalidTypeForm] + from vispy.scene.visuals import SurfacePlot # noqa: PLC0415 + + canvas, view = process_vispy_kwargs(kwargs) + + x = None if x is None else np.asarray(x) + y = None if y is None else np.asarray(y) + z = np.asarray(z) + + if colors is not None: + msg = "VisPy does not currently support coloring like we would like." + warnings.warn(msg, UserWarning, stacklevel=2) + colors = None + + view.add(SurfacePlot(x=x, y=y, z=z, color=colors, **kwargs)) + + return canvas + + +@draw_surface.register("matplotlib") +def _( + x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, + y: Real[ArrayLike, " rows"] | Real[ArrayLike, "rows cols"] | None = None, + *, + z: Real[ArrayLike, "rows cols"], + colors: Real[ArrayLike, "rows cols"] + | Real[ArrayLike, "rows cols 3"] + | Real[ArrayLike, "rows cols 4"] + | None = None, + **kwargs: Any, +) -> MplFigure: # type: ignore[reportInvalidTypeForm] + fig, ax = process_matplotlib_kwargs(kwargs) + + z = np.asarray(z) + + x = np.arange(z.shape[1]) if x is None else np.asarray(x) + + if x.ndim == 1: + x = np.broadcast_to(x[None, :], z.shape) + + y = np.arange(z.shape[0]) if y is None else np.asarray(y) + + if y.ndim == 1: + y = np.broadcast_to(y[:, None], z.shape) + + if colors is not None and "facecolors" not in kwargs: + colors = np.asarray(colors) + if colors.ndim != 3: # noqa: PLR2004 + msg = "Matplotlib requires 'colors' to be RGB or RGBA values." + warnings.warn(msg, UserWarning, stacklevel=2) + c_min = colors.min() + c_max = colors.max() + colors = np.broadcast_to(colors[..., None], (*colors.shape, 3)) + colors = (colors - c_min) / (c_max - c_min) + kwargs["facecolors"] = colors + + ax.plot_surface(x, y, z, **kwargs) + + return fig + + @draw_surface.register("plotly") def _( x: Real[ArrayLike, " cols"] | Real[ArrayLike, "rows cols"] | None = None, @@ -895,8 +969,10 @@ def _( x = None if x is None else np.asarray(x) y = None if y is None else np.asarray(y) z = np.asarray(z) - colors = None if colors is None else np.asarray(colors) - fig.add_surface(x=x, y=y, z=z, surfacecolor=colors, **kwargs) + if colors is not None and "surfacecolor" not in kwargs: + kwargs["surfacecolor"] = np.asarray(colors) + + fig.add_surface(x=x, y=y, z=z, **kwargs) return fig diff --git a/differt/tests/plotting/test_core.py b/differt/tests/plotting/test_core.py index 37d6410f..abaafa6e 100644 --- a/differt/tests/plotting/test_core.py +++ b/differt/tests/plotting/test_core.py @@ -158,9 +158,14 @@ def test_draw_surface( x = np.linspace(0, 1, 10) y = np.linspace(0, 1, 20) X, Y = np.meshgrid(x, y) # noqa: N806 - Z = X * Y + Z = X * Y # noqa: N806 - with use(backend): + if backend in {"vispy", "matplotlib"} and pass_colors: + expectation = pytest.warns(UserWarning) + else: + expectation = does_not_raise() + + with use(backend), expectation: _ = draw_surface( x=X if pass_xy else None, y=Y if pass_xy else None,