Skip to content

Commit

Permalink
chore(lib): change legacy PRNGKey to key
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Oct 7, 2024
1 parent 7cef8d5 commit de27def
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion differt/src/differt/geometry/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def group_by_objects(self) -> Int[Array, " *batch"]:
>>> from differt.geometry.paths import Paths
>>>
>>> key = jax.random.PRNGKey(1234)
>>> key = jax.random.key(1234)
>>> key_v, key_o = jax.random.split(key, 2)
>>> *batch, path_length = (2, 6, 3)
>>> vertices = jax.random.uniform(key_v, (*batch, path_length, 3))
Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def sample(
Args:
size: The size of the sample, i.e., the number of triangles.
replace: Whether to sample with or without replacement.
key: The :func:`jax.random.PRNGKey` to be used.
key: The :func:`jax.random.key` to be used.
Returns:
A new random mesh.
Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/rt/image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def image_of_vertices_with_respect_to_mirrors(
... image_of_vertices_with_respect_to_mirrors,
... )
>>>
>>> key = jax.random.PRNGKey(0)
>>> key = jax.random.key(0)
>>> (
... key0,
... key1,
Expand Down
6 changes: 3 additions & 3 deletions differt/src/differt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def sorted_array2(array: Shaped[Array, "m n"]) -> Shaped[Array, "m n"]:
... )
>>>
>>> arr = jnp.arange(10).reshape(5, 2)
>>> key = jax.random.PRNGKey(1234)
>>> key = jax.random.key(1234)
>>> (
... key1,
... key2,
Expand Down Expand Up @@ -175,7 +175,7 @@ def minimize(
>>>
>>> batch = (1, 2, 3)
>>> n = 10
>>> key = jax.random.PRNGKey(1234)
>>> key = jax.random.key(1234)
>>> offset = jax.random.uniform(key, (*batch, n))
>>>
>>> def f(x, offset, scale=2.0):
Expand Down Expand Up @@ -246,7 +246,7 @@ def sample_points_in_bounding_box(
bounding_box: The bounding box (min. and max. coordinates).
size: The sample size or :data:`None`. If :data:`None`,
the returned array is 1D. Otherwise, it is 2D.
key: The :func:`jax.random.PRNGKey` to be used.
key: The :func:`jax.random.key` to be used.
Returns:
An array of points randomly sampled.
Expand Down
2 changes: 1 addition & 1 deletion differt/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def seed() -> int:

@pytest.fixture
def key(seed: int) -> PRNGKeyArray:
return jax.random.PRNGKey(seed)
return jax.random.key(seed)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion differt/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def wrapper(fun: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fun)
def _wrapper_(*args: Any, **kwargs: Any) -> Any:
bound_args = sig.bind(*args, **kwargs)
keys = jax.random.split(jax.random.PRNGKey(seed), len(arg_names))
keys = jax.random.split(jax.random.key(seed), len(arg_names))
for key, arg_name in zip(keys, arg_names, strict=False):
shape = bound_args.arguments[arg_name]
bound_args.arguments[arg_name] = sampler(key, shape)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/notebooks/performance_tips.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(1234)\n",
"key = jax.random.key(1234)\n",
"key1, key2, key3 = jax.random.split(key, 3)\n",
"\n",
"batch = (10, 100)\n",
Expand Down Expand Up @@ -250,7 +250,7 @@
"source": [
"from beartype import beartype as typechecker\n",
"\n",
"key = jax.random.PRNGKey(1234)\n",
"key = jax.random.key(1234)\n",
"key1, key2, key3 = jax.random.split(key, 3)\n",
"\n",
"batch = (100, 10, 2)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/type_checking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"metadata": {},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(1234)\n",
"key = jax.random.key(1234)\n",
"\n",
"arr = jax.random.randint(key, (10, 4), 0, 2)\n",
"arr"
Expand Down

0 comments on commit de27def

Please sign in to comment.