From c1f9cd85c6b5b9c8164c93d7d704d9df10350232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 6 May 2024 18:18:21 +0200 Subject: [PATCH 01/16] feat(lib): add random sampling and ML features --- differt/src/differt/geometry/triangle_mesh.py | 31 ++++++++++++++++++- differt/src/differt/utils.py | 31 ++++++++++++++++++- differt/tests/conftest.py | 11 +++++++ differt/tests/test_utils.py | 28 ++++++++++++++++- 4 files changed, 98 insertions(+), 3 deletions(-) diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index da508838..95d5e1d6 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -4,10 +4,11 @@ from typing import Any import equinox as eqx +import jax import jax.numpy as jnp import numpy as np from beartype import beartype as typechecker -from jaxtyping import Array, Bool, Float, UInt, jaxtyped +from jaxtyping import Array, Bool, Float, Key, UInt, jaxtyped import differt_core.geometry.triangle_mesh @@ -134,6 +135,13 @@ def diffraction_edges(self) -> UInt[Array, "num_edges 3"]: """The diffraction edges.""" raise NotImplementedError + @cached_property + def bounding_box(self) -> Float[Array, "2 3"]: + """The bounding box (min. and max. coordinates).""" + return jnp.vstack( + (jnp.min(self.vertices, axis=0), jnp.max(self.vertices, axis=0)) + ) + @classmethod def empty(cls) -> "TriangleMesh": """ @@ -207,3 +215,24 @@ def plot(self, **kwargs: Any) -> Any: triangles=np.asarray(self.triangles), **kwargs, ) + + @eqx.filter_jit + def sample(self, size: int, replace: bool = False, *, key: Key) -> "TriangleMesh": + """ + Generate a new mesh by randomly sampling triangles from this geometry. + + Args: + size: The size of the sample, i.e., the number of triangles. + replace: Whether to sample with or without replacement. + key: The :class:`jax.random.PRNGKey` to be used. + + Return: + A new random mesh. + """ + triangles = self.triangles[ + jax.random.choice( + key, self.triangles.shape[0], shape=(size,), replace=replace + ), + :, + ] + return TriangleMesh(vertices=self.vertices, triangles=triangles) diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index 799b6ff9..39f012b7 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -6,11 +6,12 @@ from typing import Any, Callable, Optional, Union import chex +import equinox as eqx import jax import jax.numpy as jnp import optax from beartype import beartype as typechecker -from jaxtyping import Array, Num, Shaped, jaxtyped +from jaxtyping import Array, Float, Key, Num, Shaped, jaxtyped if sys.version_info >= (3, 11): from typing import TypeVarTuple, Unpack @@ -228,3 +229,31 @@ def f( (x, _), losses = jax.lax.scan(f, init=(x0, opt_state), xs=None, length=steps) return x, losses[-1] + + +@eqx.filter_jit +def sample_points_in_bounding_box( + bounding_box: Float[Array, "2 3"], size: Optional[int] = None, *, key: Key +) -> Float[Array, "?size 3"]: + """ + Sample point(s) in a 3D bounding box. + + Args: + bounding_box: The bounding box (min. and max. coordinates). + size: The sample size or :py:data:`None`. If :py:data:`None`, + the returned array is 1D. Otherwise, it is 2D. + key: The :class:`jax.random.PRNGKey` to be used. + + Return: + An array of points randomly sampled. + """ + amin = bounding_box[0, :] + amax = bounding_box[1, :] + scale = amax - amin + + if size is None: + r = jax.random.uniform(key, shape=(3,)) + return r * scale + amin + + r = jax.random.uniform(key, shape=(size, 3)) + return r * scale[None, :] + amin[None, :] diff --git a/differt/tests/conftest.py b/differt/tests/conftest.py index 75d2d8f8..96fbea27 100644 --- a/differt/tests/conftest.py +++ b/differt/tests/conftest.py @@ -1,10 +1,21 @@ from pathlib import Path +import jax import pytest from differt.scene.sionna import download_sionna_scenes +@pytest.fixture +def seed() -> int: + return 1234 + + +@pytest.fixture +def key(seed: int) -> jax.random.PRNGKey: + return jax.random.PRNGKey(seed) + + def pytest_sessionstart(session: pytest.Session) -> None: download_sionna_scenes() diff --git a/differt/tests/test_utils.py b/differt/tests/test_utils.py index b0c850fe..04610346 100644 --- a/differt/tests/test_utils.py +++ b/differt/tests/test_utils.py @@ -1,9 +1,10 @@ import chex +import jax import jax.numpy as jnp import pytest from jaxtyping import Array -from differt.utils import minimize, sorted_array2 +from differt.utils import minimize, sample_points_in_bounding_box, sorted_array2 @pytest.mark.parametrize( @@ -67,3 +68,28 @@ def fun(x: Array, a: Array, b: Array, c: Array) -> Array: chex.assert_trees_all_close(got_x, expected_x) chex.assert_trees_all_close(got_loss, c) + + +def test_sample_points_in_bounding_box(key: jax.random.PRNGKey) -> None: + def assert_in_bounds(a: Array, bounds: Array) -> None: + if a.ndim == 1: + a = jnp.reshape(a, (1, a.size)) + + assert bounds.shape[0] == 2 + assert a.shape[1] == bounds.shape[1] + + for i in range(a.shape[1]): + assert jnp.all(a[:, i] >= bounds[0, i]) + assert jnp.all(a[:, i] <= bounds[1, i]) + + bounding_box = jnp.array([[-1.0, -2.0, -3.0], [+4.0, +5.0, +6.0]]) + + got = sample_points_in_bounding_box(bounding_box, key=key) + + assert_in_bounds(got, bounding_box) + assert got.shape == (3,) + + got = sample_points_in_bounding_box(bounding_box, size=100, key=key) + + assert_in_bounds(got, bounding_box) + assert got.shape == (100, 3) From b9d95f799a1fafc53d9e8b2948d3402b51311aa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 7 May 2024 10:30:04 +0200 Subject: [PATCH 02/16] feat(lib): add empty graph method and fix type hint --- differt-core/src/rt/graph.rs | 36 +++++++++++++++++++ differt/src/differt/geometry/triangle_mesh.py | 6 ++-- differt/tests/conftest.py | 3 +- differt/tests/test_utils.py | 5 ++- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/differt-core/src/rt/graph.rs b/differt-core/src/rt/graph.rs index 527ae3d9..ac84ae4a 100644 --- a/differt-core/src/rt/graph.rs +++ b/differt-core/src/rt/graph.rs @@ -117,6 +117,9 @@ pub mod complete { /// A complete graph, i.e., /// a simple undirected graph in which every pair of /// distinc nodes is connected by a unique edge. + /// + /// Args: + /// num_nodes: The number of nodes. #[pyclass] #[derive(Clone, Debug)] pub struct CompleteGraph { @@ -542,6 +545,12 @@ pub mod directed { } impl DiGraph { + #[inline] + pub fn empty(num_nodes: usize) -> Self { + Self { + edges_list: vec![vec![]; num_nodes], + } + } #[inline] pub fn get_adjacent_nodes(&self, node: NodeId) -> &[NodeId] { self.edges_list[node].as_ref() @@ -569,6 +578,23 @@ pub mod directed { #[pymethods] impl DiGraph { + /// Create an edgeless directed graph with ``num_nodes`` nodes. + /// + /// This is equivalent to creating a directed graph from + /// an adjacency matrix will all entries equal to :py:data:`False`. + /// + /// Args: + /// graph (CompleteGraph): The number of nodes. + /// + /// Return: + /// DiGraph: A directed graph. + #[classmethod] + #[pyo3(name = "empty")] + #[pyo3(text_signature = "(cls, num_nodes)")] + fn py_empty(_: Bound<'_, PyType>, num_nodes: usize) -> Self { + Self::empty(num_nodes) + } + /// Create a directed graph from an adjacency matrix. /// /// Each row of the adjacency matrix ``M`` contains boolean @@ -1083,6 +1109,16 @@ mod tests { assert_eq!(got, expected); } + #[rstest] + #[case(9, 2)] + #[case(3, 3)] + fn test_empty_di_graph_returns_all_paths(#[case] num_nodes: usize, #[case] depth: usize) { + let mut graph = DiGraph::empty(num_nodes); + let (from, to) = graph.insert_from_and_to_nodes(false); + + assert_eq!(graph.all_paths(from, to, depth + 2, true).count(), 0); + } + #[rstest] #[case(9, 2)] #[case(3, 3)] diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index 95d5e1d6..56a2dd99 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import numpy as np from beartype import beartype as typechecker -from jaxtyping import Array, Bool, Float, Key, UInt, jaxtyped +from jaxtyping import Array, Bool, Float, PRNGKeyArray, UInt, jaxtyped import differt_core.geometry.triangle_mesh @@ -217,7 +217,9 @@ def plot(self, **kwargs: Any) -> Any: ) @eqx.filter_jit - def sample(self, size: int, replace: bool = False, *, key: Key) -> "TriangleMesh": + def sample( + self, size: int, replace: bool = False, *, key: PRNGKeyArray + ) -> "TriangleMesh": """ Generate a new mesh by randomly sampling triangles from this geometry. diff --git a/differt/tests/conftest.py b/differt/tests/conftest.py index 96fbea27..7e07c762 100644 --- a/differt/tests/conftest.py +++ b/differt/tests/conftest.py @@ -2,6 +2,7 @@ import jax import pytest +from jaxtyping import PRNGKeyArray from differt.scene.sionna import download_sionna_scenes @@ -12,7 +13,7 @@ def seed() -> int: @pytest.fixture -def key(seed: int) -> jax.random.PRNGKey: +def key(seed: int) -> PRNGKeyArray: return jax.random.PRNGKey(seed) diff --git a/differt/tests/test_utils.py b/differt/tests/test_utils.py index 04610346..204ab8c1 100644 --- a/differt/tests/test_utils.py +++ b/differt/tests/test_utils.py @@ -1,8 +1,7 @@ import chex -import jax import jax.numpy as jnp import pytest -from jaxtyping import Array +from jaxtyping import Array, PRNGKeyArray from differt.utils import minimize, sample_points_in_bounding_box, sorted_array2 @@ -70,7 +69,7 @@ def fun(x: Array, a: Array, b: Array, c: Array) -> Array: chex.assert_trees_all_close(got_loss, c) -def test_sample_points_in_bounding_box(key: jax.random.PRNGKey) -> None: +def test_sample_points_in_bounding_box(key: PRNGKeyArray) -> None: def assert_in_bounds(a: Array, bounds: Array) -> None: if a.ndim == 1: a = jnp.reshape(a, (1, a.size)) From 9af41a22e0e39990d170cac275554a01eba3fdb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 7 May 2024 10:43:16 +0200 Subject: [PATCH 03/16] fix(docs): type hints --- differt-core/src/rt/graph.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/differt-core/src/rt/graph.rs b/differt-core/src/rt/graph.rs index ac84ae4a..09c39eed 100644 --- a/differt-core/src/rt/graph.rs +++ b/differt-core/src/rt/graph.rs @@ -119,7 +119,7 @@ pub mod complete { /// distinc nodes is connected by a unique edge. /// /// Args: - /// num_nodes: The number of nodes. + /// num_nodes (int): The number of nodes. #[pyclass] #[derive(Clone, Debug)] pub struct CompleteGraph { @@ -584,7 +584,7 @@ pub mod directed { /// an adjacency matrix will all entries equal to :py:data:`False`. /// /// Args: - /// graph (CompleteGraph): The number of nodes. + /// num\_nodes (int): The number of nodes. /// /// Return: /// DiGraph: A directed graph. From 46579d25ebd287b8c5534cc19618909c10d8dfb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 7 May 2024 10:57:57 +0200 Subject: [PATCH 04/16] fix(docs): type hint --- differt/src/differt/geometry/triangle_mesh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index 56a2dd99..e7ef0262 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -226,7 +226,7 @@ def sample( Args: size: The size of the sample, i.e., the number of triangles. replace: Whether to sample with or without replacement. - key: The :class:`jax.random.PRNGKey` to be used. + key: The :func:`jax.random.PRNGKey` to be used. Return: A new random mesh. From 94ed298292c7bdbe555bf6cdc05bf513f06d6931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 7 May 2024 11:51:26 +0200 Subject: [PATCH 05/16] refactor(lib): make plotting logic simpler --- differt/src/differt/plotting/__init__.py | 2 + differt/src/differt/plotting/_utils.py | 132 ++++++++++++----------- 2 files changed, 71 insertions(+), 63 deletions(-) diff --git a/differt/src/differt/plotting/__init__.py b/differt/src/differt/plotting/__init__.py index 9eed764f..709aba4e 100644 --- a/differt/src/differt/plotting/__init__.py +++ b/differt/src/differt/plotting/__init__.py @@ -88,6 +88,7 @@ """ __all__ = ( + "Dispatcher", "dispatch", "set_defaults", "use", @@ -104,6 +105,7 @@ from ._core import draw_image, draw_markers, draw_mesh, draw_paths from ._utils import ( + Dispatcher, dispatch, process_matplotlib_kwargs, process_plotly_kwargs, diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index 7df79ba7..8738a1d4 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from functools import wraps from threading import Lock -from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar # Immutables @@ -144,19 +144,81 @@ def use(*args: Any, **kwargs: Any) -> Iterator[str]: DEFAULT_KWARGS = default_kwargs -class Dispatcher(Protocol, Generic[P, T]): # pragma: no cover +class Dispatcher(Generic[P, T]): + """A callable that automatically dispatches between different backends.""" + + def __init__(self) -> None: + self.__registry: dict[str, Callable[P, T]] = {} + def __call__( self, *args: P.args, **kwargs: P.kwargs, - ) -> T: ... + ) -> T: + """ + Call the appropriate backend implementation based on the default backend and the provided arguments. + + Args: + args: Positional arguments passed to the correct backend implementation. + kwargs: Keyword arguments passed to the correct backend implementation. + + Return: + The result of the call. + """ + # We cannot currently add keyword argument to the signature, + # at least Pyright will not allow that, + # see: https://github.com/microsoft/pyright/issues/5844. + # + # The motivation is detailed in P612: + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters. + backend: str = kwargs.pop("backend", DEFAULT_BACKEND) # type: ignore + + try: + return self.__registry[backend](*args, **kwargs) + except KeyError: + raise NotImplementedError( + f"No backend implementation for '{backend}'" + ) from None def register( self, backend: str, - ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + ) -> Callable[[Callable[P, T]], Callable[P, T]]: + """ + Return a wrapper that will call the decorated function for the specified backend. - def dispatch(self, backend: str) -> Callable[P, T]: ... + Args: + backend: The name of backend for which the decorated + function will be called. + + Return: + A wrapper to be put before the backend-specific implementation. + """ + if backend not in SUPPORTED_BACKENDS: + raise ValueError( + f"Unsupported backend '{backend}', " + f"allowed values are: {', '.join(SUPPORTED_BACKENDS)}." + ) + + def wrapper(impl: Callable[P, T]) -> Callable[P, T]: + """Actually register the backend implementation.""" + + @wraps(impl) + def __wrapper__(*args: P.args, **kwargs: P.kwargs) -> T: # noqa: N807 + try: + return impl(*args, **kwargs) + except ImportError as e: + raise ImportError( + "An import error occurred when dispatching " + f"plot utility to backend '{backend}'. " + "Did you correctly install it?" + ) from e + + self.__registry[backend] = __wrapper__ + + return __wrapper__ + + return wrapper def dispatch(fun: Callable[P, T]) -> Dispatcher[P, T]: @@ -222,65 +284,9 @@ class instance. Traceback (most recent call last): ValueError: Unsupported backend 'numpy', allowed values are: ... """ - registry = {} - - def register( - backend: str, - ) -> Callable[[Callable[P, T]], Callable[P, T]]: - """Register a new implementation.""" - if backend not in SUPPORTED_BACKENDS: - raise ValueError( - f"Unsupported backend '{backend}', " - f"allowed values are: {', '.join(SUPPORTED_BACKENDS)}." - ) - - def wrapper(impl: Callable[P, T]) -> Callable[P, T]: - """Actually register the backend implementation.""" - - @wraps(impl) - def __wrapper__(*args: P.args, **kwargs: P.kwargs) -> T: # noqa: N807 - try: - return impl(*args, **kwargs) - except ImportError as e: - raise ImportError( - "An import error occurred when dispatching " - f"plot utility to backend '{backend}'. " - "Did you correctly install it?" - ) from e - - registry[backend] = __wrapper__ - - return __wrapper__ - - return wrapper - - def dispatch(backend: str) -> Callable[P, T]: - try: - return registry[backend] - except KeyError: - raise NotImplementedError( - f"No backend implementation for '{backend}'" - ) from None - - @wraps(fun) - def main_wrapper( - *args: P.args, - **kwargs: P.kwargs, - ) -> T: - # We cannot currently add keyword argument to the signature, - # at least Pyright will not allow that, - # see: https://github.com/microsoft/pyright/issues/5844. - # - # The motivation is detailed in P612: - # https://peps.python.org/pep-0612/#concatenating-keyword-parameters. - backend: str = kwargs.pop("backend", DEFAULT_BACKEND) # type: ignore - return dispatch(backend)(*args, **kwargs) - - main_wrapper.register = register # type: ignore[attr-defined] - main_wrapper.dispatch = dispatch # type: ignore[attr-defined] - main_wrapper.registry = registry # type: ignore[attr-defined] + dispatcher: Dispatcher[P, T] = Dispatcher() - return main_wrapper # type: ignore[return-value] + return wraps(fun)(dispatcher) # type: ignore def view_from_canvas(canvas: SceneCanvas) -> ViewBox: From b29cf27d1decd826e9f2b9f40b6a793b1a67c76f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Tue, 7 May 2024 16:59:30 +0200 Subject: [PATCH 06/16] chore(docs): cleanup --- differt/src/differt/plotting/__init__.py | 2 - differt/src/differt/plotting/_utils.py | 168 +++++++++++------------ differt/tests/plotting/test_utils.py | 4 +- 3 files changed, 82 insertions(+), 92 deletions(-) diff --git a/differt/src/differt/plotting/__init__.py b/differt/src/differt/plotting/__init__.py index 709aba4e..9eed764f 100644 --- a/differt/src/differt/plotting/__init__.py +++ b/differt/src/differt/plotting/__init__.py @@ -88,7 +88,6 @@ """ __all__ = ( - "Dispatcher", "dispatch", "set_defaults", "use", @@ -105,7 +104,6 @@ from ._core import draw_image, draw_markers, draw_mesh, draw_paths from ._utils import ( - Dispatcher, dispatch, process_matplotlib_kwargs, process_plotly_kwargs, diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index 8738a1d4..d7840130 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -5,7 +5,7 @@ import importlib from collections.abc import Iterator, MutableMapping from contextlib import contextmanager -from functools import wraps +from functools import update_wrapper, wraps from threading import Lock from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar @@ -104,16 +104,15 @@ def set_defaults(backend: str | None = None, **kwargs: Any) -> str: f"We currently support: {', '.join(SUPPORTED_BACKENDS)}." ) - with BACKEND_LOCK: - try: - importlib.import_module(f"{backend}") - DEFAULT_BACKEND = backend - DEFAULT_KWARGS = kwargs - return backend - except ImportError: - raise ImportError( - f"Could not load backend '{backend}', did you install it?" - ) from None + try: + importlib.import_module(f"{backend}") + DEFAULT_BACKEND = backend + DEFAULT_KWARGS = kwargs + return backend + except ImportError: + raise ImportError( + f"Could not load backend '{backend}', did you install it?" + ) from None @contextmanager @@ -137,18 +136,81 @@ def use(*args: Any, **kwargs: Any) -> Iterator[str]: default_backend = DEFAULT_BACKEND default_kwargs = DEFAULT_KWARGS - try: - yield set_defaults(*args, **kwargs) - finally: - DEFAULT_BACKEND = default_backend - DEFAULT_KWARGS = default_kwargs + with BACKEND_LOCK: + try: + yield set_defaults(*args, **kwargs) + finally: + DEFAULT_BACKEND = default_backend + DEFAULT_KWARGS = default_kwargs -class Dispatcher(Generic[P, T]): - """A callable that automatically dispatches between different backends.""" +class dispatch(Generic[P, T]): + """ + A class that transforms a function into a backend dispatcher for plot functions. + + Args: + fun: The callable that will register future dispatch + functions for each backend implementation. + + Notes: + Only the functions registered with :meth:`register` will be called. + The :data:`fun` argument wrapped inside :class:`dispatch` is + only used for documentation, but never called. + + Examples: + The following example shows how one can implement plotting + utilities on different backends for a given plot. - def __init__(self) -> None: + >>> import differt.plotting as dplt + >>> + >>> @dplt.dispatch + ... def plot_line(vertices, color): + ... pass + >>> + >>> @plot_line.register("matplotlib") + ... def _(vertices, color): + ... print("Using matplotlib backend") + >>> + >>> @plot_line.register("plotly") + ... def _(vertices, color): + ... print("Using plotly backend") + >>> + >>> plot_line( + ... _, + ... _, + ... backend="matplotlib", + ... ) + Using matplotlib backend + >>> + >>> plot_line( + ... _, + ... _, + ... backend="plotly", + ... ) + Using plotly backend + >>> + >>> plot_line( + ... _, + ... _, + ... backend="vispy", + ... ) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + NotImplementedError: No backend implementation for 'vispy' + >>> + >>> # The default backend is VisPy so unimplemented too. + >>> plot_line(_, _) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + NotImplementedError: No backend implementation for 'vispy' + >>> + >>> @plot_line.register("numpy") # doctest: +IGNORE_EXCEPTION_DETAIL + ... def _(vertices, color): + ... pass + Traceback (most recent call last): + ValueError: Unsupported backend 'numpy', allowed values are: ... + """ + def __init__(self, fun: Callable[P, T]) -> None: self.__registry: dict[str, Callable[P, T]] = {} + update_wrapper(self, fun) def __call__( self, @@ -221,74 +283,6 @@ def __wrapper__(*args: P.args, **kwargs: P.kwargs) -> T: # noqa: N807 return wrapper -def dispatch(fun: Callable[P, T]) -> Dispatcher[P, T]: - """ - Transform a function into a backend dispatcher for plot functions. - - Args: - fun: The callable that will register future dispatch - functions for each backend implementation. - - Return: - The same callable, wrapped in a :py:class:`Dispatcher` - class instance. - - Examples: - The following example shows how one can implement plotting - utilities on different backends for a given plot. - - >>> import differt.plotting as dplt - >>> - >>> @dplt.dispatch - ... def plot_line(vertices, color): - ... pass - >>> - >>> @plot_line.register("matplotlib") - ... def _(vertices, color): - ... print("Using matplotlib backend") - >>> - >>> @plot_line.register("plotly") - ... def _(vertices, color): - ... print("Using plotly backend") - >>> - >>> plot_line( - ... _, - ... _, - ... backend="matplotlib", - ... ) - Using matplotlib backend - >>> - >>> plot_line( - ... _, - ... _, - ... backend="plotly", - ... ) - Using plotly backend - >>> - >>> plot_line( - ... _, - ... _, - ... backend="vispy", - ... ) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - NotImplementedError: No backend implementation for 'vispy' - >>> - >>> # The default backend is VisPy so unimplemented too. - >>> plot_line(_, _) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - NotImplementedError: No backend implementation for 'vispy' - >>> - >>> @plot_line.register("numpy") # doctest: +IGNORE_EXCEPTION_DETAIL - ... def _(vertices, color): - ... pass - Traceback (most recent call last): - ValueError: Unsupported backend 'numpy', allowed values are: ... - """ - dispatcher: Dispatcher[P, T] = Dispatcher() - - return wraps(fun)(dispatcher) # type: ignore - - def view_from_canvas(canvas: SceneCanvas) -> ViewBox: """ Return the view from the specified canvas. diff --git a/differt/tests/plotting/test_utils.py b/differt/tests/plotting/test_utils.py index 2e57431c..e536a0fa 100644 --- a/differt/tests/plotting/test_utils.py +++ b/differt/tests/plotting/test_utils.py @@ -142,9 +142,7 @@ def est_missing_default_backend_module( ( "vispy", "matplotlib", - pytest.param( - "plotly", marks=pytest.mark.xfail(reason="Unknown, to be investigated...") - ), + "plotly" ), ) def test_missing_backend_module( From f000e2389e53aca3f41832b6d0536678b9e7560a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 11:56:29 +0200 Subject: [PATCH 07/16] chore(docs): fix disappearing functions --- differt/src/differt/plotting/_utils.py | 98 ++++++++++++++++---------- differt/src/differt/utils.py | 5 +- differt/tests/plotting/test_utils.py | 6 +- docs/source/conf.py | 19 +---- 4 files changed, 66 insertions(+), 62 deletions(-) diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index d7840130..77319655 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -3,9 +3,10 @@ from __future__ import annotations import importlib +import types from collections.abc import Iterator, MutableMapping from contextlib import contextmanager -from functools import update_wrapper, wraps +from functools import wraps from threading import Lock from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar @@ -144,14 +145,31 @@ def use(*args: Any, **kwargs: Any) -> Iterator[str]: DEFAULT_KWARGS = default_kwargs -class dispatch(Generic[P, T]): +class _Dispatcher(Generic[P, T]): + registry: types.MappingProxyType[str, Callable[P, T]] + + def __call__( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> T: ... + def register( + self, + backend: str, + ) -> Callable[[Callable[P, T]], Callable[P, T]]: ... + + +def dispatch(fun: Callable[P, T]) -> _Dispatcher[P, T]: """ - A class that transforms a function into a backend dispatcher for plot functions. + Transform a function into a backend dispatcher for plot functions. Args: fun: The callable that will register future dispatch functions for each backend implementation. + Return: + A callable that can register backend implementations with ``.register``. + Notes: Only the functions registered with :meth:`register` will be called. The :data:`fun` argument wrapped inside :class:`dispatch` is @@ -208,42 +226,9 @@ class dispatch(Generic[P, T]): Traceback (most recent call last): ValueError: Unsupported backend 'numpy', allowed values are: ... """ - def __init__(self, fun: Callable[P, T]) -> None: - self.__registry: dict[str, Callable[P, T]] = {} - update_wrapper(self, fun) - - def __call__( - self, - *args: P.args, - **kwargs: P.kwargs, - ) -> T: - """ - Call the appropriate backend implementation based on the default backend and the provided arguments. - - Args: - args: Positional arguments passed to the correct backend implementation. - kwargs: Keyword arguments passed to the correct backend implementation. - - Return: - The result of the call. - """ - # We cannot currently add keyword argument to the signature, - # at least Pyright will not allow that, - # see: https://github.com/microsoft/pyright/issues/5844. - # - # The motivation is detailed in P612: - # https://peps.python.org/pep-0612/#concatenating-keyword-parameters. - backend: str = kwargs.pop("backend", DEFAULT_BACKEND) # type: ignore - - try: - return self.__registry[backend](*args, **kwargs) - except KeyError: - raise NotImplementedError( - f"No backend implementation for '{backend}'" - ) from None + registry: dict[str, Callable[P, T]] = {} def register( - self, backend: str, ) -> Callable[[Callable[P, T]], Callable[P, T]]: """ @@ -276,12 +261,49 @@ def __wrapper__(*args: P.args, **kwargs: P.kwargs) -> T: # noqa: N807 "Did you correctly install it?" ) from e - self.__registry[backend] = __wrapper__ + registry[backend] = __wrapper__ return __wrapper__ return wrapper + @wraps(fun) + def wrapper( + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + """ + Call the appropriate backend implementation based on the default backend and the provided arguments. + + Args: + args: Positional arguments passed to the correct backend implementation. + kwargs: Keyword arguments passed to the correct backend implementation. + + Return: + The result of the call. + """ + # We cannot currently add keyword argument to the signature, + # at least Pyright will not allow that, + # see: https://github.com/microsoft/pyright/issues/5844. + # + # The motivation is detailed in P612: + # https://peps.python.org/pep-0612/#concatenating-keyword-parameters. + backend: str = kwargs.pop("backend", DEFAULT_BACKEND) # type: ignore + + try: + return registry[backend](*args, **kwargs) + except KeyError: + raise NotImplementedError( + f"No backend implementation for '{backend}'" + ) from None + + return wrapper + + wrapper.register = register # type: ignore + wrapper.registry = types.MappingProxyType(registry) # type: ignore + + return wrapper # type: ignore + def view_from_canvas(canvas: SceneCanvas) -> ViewBox: """ diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index 39f012b7..6bbe7eed 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -11,7 +11,7 @@ import jax.numpy as jnp import optax from beartype import beartype as typechecker -from jaxtyping import Array, Float, Key, Num, Shaped, jaxtyped +from jaxtyping import Array, Float, Num, PRNGKeyArray, Shaped, jaxtyped if sys.version_info >= (3, 11): from typing import TypeVarTuple, Unpack @@ -232,8 +232,9 @@ def f( @eqx.filter_jit +@jaxtyped(typechecker=typechecker) def sample_points_in_bounding_box( - bounding_box: Float[Array, "2 3"], size: Optional[int] = None, *, key: Key + bounding_box: Float[Array, "2 3"], size: Optional[int] = None, *, key: PRNGKeyArray ) -> Float[Array, "?size 3"]: """ Sample point(s) in a 3D bounding box. diff --git a/differt/tests/plotting/test_utils.py b/differt/tests/plotting/test_utils.py index e536a0fa..ef2e5ab7 100644 --- a/differt/tests/plotting/test_utils.py +++ b/differt/tests/plotting/test_utils.py @@ -139,11 +139,7 @@ def est_missing_default_backend_module( @pytest.mark.parametrize( "backend", - ( - "vispy", - "matplotlib", - "plotly" - ), + ("vispy", "matplotlib", "plotly"), ) def test_missing_backend_module( backend: str, missing_modules: MissingModulesContextGenerator diff --git a/docs/source/conf.py b/docs/source/conf.py index 718fe447..a6418378 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -107,8 +107,6 @@ ("*", "text/html", 0), ] -# TODO: fix JS warnings about html-manager (wrong version?) - # -- Bibtex bibtex_bibfiles = ["references.bib"] @@ -171,21 +169,8 @@ # Patches -# TODO: fix Plotly's Figure not linking to docs with intersphinx. - -""" -def fix_signature(app, what, name, obj, options, signature, return_annotation): - target = "~plotly.graph_objs._figure.Figure" - sub = ":py:class:`Figure`" - sub = "~plotly.graph_objects.Figure" - if return_annotation and target in return_annotation: - return_annotation = return_annotation.replace(target, sub) - return signature, return_annotation.replace(target, sub) - - -def setup(app): - app.connect("autodoc-process-signature", fix_signature, priority=999) -""" +# TODO: fix Plotly's Figure not linking to docs with intersphinx, +# reported here https://github.com/sphinx-doc/sphinx/issues/12360. def fix_sionna_folder(app, obj: Any, bound_method: bool) -> None: From 520f5f773fca4f57982457a3ee3eb7d6d199ba7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 12:04:12 +0200 Subject: [PATCH 08/16] fix(docs): typos --- differt/src/differt/plotting/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index 77319655..8c8e1319 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -168,11 +168,11 @@ def dispatch(fun: Callable[P, T]) -> _Dispatcher[P, T]: functions for each backend implementation. Return: - A callable that can register backend implementations with ``.register``. + A callable that can register backend implementations with ``register``. Notes: - Only the functions registered with :meth:`register` will be called. - The :data:`fun` argument wrapped inside :class:`dispatch` is + Only the functions registered with ``register``` will be called. + The :data:`fun` argument wrapped inside :fun:`dispatch` is only used for documentation, but never called. Examples: From 72a04e91b456809a8e4a45171c3655c67dc7402b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 12:08:48 +0200 Subject: [PATCH 09/16] chore(docs): nicer looking example --- differt/src/differt/plotting/_core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/differt/src/differt/plotting/_core.py b/differt/src/differt/plotting/_core.py index 7970c4c7..c5247fd5 100644 --- a/differt/src/differt/plotting/_core.py +++ b/differt/src/differt/plotting/_core.py @@ -137,12 +137,12 @@ def draw_paths( >>> from differt.plotting import draw_paths >>> >>> def rotation(angle: float) -> np.ndarray: - ... c = np.cos(angle) - ... s = np.sin(angle) + ... co = np.cos(angle) + ... si = np.sin(angle) ... return np.array( ... [ - ... [+c, -s, 0.0], - ... [+s, +c, 0.0], + ... [+co, -si, 0.0], + ... [+si, +co, 0.0], ... [0.0, 0.0, 1.0], ... ] ... ) From 79105f7883c5061f848b825cca61bdee5c4d1dd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 12:20:53 +0200 Subject: [PATCH 10/16] chore(docs): add reuse example --- differt/src/differt/plotting/_utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index 8c8e1319..85bfb59b 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -172,7 +172,7 @@ def dispatch(fun: Callable[P, T]) -> _Dispatcher[P, T]: Notes: Only the functions registered with ``register``` will be called. - The :data:`fun` argument wrapped inside :fun:`dispatch` is + The :data:`fun` argument wrapped inside :func:`dispatch` is only used for documentation, but never called. Examples: @@ -488,6 +488,25 @@ def reuse(**kwargs: Any) -> Iterator[SceneCanvas | MplFigure | Figure]: Return: The canvas or figure that is reused for this context. + + Examples: + The following example show how the same figure is reused + for multiple plots. + + .. plotly:: + + >>> from differt.plotting import draw_image, reuse + >>> + >>> x = np.linspace(-1.0, +1.0, 100) + >>> y = np.linspace(-4.0, +4.0, 200) + >>> X, Y = np.meshgrid(x, y) + >>> + >>> with reuse(backend="plotly") as fig: + ... for z0, w in enumerate(jnp.linspace(0, 10 * jnp.pi, 5)): + ... Z = np.cos(w * X) * np.sin(w * Y) + ... draw_image(Z, x=x, y=y, z0=z0) # TODO: fix colorbar + >>> + >>> fig # doctest: +SKIP """ global DEFAULT_KWARGS backend: str | None = kwargs.pop("backend", None) From 219d57fa6e238e84ec5658d1c0aff44494b9f820 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 13:48:38 +0200 Subject: [PATCH 11/16] fix(lib): type hints --- differt/src/differt/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index 6bbe7eed..a89d1e58 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -235,7 +235,7 @@ def f( @jaxtyped(typechecker=typechecker) def sample_points_in_bounding_box( bounding_box: Float[Array, "2 3"], size: Optional[int] = None, *, key: PRNGKeyArray -) -> Float[Array, "?size 3"]: +) -> Union[Float[Array, "size 3"], Float[Array, "3"]]: """ Sample point(s) in a 3D bounding box. From 5598df19037987814b1f78b75ea57b629bf01232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 14:29:06 +0200 Subject: [PATCH 12/16] fix(docs): examples and tests --- differt/src/differt/plotting/_utils.py | 75 ++++++++++++++++++++------ differt/tests/plotting/test_utils.py | 6 +++ 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index 85bfb59b..686b12e0 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -65,35 +65,45 @@ def set_defaults(backend: str | None = None, **kwargs: Any) -> str: ImportError: If the backend is not installed. Examples: - The following example shows how to set the default plotting backend. + The following example shows how to set the default plotting backend + and other plotting defaults. >>> import differt.plotting as dplt >>> >>> @dplt.dispatch - ... def my_plot(): + ... def my_plot(*args, **kwargs): ... pass >>> >>> @my_plot.register("vispy") - ... def _(): - ... print("Using vispy backend") + ... def _(*args, **kwargs): + ... dplt.process_vispy_kwargs(kwargs) + ... print(f"Using vispy backend with {args = }, {kwargs = }") >>> >>> @my_plot.register("matplotlib") - ... def _(): - ... print("Using matplotlib backend") + ... def _(*args, **kwargs): + ... dplt.process_matplotlib_kwargs(kwargs) + ... print(f"Using matplotlib backend with {args = }, {kwargs = }") >>> >>> my_plot() # When not specified, use default backend - Using vispy backend - >>> - >>> my_plot(backend="matplotlib") # We can force the backend - Using matplotlib backend + Using vispy backend with args = (), kwargs = {} >>> - >>> dplt.set_defaults("matplotlib") # Or change the default backend... + >>> dplt.set_defaults("matplotlib") # We can change the default backend 'matplotlib' >>> my_plot() # So that now it defaults to 'matplotlib' - Using matplotlib backend + Using matplotlib backend with args = (), kwargs = {} >>> + >>> dplt.set_defaults( + ... "matplotlib", color="red" + ... ) # We can also specify additional defaults + 'matplotlib' + >>> my_plot() # So that now it defaults to 'matplotlib' and color='red' + Using matplotlib backend with args = (), kwargs = {'color': 'red'} >>> my_plot(backend="vispy") # Of course, the 'vispy' backend is still available - Using vispy backend + Using vispy backend with args = (), kwargs = {'color': 'red'} + >>> my_plot(backend="vispy", color="green") # And we can also override any default + Using vispy backend with args = (), kwargs = {'color': 'green'} + >>> dplt.set_defaults("vispy") # Reset all defaults + 'vispy' """ global DEFAULT_BACKEND, DEFAULT_KWARGS @@ -132,6 +142,40 @@ def use(*args: Any, **kwargs: Any) -> Iterator[str]: Return: The name of the default backend used in this context. + + Examples: + The following example shows how set plot defaults in a context. + + >>> import differt.plotting as dplt + >>> + >>> @dplt.dispatch + ... def my_plot(*args, **kwargs): + ... pass + >>> + >>> @my_plot.register("vispy") + ... def _(*args, **kwargs): + ... dplt.process_vispy_kwargs(kwargs) + ... print(f"Using vispy backend with {args = }, {kwargs = }") + >>> + >>> @my_plot.register("plotly") + ... def _(*args, **kwargs): + ... dplt.process_plotly_kwargs(kwargs) + ... print(f"Using plotly backend with {args = }, {kwargs = }") + >>> + >>> my_plot() # When not specified, use default backend + Using vispy backend with args = (), kwargs = {} + >>> with dplt.use(): # No parameters = reset defaults (except the default backend) + ... my_plot() + Using vispy backend with args = (), kwargs = {} + >>> with dplt.use("plotly"): # We can change the default backend + ... my_plot() # So that now it defaults to 'matplotlib' + Using plotly backend with args = (), kwargs = {} + >>> + >>> with dplt.use( + ... "plotly", color="black" + ... ): # We can also specify additional defaults + ... my_plot() + Using plotly backend with args = (), kwargs = {'color': 'black'} """ global DEFAULT_BACKEND, DEFAULT_KWARGS default_backend = DEFAULT_BACKEND @@ -481,8 +525,6 @@ def reuse(**kwargs: Any) -> Iterator[SceneCanvas | MplFigure | Figure]: """Create a context manager that will automatically reuse the current canvas / figure. Args: - args: Positional arguments passed to - :py:func:`set_defaults`. kwargs: Keywords arguments passed to :py:func:`set_defaults`. @@ -501,11 +543,10 @@ def reuse(**kwargs: Any) -> Iterator[SceneCanvas | MplFigure | Figure]: >>> y = np.linspace(-4.0, +4.0, 200) >>> X, Y = np.meshgrid(x, y) >>> - >>> with reuse(backend="plotly") as fig: + >>> with reuse(backend="plotly") as fig: # doctest: +SKIP ... for z0, w in enumerate(jnp.linspace(0, 10 * jnp.pi, 5)): ... Z = np.cos(w * X) * np.sin(w * Y) ... draw_image(Z, x=x, y=y, z0=z0) # TODO: fix colorbar - >>> >>> fig # doctest: +SKIP """ global DEFAULT_KWARGS diff --git a/differt/tests/plotting/test_utils.py b/differt/tests/plotting/test_utils.py index ef2e5ab7..073e52dd 100644 --- a/differt/tests/plotting/test_utils.py +++ b/differt/tests/plotting/test_utils.py @@ -2,6 +2,7 @@ import builtins import importlib +import platform import sys from contextlib import AbstractContextManager as ContextManager from contextlib import contextmanager @@ -24,6 +25,11 @@ view_from_canvas, ) +if not platform.system() == "Darwin" and platform.processor() == "arm": + pytest.skip( + "skipping tests on macOS (m1) runners at the moment...", allow_module_level=True + ) + _LOCK = Lock() From 28b5ee33f4910719b1b6ddb726fcadf037672ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 15:24:33 +0200 Subject: [PATCH 13/16] chore(docs): renaming some axes --- differt/src/differt/plotting/_core.py | 24 ++++++++++++------------ docs/source/conf.py | 3 ++- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/differt/src/differt/plotting/_core.py b/differt/src/differt/plotting/_core.py index c5247fd5..d87e0ec2 100644 --- a/differt/src/differt/plotting/_core.py +++ b/differt/src/differt/plotting/_core.py @@ -311,9 +311,9 @@ def _( @dispatch # type: ignore def draw_image( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " *m"] | None = None, - y: Float[np.ndarray, " *n"] | None = None, + data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], + x: Float[np.ndarray, " rows"] | None = None, + y: Float[np.ndarray, " cols"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> Canvas | MplFigure | Figure: # type: ignore @@ -368,9 +368,9 @@ def draw_image( @draw_image.register("vispy") def _( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " ..."] | None = None, - y: Float[np.ndarray, " ..."] | None = None, + data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], + x: Float[np.ndarray, " rows"] | None = None, + y: Float[np.ndarray, " cols"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> Canvas: @@ -416,9 +416,9 @@ def _( @draw_image.register("matplotlib") def _( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " ..."] | None = None, - y: Float[np.ndarray, " ..."] | None = None, + data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], + x: Float[np.ndarray, " rows"] | None = None, + y: Float[np.ndarray, " cols"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> MplFigure: @@ -431,9 +431,9 @@ def _( @draw_image.register("plotly") def _( - data: Num[np.ndarray, "m n"] | Num[np.ndarray, "m n 3"] | Num[np.ndarray, "m n 4"], - x: Float[np.ndarray, " ..."] | None = None, - y: Float[np.ndarray, " ..."] | None = None, + data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], + x: Float[np.ndarray, " rows"] | None = None, + y: Float[np.ndarray, " cols"] | None = None, z0: float = 0.0, **kwargs: Any, ) -> Figure: diff --git a/docs/source/conf.py b/docs/source/conf.py index a6418378..50175610 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,6 +10,7 @@ import os from datetime import date from typing import Any +from sphinx.application import Sphinx from differt import __version__ @@ -173,7 +174,7 @@ # reported here https://github.com/sphinx-doc/sphinx/issues/12360. -def fix_sionna_folder(app, obj: Any, bound_method: bool) -> None: +def fix_sionna_folder(app: Sphinx, obj: Any, bound_method: bool) -> None: """ Rename the default folder to a more readeable name. """ From f70942d1a32318767ed60bdd48de15e51bf2adcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 8 May 2024 17:24:55 +0200 Subject: [PATCH 14/16] fix(docs): type hints etc --- .pre-commit-config.yaml | 2 +- differt/src/differt/plotting/_core.py | 120 +++++++++++++++---------- differt/src/differt/plotting/_utils.py | 50 ++++++----- docs/source/conf.py | 3 +- pyproject.toml | 3 +- 5 files changed, 106 insertions(+), 72 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7c07eda0..62566c44 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: - id: ruff-format types_or: [python, pyi, jupyter] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.361 + rev: v1.1.362 hooks: - id: pyright - repo: https://github.com/doublify/pre-commit-rust diff --git a/differt/src/differt/plotting/_core.py b/differt/src/differt/plotting/_core.py index d87e0ec2..a97d597b 100644 --- a/differt/src/differt/plotting/_core.py +++ b/differt/src/differt/plotting/_core.py @@ -1,9 +1,7 @@ """Core plotting implementations.""" -from __future__ import annotations - from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import Any, Optional, Union import numpy as np from jaxtyping import Float, Num, UInt @@ -15,18 +13,34 @@ process_vispy_kwargs, ) -if TYPE_CHECKING: +# We cannot use from __future__ import annotations because +# otherwise array annotations do not render correctly. +# We cannot rely on TYPE_CHECKING-guarded annotation +# because Sphinx will fail to import this NumPy or Jax typing +# Hence, we prefer to silence pyright instead. + +try: from matplotlib.figure import Figure as MplFigure +except ImportError: + MplFigure = Any + +try: from plotly.graph_objects import Figure +except ImportError: + Figure = Any + +try: from vispy.scene.canvas import SceneCanvas as Canvas +except ImportError: + Canvas = Any -@dispatch # type: ignore +@dispatch def draw_mesh( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> Canvas | MplFigure | Figure: # type: ignore +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot a 3D mesh made of triangles. @@ -72,7 +86,7 @@ def _( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> Canvas: +) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Mesh canvas, view = process_vispy_kwargs(kwargs) @@ -88,7 +102,7 @@ def _( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> MplFigure: +) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) x, y, z = vertices.T @@ -102,7 +116,7 @@ def _( vertices: Float[np.ndarray, "num_vertices 3"], triangles: UInt[np.ndarray, "num_triangles 3"], **kwargs: Any, -) -> Figure: +) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) x, y, z = vertices.T @@ -111,10 +125,10 @@ def _( return fig.add_mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs) -@dispatch # type: ignore +@dispatch def draw_paths( - paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any -) -> Canvas | MplFigure | Figure: # type: ignore + paths: Float[np.ndarray, r"\*batch path_length 3"], **kwargs: Any +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot a batch of paths of the same length. @@ -172,7 +186,7 @@ def draw_paths( @draw_paths.register("vispy") -def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Canvas: +def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import LinePlot canvas, view = process_vispy_kwargs(kwargs) @@ -186,7 +200,7 @@ def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Canvas @draw_paths.register("matplotlib") -def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> MplFigure: +def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) for i in np.ndindex(paths.shape[:-2]): @@ -196,7 +210,7 @@ def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> MplFig @draw_paths.register("plotly") -def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Figure: +def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) for i in np.ndindex(paths.shape[:-2]): @@ -206,13 +220,13 @@ def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Figure return fig -@dispatch # type: ignore +@dispatch def draw_markers( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> Canvas | MplFigure | Figure: # type: ignore +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot markers and, optionally, their label. @@ -259,10 +273,10 @@ def draw_markers( @draw_markers.register("vispy") def _( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> Canvas: +) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Markers, Text canvas, view = process_vispy_kwargs(kwargs) @@ -280,20 +294,20 @@ def _( @draw_markers.register("matplotlib") def _( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> MplFigure: +) -> MplFigure: # type: ignore[reportInvalidTypeForm] raise NotImplementedError # TODO @draw_markers.register("plotly") def _( markers: Float[np.ndarray, "num_markers 3"], - labels: Sequence[str] | None = None, - text_kwargs: Mapping[str, Any] | None = None, + labels: Optional[Sequence[str]] = None, + text_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any, -) -> Figure: +) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) if labels: @@ -309,14 +323,18 @@ def _( ) -@dispatch # type: ignore +@dispatch def draw_image( - data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " rows"] | None = None, - y: Float[np.ndarray, " cols"] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> Canvas | MplFigure | Figure: # type: ignore +) -> Union[Canvas, MplFigure, Figure]: # type: ignore[reportInvalidTypeForm] """ Plot a 2D image on a 3D canvas, at using a fixed z-coordinate. @@ -368,12 +386,16 @@ def draw_image( @draw_image.register("vispy") def _( - data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " rows"] | None = None, - y: Float[np.ndarray, " cols"] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> Canvas: +) -> Canvas: # type: ignore[reportInvalidTypeForm] from vispy.scene.visuals import Image from vispy.visuals.transforms import STTransform @@ -416,12 +438,16 @@ def _( @draw_image.register("matplotlib") def _( - data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " rows"] | None = None, - y: Float[np.ndarray, " cols"] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> MplFigure: +) -> MplFigure: # type: ignore[reportInvalidTypeForm] fig, ax = process_matplotlib_kwargs(kwargs) ax.plot_surface(X=x, Y=y, Z=np.full_like(data, z0), color=data, **kwargs) @@ -431,12 +457,16 @@ def _( @draw_image.register("plotly") def _( - data: Num[np.ndarray, "rows cols"] | Num[np.ndarray, "rows cols 3"] | Num[np.ndarray, "rows cols 4"], - x: Float[np.ndarray, " rows"] | None = None, - y: Float[np.ndarray, " cols"] | None = None, + data: Union[ + Num[np.ndarray, "rows cols"], + Num[np.ndarray, "rows cols 3"], + Num[np.ndarray, "rows cols 4"], + ], + x: Optional[Float[np.ndarray, " rows"]] = None, + y: Optional[Float[np.ndarray, " cols"]] = None, z0: float = 0.0, **kwargs: Any, -) -> Figure: +) -> Figure: # type: ignore[reportInvalidTypeForm] fig = process_plotly_kwargs(kwargs) return fig.add_surface( diff --git a/differt/src/differt/plotting/_utils.py b/differt/src/differt/plotting/_utils.py index 686b12e0..22a03660 100644 --- a/differt/src/differt/plotting/_utils.py +++ b/differt/src/differt/plotting/_utils.py @@ -1,14 +1,13 @@ """Useful decorators for plotting.""" -from __future__ import annotations - import importlib +import sys import types from collections.abc import Iterator, MutableMapping from contextlib import contextmanager from functools import wraps from threading import Lock -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union # Immutables @@ -25,28 +24,31 @@ DEFAULT_KWARGS: MutableMapping[str, Any] = {} """The default keyword arguments.""" -if TYPE_CHECKING: - import sys +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +P = ParamSpec("P") + +if TYPE_CHECKING: from matplotlib.figure import Figure as MplFigure from mpl_toolkits.mplot3d import Axes3D from plotly.graph_objects import Figure - from vispy.scene.canvas import SceneCanvas + from vispy.scene.canvas import SceneCanvas as Canvas from vispy.scene.widgets.viewbox import ViewBox - - if sys.version_info >= (3, 10): - from typing import ParamSpec - else: - from typing_extensions import ParamSpec - - P = ParamSpec("P") - T = TypeVar("T", SceneCanvas, MplFigure, Figure) else: - P = TypeVar("P") - T = TypeVar("T") + MplFigure = Any + Axes3D = Any + Figure = Any + Canvas = Any + ViewBox = Any + +T = TypeVar("T", Canvas, MplFigure, Figure) -def set_defaults(backend: str | None = None, **kwargs: Any) -> str: +def set_defaults(backend: Optional[str] = None, **kwargs: Any) -> str: """ Set default keyword arguments for future plotting utilities. @@ -215,7 +217,7 @@ def dispatch(fun: Callable[P, T]) -> _Dispatcher[P, T]: A callable that can register backend implementations with ``register``. Notes: - Only the functions registered with ``register``` will be called. + Only the functions registered with ``register`` will be called. The :data:`fun` argument wrapped inside :func:`dispatch` is only used for documentation, but never called. @@ -349,7 +351,7 @@ def wrapper( return wrapper # type: ignore -def view_from_canvas(canvas: SceneCanvas) -> ViewBox: +def view_from_canvas(canvas: Canvas) -> ViewBox: """ Return the view from the specified canvas. @@ -387,7 +389,7 @@ def default_view() -> ViewBox: def process_vispy_kwargs( kwargs: MutableMapping[str, Any], -) -> tuple[SceneCanvas, ViewBox]: +) -> tuple[Canvas, ViewBox]: """ Process keyword arguments passed to some VisPy plotting utility. @@ -399,7 +401,7 @@ def process_vispy_kwargs( The keys specified below will be removed from the mapping. Keyword Args: - convas (:py:class:`SceneCanvas`): + canvas (:py:class:`SceneCanvas`): The canvas that draws contents of the scene. If not provided, will try to access canvas from ``view`` (if supplied). view (:py:class:`Viewbox`): @@ -476,7 +478,7 @@ def process_matplotlib_kwargs( or plt.figure() ) - def current_ax3d() -> Axes3D | None: + def current_ax3d() -> Optional[Axes3D]: if len(figure.axes) > 0: ax = figure.gca() if isinstance(ax, Axes3D): @@ -521,7 +523,7 @@ def process_plotly_kwargs( @contextmanager -def reuse(**kwargs: Any) -> Iterator[SceneCanvas | MplFigure | Figure]: +def reuse(**kwargs: Any) -> Iterator[Union[Canvas, MplFigure, Figure]]: """Create a context manager that will automatically reuse the current canvas / figure. Args: @@ -550,7 +552,7 @@ def reuse(**kwargs: Any) -> Iterator[SceneCanvas | MplFigure | Figure]: >>> fig # doctest: +SKIP """ global DEFAULT_KWARGS - backend: str | None = kwargs.pop("backend", None) + backend: Optional[str] = kwargs.pop("backend", None) with use(backend=backend) as b: try: diff --git a/docs/source/conf.py b/docs/source/conf.py index 50175610..df19d654 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,6 +10,7 @@ import os from datetime import date from typing import Any + from sphinx.application import Sphinx from differt import __version__ @@ -72,7 +73,7 @@ "../../differt-core/python/differt_core", ] apidoc_output_dirs = "reference" -apidoc_exclude_patterns = ["conftest.py", "*scene/scenes/*"] +apidoc_exclude_patterns = ["*conftest.py", "*scene/scenes/*"] apidoc_separate = True apidoc_no_toc = True apidoc_max_depth = 1 diff --git a/pyproject.toml b/pyproject.toml index 1a835d4e..9ef80bb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ omit = ["**/*/conftest.py"] allow-direct-references = true [tool.pyright] -include = ["differt/src/differt", "tests"] +deprecateTypingAliases = true +include = ["differt/src/differt", "differt/tests"] venv = ".venv" venvPath = "." From 5c90adf06e34dcc7f4aa0fa497cfafa9b6f24390 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Fri, 10 May 2024 14:17:42 +0200 Subject: [PATCH 15/16] fix(docs): wip --- docs/source/conf.py | 2 +- src/differt_dev/sphinxext/apidoc.py | 17 +++++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index df19d654..8678eb99 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -73,7 +73,7 @@ "../../differt-core/python/differt_core", ] apidoc_output_dirs = "reference" -apidoc_exclude_patterns = ["*conftest.py", "*scene/scenes/*"] +apidoc_exclude_patterns = ["conftest.py", "*/scene/scenes/*"] apidoc_separate = True apidoc_no_toc = True apidoc_max_depth = 1 diff --git a/src/differt_dev/sphinxext/apidoc.py b/src/differt_dev/sphinxext/apidoc.py index bc382d8c..7e6a27b9 100644 --- a/src/differt_dev/sphinxext/apidoc.py +++ b/src/differt_dev/sphinxext/apidoc.py @@ -71,18 +71,11 @@ def builder_inited(app: Sphinx) -> None: # noqa: C901 output_dir = path.join(app.srcdir, output_dir) - apidoc.main([*options, f"-o={output_dir}", module_dir, *exclude_patterns]) - apidoc.main( - [ - *options, - f"-o={output_dir}", - module_dir, - *[ - path.join(module_dir, exclude_pattern) - for exclude_pattern in exclude_patterns - ], - ] - ) + args = [*options, f"-o={output_dir}", module_dir, *[path.join(module_dir, exclude_pattern) for exclude_pattern in exclude_patterns]] + + print(f"{args = }") + + apidoc.main(args) def setup(app: Sphinx) -> None: From 112f41915ee985e1942ebf5993a35bfa3d7821e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 13 May 2024 09:33:39 +0200 Subject: [PATCH 16/16] fix(docs): exclude patterns --- docs/source/conf.py | 2 +- src/differt_dev/sphinxext/apidoc.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8678eb99..ee06339f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -73,7 +73,7 @@ "../../differt-core/python/differt_core", ] apidoc_output_dirs = "reference" -apidoc_exclude_patterns = ["conftest.py", "*/scene/scenes/*"] +apidoc_exclude_patterns = ["conftest.py", "scene/scenes/**"] apidoc_separate = True apidoc_no_toc = True apidoc_max_depth = 1 diff --git a/src/differt_dev/sphinxext/apidoc.py b/src/differt_dev/sphinxext/apidoc.py index 7e6a27b9..40f4f75a 100644 --- a/src/differt_dev/sphinxext/apidoc.py +++ b/src/differt_dev/sphinxext/apidoc.py @@ -71,9 +71,15 @@ def builder_inited(app: Sphinx) -> None: # noqa: C901 output_dir = path.join(app.srcdir, output_dir) - args = [*options, f"-o={output_dir}", module_dir, *[path.join(module_dir, exclude_pattern) for exclude_pattern in exclude_patterns]] - - print(f"{args = }") + args = [ + *options, + f"-o={output_dir}", + module_dir, + *[ + path.join(module_dir, exclude_pattern) + for exclude_pattern in exclude_patterns + ], + ] apidoc.main(args)