Skip to content

Commit

Permalink
fix(tests): issue with beartype>=0.20 and relax `sample_points_in_...…
Browse files Browse the repository at this point in the history
…`'s args (#225)

* fix(tests): pin beartype<0.20 and relax `sample_points_in_...`'s args

* fixes

* fix: type hint
  • Loading branch information
jeertmans authored Feb 27, 2025
1 parent 19ae8fa commit d44c67e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 42 deletions.
47 changes: 12 additions & 35 deletions differt/src/differt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Callable, Iterable, Mapping
from functools import partial
from typing import Any, Concatenate, ParamSpec, overload
from typing import Any, Concatenate, ParamSpec

import chex
import equinox as eqx
Expand Down Expand Up @@ -108,19 +108,21 @@ def sorted_array2(array: Shaped[ArrayLike, "m n"]) -> Shaped[Array, "m n"]:
if array.size == 0:
return array

return array[jnp.lexsort(array.T[::-1])] # type: ignore[reportArgumentType]
return array[jnp.lexsort(array.T[::-1])]


# Redefined here, because chex uses deprecated type hints
# TODO: fixme when google/chex#361 is resolved.
_OptState = chex.Array | Iterable["_OptState"] | Mapping[Any, "_OptState"]
# TODO: changme because beartype>=0.20 complains that it cannot import ArrayTree,
# and I don't see how to fix it.
OptState = Any
# TODO: fixme when Python >= 3.11
_P = ParamSpec("_P")
P = ParamSpec("P")


@eqx.filter_jit
def minimize(
fun: Callable[Concatenate[Num[Array, " n"], _P], Num[Array, " "]],
fun: Callable[Concatenate[Num[Array, " n"], P], Num[Array, " "]],
x0: Num[ArrayLike, "*batch n"],
args: tuple[Any, ...] = (),
steps: int = 1000,
Expand Down Expand Up @@ -255,9 +257,9 @@ def minimize(
opt_state = optimizer.init(x0)

def f(
carry: tuple[Num[Array, "*batch n"], _OptState],
carry: tuple[Num[Array, "*batch n"], OptState],
_: None,
) -> tuple[tuple[Num[Array, "*batch n"], _OptState], Num[Array, " *batch"]]:
) -> tuple[tuple[Num[Array, "*batch n"], OptState], Num[Array, " *batch"]]:
x, opt_state = carry
loss, grads = f_and_df(x, *args)
updates, opt_state = optimizer.update(grads, opt_state)
Expand All @@ -270,41 +272,19 @@ def f(
return x, losses[-1]


@overload
def sample_points_in_bounding_box(
bounding_box: Float[ArrayLike, "2 3"],
shape: None = None,
*,
key: PRNGKeyArray,
) -> Float[Array, "3"]: ...


@overload
def sample_points_in_bounding_box(
bounding_box: Float[ArrayLike, "2 3"],
shape: tuple[int, ...],
*,
key: PRNGKeyArray,
) -> Float[Array, "3"]: ...


@partial(jax.jit, static_argnames=("shape",))
def sample_points_in_bounding_box(
bounding_box: Float[ArrayLike, "2 3"],
shape: tuple[int, ...] | None = None,
shape: tuple[int, ...] = (),
*,
key: PRNGKeyArray,
) -> (
Float[Array, "*shape 3"] | Float[Array, "3"]
): # TODO: use symbolic expression to link to 'shape' parameter
) -> Float[Array, "*shape 3"]:
"""
Sample point(s) in a 3D bounding box.
Args:
bounding_box: The bounding box (min. and max. coordinates).
shape: The sample shape or :data:`None`. If :data:`None`,
the returned array is 1D. Otherwise, the shape
of the returned array is ``(*shape, 3)``.
shape: The sample shape.
key: The :func:`jax.random.PRNGKey` to be used.
Returns:
Expand All @@ -315,9 +295,6 @@ def sample_points_in_bounding_box(
amax = bounding_box[1, :]
scale = amax - amin

if shape is None:
shape = ()

r = jax.random.uniform(key, shape=(*shape, 3))

return r * scale + amin
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dev = [
]
docs = [
"differt[all]",
"beartype>=0.19.0",
"beartype>=0.20",
"myst-nb>=0.17.2",
"pillow>=10.1.0",
"sphinx>=8.1.3",
Expand All @@ -35,7 +35,7 @@ prof = [
]
tests = [
"differt[all]",
"beartype>=0.19.0",
"beartype>=0.20",
"chex>=0.1.84",
"pytest>=7.4.3",
"pytest-benchmark>=4.0.0",
Expand Down Expand Up @@ -141,7 +141,7 @@ venvPath = "."

[tool.pytest.ini_options]
addopts = [
"--jaxtyping-packages=differt,beartype.beartype",
"--jaxtyping-packages=differt,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))",
"--numprocesses=logical",
"--cov-report=xml",
"--cov=differt/src/differt",
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d44c67e

Please sign in to comment.