From 9695a2256cf5043070620f5120a910d4024c0d5a Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Wed, 4 Dec 2024 15:45:33 -0800 Subject: [PATCH] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886640 --- docs/index.rst | 2 +- examples/contrib/differentially_private_sgd.ipynb | 2 +- optax/_src/alias_test.py | 4 ++-- optax/contrib/_common_test.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 694dadc97..925c85308 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,7 +35,7 @@ can be used to obtain the most recent version of Optax:: pip install git+git://github.com/google-deepmind/optax.git Note that Optax is built on top of JAX. -See `here `_ +See `here `_ for instructions on installing JAX. diff --git a/examples/contrib/differentially_private_sgd.ipynb b/examples/contrib/differentially_private_sgd.ipynb index 85c917d90..ba79ea228 100644 --- a/examples/contrib/differentially_private_sgd.ipynb +++ b/examples/contrib/differentially_private_sgd.ipynb @@ -12,7 +12,7 @@ "\n", "A large portion of this code is forked from the differentially private SGD\n", "example in the [JAX repo](\n", - "https://github.com/google/jax/blob/main/examples/differentially_private_sgd.py).\n", + "https://github.com/jax-ml/jax/blob/main/examples/differentially_private_sgd.py).\n", "\n", "[Differentially Private Stochastic Gradient Descent](https://arxiv.org/abs/1607.00133) requires clipping the per-example parameter\n", "gradients, which is non-trivial to implement efficiently for convolutional\n", diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 3e713734e..2df610b09 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -264,7 +264,7 @@ def test_preserve_dtype(self, opt_name, opt_kwargs, dtype): # x = 0.5**jnp.asarray(1, dtype=jnp.int32) # (appearing in e.g. optax.tree_utils.tree_bias_correction) # are promoted (strictly) to float32 when jitted - # see https://github.com/google/jax/issues/23337 + # see https://github.com/jax-ml/jax/issues/23337 # This may end up letting updates have a dtype different from params. # The solution is to fix the dtype of the result to the desired dtype # (just as done in optax.tree_utils.tree_bias_correction). @@ -851,7 +851,7 @@ def test_minimize_bad_initialization(self): chex.assert_trees_all_close(jnp_fun(optax_sol), minimum, atol=tol, rtol=tol) def test_steep_objective(self): - # See jax related issue https://github.com/google/jax/issues/4594 + # See jax related issue https://github.com/jax-ml/jax/issues/4594 tol = 1e-5 n = 2 mat = jnp.eye(n) * 1e4 diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index b20118483..c21b5a879 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -301,7 +301,7 @@ def test_preserve_dtype( # x = 0.5**jnp.asarray(1, dtype=jnp.int32) # (appearing in e.g. optax.tree_utils.tree_bias_correction) # are promoted (strictly) to float32 when jitted - # see https://github.com/google/jax/issues/23337 + # see https://github.com/jax-ml/jax/issues/23337 # This may end up letting updates have a dtype different from params. # The solution is to fix the dtype of the result to the desired dtype # (just as done in optax.tree_utils.tree_bias_correction).