From de27def5c4f04274fa80b33c003318061e5dba3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Mon, 7 Oct 2024 14:08:50 +0200 Subject: [PATCH] chore(lib): change legacy `PRNGKey` to `key` --- differt/src/differt/geometry/paths.py | 2 +- differt/src/differt/geometry/triangle_mesh.py | 2 +- differt/src/differt/rt/image_method.py | 2 +- differt/src/differt/utils.py | 6 +++--- differt/tests/conftest.py | 2 +- differt/tests/utils.py | 2 +- docs/source/notebooks/performance_tips.ipynb | 4 ++-- docs/source/notebooks/type_checking.ipynb | 2 +- 8 files changed, 11 insertions(+), 11 deletions(-) diff --git a/differt/src/differt/geometry/paths.py b/differt/src/differt/geometry/paths.py index f12f5662..30114b28 100644 --- a/differt/src/differt/geometry/paths.py +++ b/differt/src/differt/geometry/paths.py @@ -251,7 +251,7 @@ def group_by_objects(self) -> Int[Array, " *batch"]: >>> from differt.geometry.paths import Paths >>> - >>> key = jax.random.PRNGKey(1234) + >>> key = jax.random.key(1234) >>> key_v, key_o = jax.random.split(key, 2) >>> *batch, path_length = (2, 6, 3) >>> vertices = jax.random.uniform(key_v, (*batch, path_length, 3)) diff --git a/differt/src/differt/geometry/triangle_mesh.py b/differt/src/differt/geometry/triangle_mesh.py index f3d536f4..49b35659 100644 --- a/differt/src/differt/geometry/triangle_mesh.py +++ b/differt/src/differt/geometry/triangle_mesh.py @@ -409,7 +409,7 @@ def sample( Args: size: The size of the sample, i.e., the number of triangles. replace: Whether to sample with or without replacement. - key: The :func:`jax.random.PRNGKey` to be used. + key: The :func:`jax.random.key` to be used. Returns: A new random mesh. diff --git a/differt/src/differt/rt/image_method.py b/differt/src/differt/rt/image_method.py index 44c0fcab..5134fb23 100644 --- a/differt/src/differt/rt/image_method.py +++ b/differt/src/differt/rt/image_method.py @@ -123,7 +123,7 @@ def image_of_vertices_with_respect_to_mirrors( ... image_of_vertices_with_respect_to_mirrors, ... ) >>> - >>> key = jax.random.PRNGKey(0) + >>> key = jax.random.key(0) >>> ( ... key0, ... key1, diff --git a/differt/src/differt/utils.py b/differt/src/differt/utils.py index afeff339..8595c3fb 100644 --- a/differt/src/differt/utils.py +++ b/differt/src/differt/utils.py @@ -43,7 +43,7 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]: ... ) >>> >>> arr = jnp.arange(10).reshape(5, 2) - >>> key = jax.random.PRNGKey(1234) + >>> key = jax.random.key(1234) >>> ( ... key1, ... key2, @@ -175,7 +175,7 @@ def minimize( >>> >>> batch = (1, 2, 3) >>> n = 10 - >>> key = jax.random.PRNGKey(1234) + >>> key = jax.random.key(1234) >>> offset = jax.random.uniform(key, (*batch, n)) >>> >>> def f(x, offset, scale=2.0): @@ -246,7 +246,7 @@ def sample_points_in_bounding_box( bounding_box: The bounding box (min. and max. coordinates). size: The sample size or :data:`None`. If :data:`None`, the returned array is 1D. Otherwise, it is 2D. - key: The :func:`jax.random.PRNGKey` to be used. + key: The :func:`jax.random.key` to be used. Returns: An array of points randomly sampled. diff --git a/differt/tests/conftest.py b/differt/tests/conftest.py index 1b62cb23..dc663424 100644 --- a/differt/tests/conftest.py +++ b/differt/tests/conftest.py @@ -22,7 +22,7 @@ def seed() -> int: @pytest.fixture def key(seed: int) -> PRNGKeyArray: - return jax.random.PRNGKey(seed) + return jax.random.key(seed) @pytest.fixture diff --git a/differt/tests/utils.py b/differt/tests/utils.py index 836391f0..da593f2b 100644 --- a/differt/tests/utils.py +++ b/differt/tests/utils.py @@ -29,7 +29,7 @@ def wrapper(fun: Callable[..., Any]) -> Callable[..., Any]: @wraps(fun) def _wrapper_(*args: Any, **kwargs: Any) -> Any: bound_args = sig.bind(*args, **kwargs) - keys = jax.random.split(jax.random.PRNGKey(seed), len(arg_names)) + keys = jax.random.split(jax.random.key(seed), len(arg_names)) for key, arg_name in zip(keys, arg_names, strict=False): shape = bound_args.arguments[arg_name] bound_args.arguments[arg_name] = sampler(key, shape) diff --git a/docs/source/notebooks/performance_tips.ipynb b/docs/source/notebooks/performance_tips.ipynb index ecc716d7..ee734bd4 100644 --- a/docs/source/notebooks/performance_tips.ipynb +++ b/docs/source/notebooks/performance_tips.ipynb @@ -109,7 +109,7 @@ }, "outputs": [], "source": [ - "key = jax.random.PRNGKey(1234)\n", + "key = jax.random.key(1234)\n", "key1, key2, key3 = jax.random.split(key, 3)\n", "\n", "batch = (10, 100)\n", @@ -250,7 +250,7 @@ "source": [ "from beartype import beartype as typechecker\n", "\n", - "key = jax.random.PRNGKey(1234)\n", + "key = jax.random.key(1234)\n", "key1, key2, key3 = jax.random.split(key, 3)\n", "\n", "batch = (100, 10, 2)\n", diff --git a/docs/source/notebooks/type_checking.ipynb b/docs/source/notebooks/type_checking.ipynb index 5d230705..1fc04dfc 100644 --- a/docs/source/notebooks/type_checking.ipynb +++ b/docs/source/notebooks/type_checking.ipynb @@ -77,7 +77,7 @@ "metadata": {}, "outputs": [], "source": [ - "key = jax.random.PRNGKey(1234)\n", + "key = jax.random.key(1234)\n", "\n", "arr = jax.random.randint(key, (10, 4), 0, 2)\n", "arr"