Skip to content

Commit

Permalink
fix(tests): pin beartype<0.20 and relax sample_points_in_...'s args
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Feb 27, 2025
1 parent 19ae8fa commit 4b4fdd8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 44 deletions.
46 changes: 11 additions & 35 deletions differt/src/differt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Callable, Iterable, Mapping
from functools import partial
from typing import Any, Concatenate, ParamSpec, overload
from typing import Any, Concatenate, ParamSpec

import chex
import equinox as eqx
Expand Down Expand Up @@ -108,19 +108,20 @@ def sorted_array2(array: Shaped[ArrayLike, "m n"]) -> Shaped[Array, "m n"]:
if array.size == 0:
return array

return array[jnp.lexsort(array.T[::-1])] # type: ignore[reportArgumentType]
return array[jnp.lexsort(array.T[::-1])]


# Redefined here, because chex uses deprecated type hints
# TODO: fixme when google/chex#361 is resolved.
_OptState = chex.Array | Iterable["_OptState"] | Mapping[Any, "_OptState"]
ArrayTree = chex.Array | Iterable["ArrayTree"] | Mapping[Any, "ArrayTree"]
OptState = ArrayTree
# TODO: fixme when Python >= 3.11
_P = ParamSpec("_P")
P = ParamSpec("P")


@eqx.filter_jit
def minimize(
fun: Callable[Concatenate[Num[Array, " n"], _P], Num[Array, " "]],
fun: Callable[Concatenate[Num[Array, " n"], P], Num[Array, " "]],
x0: Num[ArrayLike, "*batch n"],
args: tuple[Any, ...] = (),
steps: int = 1000,
Expand Down Expand Up @@ -255,9 +256,9 @@ def minimize(
opt_state = optimizer.init(x0)

def f(
carry: tuple[Num[Array, "*batch n"], _OptState],
carry: tuple[Num[Array, "*batch n"], OptState],
_: None,
) -> tuple[tuple[Num[Array, "*batch n"], _OptState], Num[Array, " *batch"]]:
) -> tuple[tuple[Num[Array, "*batch n"], OptState], Num[Array, " *batch"]]:
x, opt_state = carry
loss, grads = f_and_df(x, *args)
updates, opt_state = optimizer.update(grads, opt_state)
Expand All @@ -270,41 +271,19 @@ def f(
return x, losses[-1]


@overload
def sample_points_in_bounding_box(
bounding_box: Float[ArrayLike, "2 3"],
shape: None = None,
*,
key: PRNGKeyArray,
) -> Float[Array, "3"]: ...


@overload
def sample_points_in_bounding_box(
bounding_box: Float[ArrayLike, "2 3"],
shape: tuple[int, ...],
*,
key: PRNGKeyArray,
) -> Float[Array, "3"]: ...


@partial(jax.jit, static_argnames=("shape",))
def sample_points_in_bounding_box(
bounding_box: Float[ArrayLike, "2 3"],
shape: tuple[int, ...] | None = None,
shape: tuple[int, ...] = (),
*,
key: PRNGKeyArray,
) -> (
Float[Array, "*shape 3"] | Float[Array, "3"]
): # TODO: use symbolic expression to link to 'shape' parameter
) -> Float[Array, "{*shape} 3"]:
"""
Sample point(s) in a 3D bounding box.
Args:
bounding_box: The bounding box (min. and max. coordinates).
shape: The sample shape or :data:`None`. If :data:`None`,
the returned array is 1D. Otherwise, the shape
of the returned array is ``(*shape, 3)``.
shape: The sample shape.
key: The :func:`jax.random.PRNGKey` to be used.
Returns:
Expand All @@ -315,9 +294,6 @@ def sample_points_in_bounding_box(
amax = bounding_box[1, :]
scale = amax - amin

if shape is None:
shape = ()

r = jax.random.uniform(key, shape=(*shape, 3))

return r * scale + amin
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dev = [
]
docs = [
"differt[all]",
"beartype>=0.19.0",
"beartype>=0.19.0,<0.20",
"myst-nb>=0.17.2",
"pillow>=10.1.0",
"sphinx>=8.1.3",
Expand All @@ -35,7 +35,7 @@ prof = [
]
tests = [
"differt[all]",
"beartype>=0.19.0",
"beartype>=0.19.0,<0.20", # Beartype 0.20 fails to type-check OptState
"chex>=0.1.84",
"pytest>=7.4.3",
"pytest-benchmark>=4.0.0",
Expand Down
14 changes: 7 additions & 7 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 4b4fdd8

Please sign in to comment.