diff --git a/differt-core/src/rt/graph.rs b/differt-core/src/rt/graph.rs index f29d1781..5494c0f8 100644 --- a/differt-core/src/rt/graph.rs +++ b/differt-core/src/rt/graph.rs @@ -117,6 +117,9 @@ pub mod complete { /// A complete graph, i.e., /// a simple undirected graph in which every pair of /// distinc nodes is connected by a unique edge. + /// + /// Args: + /// num_nodes (int): The number of nodes. #[pyclass] #[derive(Clone, Debug)] pub struct CompleteGraph { @@ -547,6 +550,12 @@ pub mod directed { } impl DiGraph { + #[inline] + pub fn empty(num_nodes: usize) -> Self { + Self { + edges_list: vec![vec![]; num_nodes], + } + } #[inline] pub fn get_adjacent_nodes(&self, node: NodeId) -> &[NodeId] { self.edges_list[node].as_ref() @@ -574,6 +583,23 @@ pub mod directed { #[pymethods] impl DiGraph { + /// Create an edgeless directed graph with ``num_nodes`` nodes. + /// + /// This is equivalent to creating a directed graph from + /// an adjacency matrix will all entries equal to :py:data:`False`. + /// + /// Args: + /// num\_nodes (int): The number of nodes. + /// + /// Return: + /// DiGraph: A directed graph. + #[classmethod] + #[pyo3(name = "empty")] + #[pyo3(text_signature = "(cls, num_nodes)")] + fn py_empty(_: Bound<'_, PyType>, num_nodes: usize) -> Self { + Self::empty(num_nodes) + } + /// Create a directed graph from an adjacency matrix. /// /// Each row of the adjacency matrix ``M`` contains boolean @@ -1111,6 +1137,16 @@ mod tests { assert_eq!(got, expected); } + #[rstest] + #[case(9, 2)] + #[case(3, 3)] + fn test_empty_di_graph_returns_all_paths(#[case] num_nodes: usize, #[case] depth: usize) { + let mut graph = DiGraph::empty(num_nodes); + let (from, to) = graph.insert_from_and_to_nodes(false); + + assert_eq!(graph.all_paths(from, to, depth + 2, true).count(), 0); + } + #[rstest] #[case(9, 2)] #[case(3, 3)] diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index da508838..e7ef0262 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -4,10 +4,11 @@ from typing import Any import equinox as eqx +import jax import jax.numpy as jnp import numpy as np from beartype import beartype as typechecker -from jaxtyping import Array, Bool, Float, UInt, jaxtyped +from jaxtyping import Array, Bool, Float, PRNGKeyArray, UInt, jaxtyped import differt_core.geometry.triangle_mesh @@ -134,6 +135,13 @@ def diffraction_edges(self) -> UInt[Array, "num_edges 3"]: """The diffraction edges.""" raise NotImplementedError + @cached_property + def bounding_box(self) -> Float[Array, "2 3"]: + """The bounding box (min. and max. coordinates).""" + return jnp.vstack( + (jnp.min(self.vertices, axis=0), jnp.max(self.vertices, axis=0)) + ) + @classmethod def empty(cls) -> "TriangleMesh": """ @@ -207,3 +215,26 @@ def plot(self, **kwargs: Any) -> Any: triangles=np.asarray(self.triangles), **kwargs, ) + + @eqx.filter_jit + def sample( + self, size: int, replace: bool = False, *, key: PRNGKeyArray + ) -> "TriangleMesh": + """ + Generate a new mesh by randomly sampling triangles from this geometry. + + Args: + size: The size of the sample, i.e., the number of triangles. + replace: Whether to sample with or without replacement. + key: The :func:`jax.random.PRNGKey` to be used. + + Return: + A new random mesh. + """ + triangles = self.triangles[ + jax.random.choice( + key, self.triangles.shape[0], shape=(size,), replace=replace + ), + :, + ] + return TriangleMesh(vertices=self.vertices, triangles=triangles) diff --git a/differt/src/differt/plotting/_core.py b/differt/src/differt/plotting/_core.py index 7970c4c7..a97d597b 100644 --- a/differt/src/differt/plotting/_core.py +++ b/differt/src/differt/plotting/_core.py @@ -1,9 +1,7 @@ """Core plotting implementations.""" -from __future__ import annotations - from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import Any, Optional, Union import numpy as np from jaxtyping import Float, Num, UInt @@ -15,18 +13,34 @@ process_vispy_kwargs, ) -if TYPE_CHECKING: +# We cannot use from __future__ import annotations because +# otherwise array annotations do not render correctly. +# We cannot rely on TYPE_CHECKING-guarded annotation +# because Sphinx will fail to import this NumPy or Jax typing +# Hence, we prefer to silence pyright instead. + +try: from matplotlib.figure import Figure as MplFigure +except ImportError: + MplFigure = Any + +try: from plotly.graph_objects import Figure +except ImportError: + Figure = Any + +try: from vispy.scene.canvas import SceneCanvas as Canvas +except ImportError: + Canvas = Any -@dispatch # type: ignore +@dispatch def draw_mesh( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> Canvas | MplFigure | Figure: # type: ignore +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot a 3D mesh made of triangles. @@ -72,7 +86,7 @@ def _( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> Canvas: +) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Mesh canvas, view = process_vispy_kwargs(kwargs) @@ -88,7 +102,7 @@ def _( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> MplFigure: +) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) x, y, z = vertices.T @@ -102,7 +116,7 @@ def _( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> Figure: +) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) x, y, z = vertices.T @@ -111,10 +125,10 @@ def _( return fig.add_mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs) -@dispatch # type: ignore +@dispatch def draw_paths( - paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any -) -> Canvas | MplFigure | Figure: # type: ignore + paths: Float[np.ndarray, r"\*batch path_length 3"], **kwargs: Any +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot a batch of paths of the same length. @@ -137,12 +151,12 @@ def draw_paths( >>> from differt.plotting import draw_paths >>> >>> def rotation(angle: float) -> np.ndarray: - ... c = np.cos(angle) - ... s = np.sin(angle) + ... co = np.cos(angle) + ... si = np.sin(angle) ... return np.array( ... [ - ... [+c, -s, 0.0], - ... [+s, +c, 0.0], + ... [+co, -si, 0.0], + ... [+si, +co, 0.0], ... [0.0, 0.0, 1.0], ... ] ... ) @@ -172,7 +186,7 @@ def draw_paths( @draw_paths.register("vispy") -def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Canvas: +def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import LinePlot canvas, view = process_vispy_kwargs(kwargs) @@ -186,7 +200,7 @@ def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Canvas @draw_paths.register("matplotlib") -def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> MplFigure: +def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) for i in np.ndindex(paths.shape[:-2]): @@ -196,7 +210,7 @@ def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> MplFig @draw_paths.register("plotly") -def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Figure: +def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) for i in np.ndindex(paths.shape[:-2]): @@ -206,13 +220,13 @@ def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Figure return fig -@dispatch # type: ignore +@dispatch def draw_markers( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> Canvas | MplFigure | Figure: # type: ignore +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot markers and, optionally, their label. @@ -259,10 +273,10 @@ def draw_markers( @draw_markers.register("vispy") def _( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> Canvas: +) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Markers, Text canvas, view = process_vispy_kwargs(kwargs) @@ -280,20 +294,20 @@ def _( @draw_markers.register("matplotlib") def _( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> MplFigure: +) -> MplFigure: # type: ignore[reportInvalidTypeForm] raise NotImplementedError # TODO @draw_markers.register("plotly") def _( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> Figure: +) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) if labels: @@ -309,14 +323,18 @@ def _( ) -@dispatch # type: ignore +@dispatch def draw_image( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " *m"] | None = None, - y: Float[np.ndarray, " *n"] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> Canvas | MplFigure | Figure: # type: ignore +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot a 2D image on a 3D canvas, at using a fixed z-coordinate. @@ -368,12 +386,16 @@ def draw_image( @draw_image.register("vispy") def _( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " ..."] | None = None, - y: Float[np.ndarray, " ..."] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> Canvas: +) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Image from vispy.visuals.transforms import STTransform @@ -416,12 +438,16 @@ def _( @draw_image.register("matplotlib") def _( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " ..."] | None = None, - y: Float[np.ndarray, " ..."] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> MplFigure: +) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) ax.plot_surface(X=x, Y=y, Z=np.full_like(data, z0), color=data, **kwargs) @@ -431,12 +457,16 @@ def _( @draw_image.register("plotly") def _( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " ..."] | None = None, - y: Float[np.ndarray, " ..."] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> Figure: +) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) return fig.add_surface( diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index 7df79ba7..22a03660 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -1,13 +1,13 @@ """Useful decorators for plotting.""" -from __future__ import annotations - import importlib +import sys +import types from collections.abc import Iterator, MutableMapping from contextlib import contextmanager from functools import wraps from threading import Lock -from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union # Immutables @@ -24,28 +24,31 @@ DEFAULT_KWARGS: MutableMapping[str, Any] = {} """The default keyword arguments.""" -if TYPE_CHECKING: - import sys +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +P = ParamSpec("P") + +if TYPE_CHECKING: from matplotlib.figure import Figure as MplFigure from mpl_toolkits.mplot3d import Axes3D from plotly.graph_objects import Figure - from vispy.scene.canvas import SceneCanvas + from vispy.scene.canvas import SceneCanvas as Canvas from vispy.scene.widgets.viewbox import ViewBox - - if sys.version_info >= (3, 10): - from typing import ParamSpec - else: - from typing_extensions import ParamSpec - - P = ParamSpec("P") - T = TypeVar("T", SceneCanvas, MplFigure, Figure) else: - P = TypeVar("P") - T = TypeVar("T") + MplFigure = Any + Axes3D = Any + Figure = Any + Canvas = Any + ViewBox = Any +T = TypeVar("T", Canvas, MplFigure, Figure) -def set_defaults(backend: str | None = None, **kwargs: Any) -> str: + +def set_defaults(backend: Optional[str] = None, **kwargs: Any) -> str: """ Set default keyword arguments for future plotting utilities. @@ -64,35 +67,45 @@ def set_defaults(backend: str | None = None, **kwargs: Any) -> str: ImportError: If the backend is not installed. Examples: - The following example shows how to set the default plotting backend. + The following example shows how to set the default plotting backend + and other plotting defaults. >>> import differt.plotting as dplt >>> >>> @dplt.dispatch - ... def my_plot(): + ... def my_plot(*args, **kwargs): ... pass >>> >>> @my_plot.register("vispy") - ... def _(): - ... print("Using vispy backend") + ... def _(*args, **kwargs): + ... dplt.process_vispy_kwargs(kwargs) + ... print(f"Using vispy backend with {args = }, {kwargs = }") >>> >>> @my_plot.register("matplotlib") - ... def _(): - ... print("Using matplotlib backend") + ... def _(*args, **kwargs): + ... dplt.process_matplotlib_kwargs(kwargs) + ... print(f"Using matplotlib backend with {args = }, {kwargs = }") >>> >>> my_plot() # When not specified, use default backend - Using vispy backend - >>> - >>> my_plot(backend="matplotlib") # We can force the backend - Using matplotlib backend + Using vispy backend with args = (), kwargs = {} >>> - >>> dplt.set_defaults("matplotlib") # Or change the default backend... + >>> dplt.set_defaults("matplotlib") # We can change the default backend 'matplotlib' >>> my_plot() # So that now it defaults to 'matplotlib' - Using matplotlib backend + Using matplotlib backend with args = (), kwargs = {} >>> + >>> dplt.set_defaults( + ... "matplotlib", color="red" + ... ) # We can also specify additional defaults + 'matplotlib' + >>> my_plot() # So that now it defaults to 'matplotlib' and color='red' + Using matplotlib backend with args = (), kwargs = {'color': 'red'} >>> my_plot(backend="vispy") # Of course, the 'vispy' backend is still available - Using vispy backend + Using vispy backend with args = (), kwargs = {'color': 'red'} + >>> my_plot(backend="vispy", color="green") # And we can also override any default + Using vispy backend with args = (), kwargs = {'color': 'green'} + >>> dplt.set_defaults("vispy") # Reset all defaults + 'vispy' """ global DEFAULT_BACKEND, DEFAULT_KWARGS @@ -104,16 +117,15 @@ def set_defaults(backend: str | None = None, **kwargs: Any) -> str: f"We currently support: {', '.join(SUPPORTED_BACKENDS)}." ) - with BACKEND_LOCK: - try: - importlib.import_module(f"{backend}") - DEFAULT_BACKEND = backend - DEFAULT_KWARGS = kwargs - return backend - except ImportError: - raise ImportError( - f"Could not load backend '{backend}', did you install it?" - ) from None + try: + importlib.import_module(f"{backend}") + DEFAULT_BACKEND = backend + DEFAULT_KWARGS = kwargs + return backend + except ImportError: + raise ImportError( + f"Could not load backend '{backend}', did you install it?" + ) from None @contextmanager @@ -132,34 +144,68 @@ def use(*args: Any, **kwargs: Any) -> Iterator[str]: Return: The name of the default backend used in this context. + + Examples: + The following example shows how set plot defaults in a context. + + >>> import differt.plotting as dplt + >>> + >>> @dplt.dispatch + ... def my_plot(*args, **kwargs): + ... pass + >>> + >>> @my_plot.register("vispy") + ... def _(*args, **kwargs): + ... dplt.process_vispy_kwargs(kwargs) + ... print(f"Using vispy backend with {args = }, {kwargs = }") + >>> + >>> @my_plot.register("plotly") + ... def _(*args, **kwargs): + ... dplt.process_plotly_kwargs(kwargs) + ... print(f"Using plotly backend with {args = }, {kwargs = }") + >>> + >>> my_plot() # When not specified, use default backend + Using vispy backend with args = (), kwargs = {} + >>> with dplt.use(): # No parameters = reset defaults (except the default backend) + ... my_plot() + Using vispy backend with args = (), kwargs = {} + >>> with dplt.use("plotly"): # We can change the default backend + ... my_plot() # So that now it defaults to 'matplotlib' + Using plotly backend with args = (), kwargs = {} + >>> + >>> with dplt.use( + ... "plotly", color="black" + ... ): # We can also specify additional defaults + ... my_plot() + Using plotly backend with args = (), kwargs = {'color': 'black'} """ global DEFAULT_BACKEND, DEFAULT_KWARGS default_backend = DEFAULT_BACKEND default_kwargs = DEFAULT_KWARGS - try: - yield set_defaults(*args, **kwargs) - finally: - DEFAULT_BACKEND = default_backend - DEFAULT_KWARGS = default_kwargs + with BACKEND_LOCK: + try: + yield set_defaults(*args, **kwargs) + finally: + DEFAULT_BACKEND = default_backend + DEFAULT_KWARGS = default_kwargs -class Dispatcher(Protocol, Generic[P, T]): # pragma: no cover +class _Dispatcher(Generic[P, T]): + registry: types.MappingProxyType[str, Callable[P, T]] + def __call__( self, *args: P.args, **kwargs: P.kwargs, ) -> T: ... - def register( self, backend: str, ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... - def dispatch(self, backend: str) -> Callable[P, T]: ... - -def dispatch(fun: Callable[P, T]) -> Dispatcher[P, T]: +def dispatch(fun: Callable[P, T]) -> _Dispatcher[P, T]: """ Transform a function into a backend dispatcher for plot functions. @@ -168,8 +214,12 @@ def dispatch(fun: Callable[P, T]) -> Dispatcher[P, T]: functions for each backend implementation. Return: - The same callable, wrapped in a :py:class:`Dispatcher` - class instance. + A callable that can register backend implementations with ``register``. + + Notes: + Only the functions registered with ``register`` will be called. + The :data:`fun` argument wrapped inside :func:`dispatch` is + only used for documentation, but never called. Examples: The following example shows how one can implement plotting @@ -222,12 +272,21 @@ class instance. Traceback (most recent call last): ValueError: Unsupported backend 'numpy', allowed values are: ... """ - registry = {} + registry: dict[str, Callable[P, T]] = {} def register( backend: str, ) -> Callable[[Callable[P, T]], Callable[P, T]]: - """Register a new implementation.""" + """ + Return a wrapper that will call the decorated function for the specified backend. + + Args: + backend: The name of backend for which the decorated + function will be called. + + Return: + A wrapper to be put before the backend-specific implementation. + """ if backend not in SUPPORTED_BACKENDS: raise ValueError( f"Unsupported backend '{backend}', " @@ -254,19 +313,21 @@ def __wrapper__(*args: P.args, **kwargs: P.kwargs) -> T: # noqa: N807 return wrapper - def dispatch(backend: str) -> Callable[P, T]: - try: - return registry[backend] - except KeyError: - raise NotImplementedError( - f"No backend implementation for '{backend}'" - ) from None - @wraps(fun) - def main_wrapper( + def wrapper( *args: P.args, **kwargs: P.kwargs, ) -> T: + """ + Call the appropriate backend implementation based on the default backend and the provided arguments. + + Args: + args: Positional arguments passed to the correct backend implementation. + kwargs: Keyword arguments passed to the correct backend implementation. + + Return: + The result of the call. + """ # We cannot currently add keyword argument to the signature, # at least Pyright will not allow that, # see: https://github.com/microsoft/pyright/issues/5844. @@ -274,16 +335,23 @@ def main_wrapper( # The motivation is detailed in P612: # https://peps.python.org/pep-0612/#concatenating-keyword-parameters. backend: str = kwargs.pop("backend", DEFAULT_BACKEND) # type: ignore - return dispatch(backend)(*args, **kwargs) - main_wrapper.register = register # type: ignore[attr-defined] - main_wrapper.dispatch = dispatch # type: ignore[attr-defined] - main_wrapper.registry = registry # type: ignore[attr-defined] + try: + return registry[backend](*args, **kwargs) + except KeyError: + raise NotImplementedError( + f"No backend implementation for '{backend}'" + ) from None - return main_wrapper # type: ignore[return-value] + return wrapper + + wrapper.register = register # type: ignore + wrapper.registry = types.MappingProxyType(registry) # type: ignore + return wrapper # type: ignore -def view_from_canvas(canvas: SceneCanvas) -> ViewBox: + +def view_from_canvas(canvas: Canvas) -> ViewBox: """ Return the view from the specified canvas. @@ -321,7 +389,7 @@ def default_view() -> ViewBox: def process_vispy_kwargs( kwargs: MutableMapping[str, Any], -) -> tuple[SceneCanvas, ViewBox]: +) -> tuple[Canvas, ViewBox]: """ Process keyword arguments passed to some VisPy plotting utility. @@ -333,7 +401,7 @@ def process_vispy_kwargs( The keys specified below will be removed from the mapping. Keyword Args: - convas (:py:class:`SceneCanvas`): + canvas (:py:class:`SceneCanvas`): The canvas that draws contents of the scene. If not provided, will try to access canvas from ``view`` (if supplied). view (:py:class:`Viewbox`): @@ -410,7 +478,7 @@ def process_matplotlib_kwargs( or plt.figure() ) - def current_ax3d() -> Axes3D | None: + def current_ax3d() -> Optional[Axes3D]: if len(figure.axes) > 0: ax = figure.gca() if isinstance(ax, Axes3D): @@ -455,20 +523,36 @@ def process_plotly_kwargs( @contextmanager -def reuse(**kwargs: Any) -> Iterator[SceneCanvas | MplFigure | Figure]: +def reuse(**kwargs: Any) -> Iterator[Union[Canvas, MplFigure, Figure]]: """Create a context manager that will automatically reuse the current canvas / figure. Args: - args: Positional arguments passed to - :py:func:`set_defaults`. kwargs: Keywords arguments passed to :py:func:`set_defaults`. Return: The canvas or figure that is reused for this context. + + Examples: + The following example show how the same figure is reused + for multiple plots. + + .. plotly:: + + >>> from differt.plotting import draw_image, reuse + >>> + >>> x = np.linspace(-1.0, +1.0, 100) + >>> y = np.linspace(-4.0, +4.0, 200) + >>> X, Y = np.meshgrid(x, y) + >>> + >>> with reuse(backend="plotly") as fig: # doctest: +SKIP + ... for z0, w in enumerate(jnp.linspace(0, 10 * jnp.pi, 5)): + ... Z = np.cos(w * X) * np.sin(w * Y) + ... draw_image(Z, x=x, y=y, z0=z0) # TODO: fix colorbar + >>> fig # doctest: +SKIP """ global DEFAULT_KWARGS - backend: str | None = kwargs.pop("backend", None) + backend: Optional[str] = kwargs.pop("backend", None) with use(backend=backend) as b: try: diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index 10b6321f..7a669550 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -5,12 +5,13 @@ from functools import partial from typing import Any, Callable, Optional, Union -import chex # TODO: fixme, chex is not a dependency +import chex +import equinox as eqx import jax import jax.numpy as jnp import optax from beartype import beartype as typechecker -from jaxtyping import Array, Num, Shaped, jaxtyped +from jaxtyping import Array, Float, Num, PRNGKeyArray, Shaped, jaxtyped if sys.version_info >= (3, 11): from typing import TypeVarTuple, Unpack @@ -229,3 +230,32 @@ def f( (x, _), losses = jax.lax.scan(f, init=(x0, opt_state), xs=None, length=steps) return x, losses[-1] + + +@eqx.filter_jit +@jaxtyped(typechecker=typechecker) +def sample_points_in_bounding_box( + bounding_box: Float[Array, "2 3"], size: Optional[int] = None, *, key: PRNGKeyArray +) -> Union[Float[Array, "size 3"], Float[Array, "3"]]: + """ + Sample point(s) in a 3D bounding box. + + Args: + bounding_box: The bounding box (min. and max. coordinates). + size: The sample size or :py:data:`None`. If :py:data:`None`, + the returned array is 1D. Otherwise, it is 2D. + key: The :class:`jax.random.PRNGKey` to be used. + + Return: + An array of points randomly sampled. + """ + amin = bounding_box[0, :] + amax = bounding_box[1, :] + scale = amax - amin + + if size is None: + r = jax.random.uniform(key, shape=(3,)) + return r * scale + amin + + r = jax.random.uniform(key, shape=(size, 3)) + return r * scale[None, :] + amin[None, :] diff --git a/differt/tests/conftest.py b/differt/tests/conftest.py index 75d2d8f8..7e07c762 100644 --- a/differt/tests/conftest.py +++ b/differt/tests/conftest.py @@ -1,10 +1,22 @@ from pathlib import Path +import jax import pytest +from jaxtyping import PRNGKeyArray from differt.scene.sionna import download_sionna_scenes +@pytest.fixture +def seed() -> int: + return 1234 + + +@pytest.fixture +def key(seed: int) -> PRNGKeyArray: + return jax.random.PRNGKey(seed) + + def pytest_sessionstart(session: pytest.Session) -> None: download_sionna_scenes() diff --git a/differt/tests/plotting/test_utils.py b/differt/tests/plotting/test_utils.py index 2e57431c..073e52dd 100644 --- a/differt/tests/plotting/test_utils.py +++ b/differt/tests/plotting/test_utils.py @@ -2,6 +2,7 @@ import builtins import importlib +import platform import sys from contextlib import AbstractContextManager as ContextManager from contextlib import contextmanager @@ -24,6 +25,11 @@ view_from_canvas, ) +if not platform.system() == "Darwin" and platform.processor() == "arm": + pytest.skip( + "skipping tests on macOS (m1) runners at the moment...", allow_module_level=True + ) + _LOCK = Lock() @@ -139,13 +145,7 @@ def est_missing_default_backend_module( @pytest.mark.parametrize( "backend", - ( - "vispy", - "matplotlib", - pytest.param( - "plotly", marks=pytest.mark.xfail(reason="Unknown, to be investigated...") - ), - ), + ("vispy", "matplotlib", "plotly"), ) def test_missing_backend_module( backend: str, missing_modules: MissingModulesContextGenerator diff --git a/differt/tests/test_utils.py b/differt/tests/test_utils.py index b0c850fe..204ab8c1 100644 --- a/differt/tests/test_utils.py +++ b/differt/tests/test_utils.py @@ -1,9 +1,9 @@ import chex import jax.numpy as jnp import pytest -from jaxtyping import Array +from jaxtyping import Array, PRNGKeyArray -from differt.utils import minimize, sorted_array2 +from differt.utils import minimize, sample_points_in_bounding_box, sorted_array2 @pytest.mark.parametrize( @@ -67,3 +67,28 @@ def fun(x: Array, a: Array, b: Array, c: Array) -> Array: chex.assert_trees_all_close(got_x, expected_x) chex.assert_trees_all_close(got_loss, c) + + +def test_sample_points_in_bounding_box(key: PRNGKeyArray) -> None: + def assert_in_bounds(a: Array, bounds: Array) -> None: + if a.ndim == 1: + a = jnp.reshape(a, (1, a.size)) + + assert bounds.shape[0] == 2 + assert a.shape[1] == bounds.shape[1] + + for i in range(a.shape[1]): + assert jnp.all(a[:, i] >= bounds[0, i]) + assert jnp.all(a[:, i] <= bounds[1, i]) + + bounding_box = jnp.array([[-1.0, -2.0, -3.0], [+4.0, +5.0, +6.0]]) + + got = sample_points_in_bounding_box(bounding_box, key=key) + + assert_in_bounds(got, bounding_box) + assert got.shape == (3,) + + got = sample_points_in_bounding_box(bounding_box, size=100, key=key) + + assert_in_bounds(got, bounding_box) + assert got.shape == (100, 3) diff --git a/docs/source/conf.py b/docs/source/conf.py index 718fe447..ee06339f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,6 +11,8 @@ from datetime import date from typing import Any +from sphinx.application import Sphinx + from differt import __version__ project = "DiffeRT" @@ -71,7 +73,7 @@ "../../differt-core/python/differt_core", ] apidoc_output_dirs = "reference" -apidoc_exclude_patterns = ["conftest.py", "*scene/scenes/*"] +apidoc_exclude_patterns = ["conftest.py", "scene/scenes/**"] apidoc_separate = True apidoc_no_toc = True apidoc_max_depth = 1 @@ -107,8 +109,6 @@ ("*", "text/html", 0), ] -# TODO: fix JS warnings about html-manager (wrong version?) - # -- Bibtex bibtex_bibfiles = ["references.bib"] @@ -171,24 +171,11 @@ # Patches -# TODO: fix Plotly's Figure not linking to docs with intersphinx. - -""" -def fix_signature(app, what, name, obj, options, signature, return_annotation): - target = "~plotly.graph_objs._figure.Figure" - sub = ":py:class:`Figure`" - sub = "~plotly.graph_objects.Figure" - if return_annotation and target in return_annotation: - return_annotation = return_annotation.replace(target, sub) - return signature, return_annotation.replace(target, sub) - - -def setup(app): - app.connect("autodoc-process-signature", fix_signature, priority=999) -""" +# TODO: fix Plotly's Figure not linking to docs with intersphinx, +# reported here https://github.com/sphinx-doc/sphinx/issues/12360. -def fix_sionna_folder(app, obj: Any, bound_method: bool) -> None: +def fix_sionna_folder(app: Sphinx, obj: Any, bound_method: bool) -> None: """ Rename the default folder to a more readeable name. """ diff --git a/pyproject.toml b/pyproject.toml index b40a13c8..9031ad59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ omit = ["**/*/conftest.py"] allow-direct-references = true [tool.pyright] -include = ["differt/src/differt", "tests"] +deprecateTypingAliases = true +include = ["differt/src/differt", "differt/tests"] venv = ".venv" venvPath = "." diff --git a/src/differt_dev/sphinxext/apidoc.py b/src/differt_dev/sphinxext/apidoc.py index bc382d8c..40f4f75a 100644 --- a/src/differt_dev/sphinxext/apidoc.py +++ b/src/differt_dev/sphinxext/apidoc.py @@ -71,18 +71,17 @@ def builder_inited(app: Sphinx) -> None: # noqa: C901 output_dir = path.join(app.srcdir, output_dir) - apidoc.main([*options, f"-o={output_dir}", module_dir, *exclude_patterns]) - apidoc.main( - [ - *options, - f"-o={output_dir}", - module_dir, - *[ - path.join(module_dir, exclude_pattern) - for exclude_pattern in exclude_patterns - ], - ] - ) + args = [ + *options, + f"-o={output_dir}", + module_dir, + *[ + path.join(module_dir, exclude_pattern) + for exclude_pattern in exclude_patterns + ], + ] + + apidoc.main(args) def setup(app: Sphinx) -> None: