From ff8f68cbbf418d5e382af8a1bc9597964a597eac Mon Sep 17 00:00:00 2001 From: Amos You Date: Sun, 21 Jan 2024 13:27:56 -0800 Subject: [PATCH 1/7] fixed underlines in titles --- docs/api/control_variates.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/api/control_variates.rst b/docs/api/control_variates.rst index f4599a7ec..4d9906d4c 100644 --- a/docs/api/control_variates.rst +++ b/docs/api/control_variates.rst @@ -9,13 +9,13 @@ Control Variates moving_avg_baseline Control delta method -~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~ .. autofunction:: control_delta_method Control variates Jacobians -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: control_variates_jacobians Moving average baseline -~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: moving_avg_baseline From 5d4186f9945ffbbaf05f9525d702e78865d27218 Mon Sep 17 00:00:00 2001 From: Amos You Date: Thu, 15 Feb 2024 01:11:04 -0800 Subject: [PATCH 2/7] fix lookahead typing --- optax/_src/lookahead.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/optax/_src/lookahead.py b/optax/_src/lookahead.py index 1ca242cf9..98d6103b5 100644 --- a/optax/_src/lookahead.py +++ b/optax/_src/lookahead.py @@ -29,12 +29,12 @@ class LookaheadState(NamedTuple): """State of the `GradientTransformation` returned by `lookahead`. Attributes: - fast_state: Optimizer state of the fast optimizer. - steps_since_sync: Number of fast optimizer steps taken since slow and fast + fast_state (:class:`optax.OptState`): Optimizer state of the fast optimizer. + steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps taken since slow and fast parameters were synchronized. """ fast_state: base.OptState - steps_since_sync: jnp.ndarray + steps_since_sync: Union[jax.Array, int] class LookaheadParams(NamedTuple): @@ -48,8 +48,8 @@ class LookaheadParams(NamedTuple): [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) Attributes: - fast: Fast parameters. - slow: Slow parameters. + fast (:class:`optax.Params`): Fast parameters. + slow (:class:`optax.Params`): Slow parameters. """ fast: base.Params slow: base.Params From da953da7984adfdeeeeb163dbbef3ee00184ee03 Mon Sep 17 00:00:00 2001 From: Amos You Date: Thu, 15 Feb 2024 01:12:21 -0800 Subject: [PATCH 3/7] remove members in wrappers + fix wrapper state typing --- docs/api/optimizer_wrappers.rst | 5 -- optax/_src/lookahead.py | 4 +- optax/_src/wrappers.py | 83 +++++++++++++++++++-------------- 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/docs/api/optimizer_wrappers.rst b/docs/api/optimizer_wrappers.rst index b029f0f0c..21b91724a 100644 --- a/docs/api/optimizer_wrappers.rst +++ b/docs/api/optimizer_wrappers.rst @@ -25,7 +25,6 @@ Apply if finite ~~~~~~~~~~~~~~~~~ .. autofunction:: apply_if_finite .. autoclass:: ApplyIfFiniteState - :members: Flatten ~~~~~~~~ @@ -37,26 +36,22 @@ Lookahead .. autoclass:: LookaheadParams :members: .. autoclass:: LookaheadState - :members: Masked update ~~~~~~~~~~~~~~ .. autofunction:: masked .. autoclass:: MaskedState - :members: Maybe update ~~~~~~~~~~~~~~ .. autofunction:: maybe_update .. autoclass:: MaybeUpdateState - :members: Multi-step update ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MultiSteps :members: .. autoclass:: MultiStepsState - :members: .. autoclass:: ShouldSkipUpdateFunction :members: .. autofunction:: skip_large_updates diff --git a/optax/_src/lookahead.py b/optax/_src/lookahead.py index 98d6103b5..5ef299424 100644 --- a/optax/_src/lookahead.py +++ b/optax/_src/lookahead.py @@ -30,8 +30,8 @@ class LookaheadState(NamedTuple): Attributes: fast_state (:class:`optax.OptState`): Optimizer state of the fast optimizer. - steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps taken since slow and fast - parameters were synchronized. + steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps + taken since slow and fast parameters were synchronized. """ fast_state: base.OptState steps_since_sync: Union[jax.Array, int] diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 9d0d50bc3..28d0cd148 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -89,23 +89,22 @@ def update_fn(updates, state, params=None, **extra_args): class ApplyIfFiniteState(NamedTuple): """State of the `GradientTransformation` returned by `apply_if_finite`. - Fields: - notfinite_count: Number of consecutive gradient updates containing an Inf or - a NaN. This number is reset to 0 whenever a gradient update without an Inf - or a NaN is done. - last_finite: Whether or not the last gradient update contained an Inf or a - NaN. - total_notfinite: Total number of gradient updates containing an Inf or - a NaN since this optimizer was initialised. This number is never reset. - inner_state: The state of the inner `GradientTransformation`. + Attributes: + notfinite_count (``Union[jax.Array, int]``): Number of consecutive gradient + updates containing an Inf or a NaN. This number is reset to 0 whenever a + gradient update without an Inf or a NaN is done. + last_finite (``Union[jax.Array, int]``): Whether or not the last gradient + update contained an Inf or a NaN. + total_notfinite (``Union[jax.Array, int]``): Total number of gradient + updates containing an Inf or a NaN since this optimizer was initialised. + This number is never reset. + inner_state (:class:`optax.OptState`): The state of the inner + `GradientTransformation`. """ - # TODO(optax-dev): notfinite_count, last_finite and inner_state used to be - # annotated as `jnp.array` but that is not a valid annotation (it's a function - # and secretely resolved to `Any`. We should add back typing. - notfinite_count: Any - last_finite: Any - total_notfinite: Any - inner_state: Any + notfinite_count: Union[jax.Array, int] + last_finite: Union[jax.Array, int] + total_notfinite: Union[jax.Array, int] + inner_state: base.OptState def apply_if_finite( @@ -175,23 +174,24 @@ def _zeros_tree_like(inp_tree: chex.ArrayTree) -> chex.ArrayTree: class MultiStepsState(NamedTuple): """State of the `GradientTransformation` returned by `MultiSteps`. - Fields: - mini_step: current mini-step counter. At an update, this either increases by - 1 or is reset to 0. - gradient_step: gradient step counter. This only increases after enough - mini-steps have been accumulated. - inner_opt_state: the state of the wrapped otpimiser. - acc_grads: accumulated gradients over multiple mini-steps. - skip_state: an arbitrarily nested tree of arrays. This is only - relevant when passing a `should_skip_update_fn` to `MultiSteps`. This - structure will then contain values for debugging and or monitoring. The - actual structure will vary depending on the choice of + Attributes: + mini_step (``Union[jax.Array, int]``): current mini-step counter. At an + update, this either increases by 1 or is reset to 0. + gradient_step (``Union[jax.Array, int]``): gradient step counter. This only + increases after enough mini-steps have been accumulated. + inner_opt_state (:class:`optax.OptState`): the state of the wrapped + optimiser. + acc_grads (``jax.Array``): accumulated gradients over multiple mini-steps. + skip_state (``chex.ArrayTree``): an arbitrarily nested tree of arrays. This + is only relevant when passing a `should_skip_update_fn` to `MultiSteps`. + This structure will then contain values for debugging and or monitoring. + The actual structure will vary depending on the choice of `ShouldSkipUpdateFunction`. """ - mini_step: Array - gradient_step: Array - inner_opt_state: Any - acc_grads: Any + mini_step: Union[jax.Array, int] + gradient_step: Union[jax.Array, int] + inner_opt_state: base.OptState + acc_grads: jax.Array # TODO: double check this one skip_state: chex.ArrayTree = () @@ -448,8 +448,13 @@ def gradient_transformation(self) -> base.GradientTransformation: class MaskedState(NamedTuple): - """Maintains inner transform state for masked transformations.""" - inner_state: Any + """Maintains inner transform state for masked transformations. + + Attributes: + inner_state (:class:`optax.OptState`): The state of the inner + `GradientTransformation`. + """ + inner_state: base.OptState class MaskedNode(NamedTuple): @@ -563,9 +568,15 @@ def update_fn(updates, state, params=None, **extra_args): class MaybeUpdateState(NamedTuple): - """Maintains inner transform state and adds a step counter.""" - inner_state: Any - step: Array + """Maintains inner transform state and adds a step counter. + + Attributes: + inner_state (:class:`optax.OptState`): The state of the inner + `GradientTransformation`. + step (``Union[jax.Array, int]``): The current step counter. + """ + inner_state: base.OptState + step: Union[jax.Array, int] def maybe_update( From 65d4be67c8dd4b537192d3adf031f4ace8088617 Mon Sep 17 00:00:00 2001 From: Amos You Date: Tue, 20 Feb 2024 19:40:04 -0800 Subject: [PATCH 4/7] remove members in combine, contrib, schedules, transformations --- docs/api/combining_optimizers.rst | 1 - docs/api/contrib.rst | 15 --------------- docs/api/optimizer_schedules.rst | 1 - docs/api/transformations.rst | 26 -------------------------- 4 files changed, 43 deletions(-) diff --git a/docs/api/combining_optimizers.rst b/docs/api/combining_optimizers.rst index 2651fd5f5..48ff3aa49 100644 --- a/docs/api/combining_optimizers.rst +++ b/docs/api/combining_optimizers.rst @@ -15,4 +15,3 @@ Multi-transform ~~~~~~~~~~~~~~~ .. autofunction:: multi_transform .. autoclass:: MultiTransformState - :members: diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index f0e2d1dd3..253e17e8a 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -27,47 +27,32 @@ Complex-valued Optimization ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: split_real_and_imaginary .. autoclass:: SplitRealAndImaginaryState - :members: Continuous coin betting ~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: cocob .. autoclass:: COCOBState - :members: D-adaptation ~~~~~~~~~~~~ .. autofunction:: dadapt_adamw .. autoclass:: DAdaptAdamWState - :members: - -Privacy-Sensitive Optax Methods -------------------------------- - -.. autosummary:: - DifferentiallyPrivateAggregateState - differentially_private_aggregate - Differentially Private Aggregate ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: differentially_private_aggregate .. autoclass:: DifferentiallyPrivateAggregateState - :members: .. autofunction:: dpsgd - Mechanize ~~~~~~~~~ .. autofunction:: mechanize .. autoclass:: MechanicState - :members: Prodigy ~~~~~~~ .. autofunction:: prodigy .. autoclass:: ProdigyState - :members: Sharpness aware minimization ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/api/optimizer_schedules.rst b/docs/api/optimizer_schedules.rst index 299c09b74..5d3867dc3 100644 --- a/docs/api/optimizer_schedules.rst +++ b/docs/api/optimizer_schedules.rst @@ -46,7 +46,6 @@ Inject hyperparameters ~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: inject_hyperparams .. autoclass:: InjectHyperparamsState - :members: Linear schedules ~~~~~~~~~~~~~~~~ diff --git a/docs/api/transformations.rst b/docs/api/transformations.rst index 5ca5289f1..30f35116d 100644 --- a/docs/api/transformations.rst +++ b/docs/api/transformations.rst @@ -115,19 +115,15 @@ Transformations and states .. autofunction:: adaptive_grad_clip .. autoclass:: AdaptiveGradClipState - :members: .. autofunction:: add_decayed_weights .. autoclass:: AddDecayedWeightsState - :members: .. autofunction:: add_noise .. autoclass:: AddNoiseState - :members: .. autofunction:: apply_every .. autoclass:: ApplyEvery - :members: .. autofunction:: bias_correction @@ -136,18 +132,14 @@ Transformations and states .. autofunction:: clip .. autofunction:: clip_by_block_rms .. autoclass:: ClipState - :members: .. autofunction:: clip_by_global_norm .. autoclass:: ClipByGlobalNormState - :members: .. autofunction:: ema .. autoclass:: EmaState - :members: .. autoclass:: EmptyState - :members: .. autofunction:: global_norm @@ -155,45 +147,36 @@ Transformations and states .. autofunction:: keep_params_nonnegative .. autoclass:: NonNegativeParamsState - :members: .. autofunction:: per_example_global_norm_clip .. autofunction:: per_example_layer_norm_clip .. autofunction:: scale .. autoclass:: ScaleState - :members: .. autofunction:: scale_by_adadelta .. autoclass:: ScaleByAdaDeltaState - :members: .. autofunction:: scale_by_adam .. autofunction:: scale_by_adamax .. autoclass:: ScaleByAdamState - :members: .. autofunction:: scale_by_amsgrad .. autoclass:: ScaleByAmsgradState - :members: .. autofunction:: scale_by_belief .. autoclass:: ScaleByBeliefState - :members: .. autofunction:: scale_by_factored_rms .. autoclass:: FactoredState - :members: .. autofunction:: scale_by_learning_rate .. autofunction:: scale_by_lion .. autoclass:: ScaleByLionState - :members: .. autofunction:: scale_by_novograd .. autoclass:: ScaleByNovogradState - :members: .. autofunction:: scale_by_optimistic_gradient @@ -205,31 +188,24 @@ Transformations and states .. autofunction:: scale_by_rms .. autoclass:: ScaleByRmsState - :members: .. autofunction:: scale_by_rprop .. autoclass:: ScaleByRpropState - :members: .. autofunction:: scale_by_rss .. autoclass:: ScaleByRssState - :members: .. autofunction:: scale_by_schedule .. autoclass:: ScaleByScheduleState - :members: .. autofunction:: scale_by_sm3 .. autoclass:: ScaleBySM3State - :members: .. autofunction:: scale_by_stddev .. autoclass:: ScaleByRStdDevState - :members: .. autofunction:: scale_by_trust_ratio .. autoclass:: ScaleByTrustRatioState - :members: .. autofunction:: scale_by_yogi @@ -240,7 +216,6 @@ Transformations and states .. autofunction:: trace .. autoclass:: TraceState - :members: .. autofunction:: update_infinity_moment .. autofunction:: update_moment @@ -250,4 +225,3 @@ Transformations and states .. autofunction:: zero_nans .. autoclass:: ZeroNansState - :members: From 6398a1e793fe704df9f45b7fc41438a818931b98 Mon Sep 17 00:00:00 2001 From: Amos You Date: Tue, 20 Feb 2024 19:58:14 -0800 Subject: [PATCH 5/7] fix zero nan state typing --- optax/_src/constrain.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/optax/_src/constrain.py b/optax/_src/constrain.py index e98d12abb..6339f4641 100644 --- a/optax/_src/constrain.py +++ b/optax/_src/constrain.py @@ -14,7 +14,7 @@ # ============================================================================== """Gradient transformations used to enforce specific constraints.""" -from typing import Any, NamedTuple +from typing import NamedTuple import jax import jax.numpy as jnp @@ -57,13 +57,15 @@ def update_fn(updates, state, params): class ZeroNansState(NamedTuple): - """Contains a tree. + """State of the `GradientTransformation` returned by `zero_nans`. - The entry `found_nan` has the same tree structure as that of the parameters. - Each leaf is a single boolean which contains True iff a NaN was detected in - the corresponding parameter array at the last call to `update`. + Attributes: + found_nan (``jax.Array``): tree that has the same structure as that of the + parameters. Each leaf is a single boolean which contains True iff a NaN + was detected in the corresponding parameter array at the last call to + `update`. """ - found_nan: Any + found_nan: jax.Array def zero_nans() -> base.GradientTransformation: From 09565f93d666fa63b878dda536e1480c3e006d96 Mon Sep 17 00:00:00 2001 From: Amos You Date: Fri, 15 Mar 2024 11:27:06 -0700 Subject: [PATCH 6/7] fix typing + add back members for wrappers --- docs/api/combining_optimizers.rst | 1 + docs/api/optimizer_wrappers.rst | 5 +++++ optax/_src/combine.py | 1 + optax/_src/constrain.py | 3 ++- optax/_src/wrappers.py | 2 +- 5 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/api/combining_optimizers.rst b/docs/api/combining_optimizers.rst index 48ff3aa49..0116b4d28 100644 --- a/docs/api/combining_optimizers.rst +++ b/docs/api/combining_optimizers.rst @@ -6,6 +6,7 @@ Combining Optimizers .. autosummary:: chain multi_transform + MultiTransformState Chain ~~~~~ diff --git a/docs/api/optimizer_wrappers.rst b/docs/api/optimizer_wrappers.rst index 21b91724a..b029f0f0c 100644 --- a/docs/api/optimizer_wrappers.rst +++ b/docs/api/optimizer_wrappers.rst @@ -25,6 +25,7 @@ Apply if finite ~~~~~~~~~~~~~~~~~ .. autofunction:: apply_if_finite .. autoclass:: ApplyIfFiniteState + :members: Flatten ~~~~~~~~ @@ -36,22 +37,26 @@ Lookahead .. autoclass:: LookaheadParams :members: .. autoclass:: LookaheadState + :members: Masked update ~~~~~~~~~~~~~~ .. autofunction:: masked .. autoclass:: MaskedState + :members: Maybe update ~~~~~~~~~~~~~~ .. autofunction:: maybe_update .. autoclass:: MaybeUpdateState + :members: Multi-step update ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MultiSteps :members: .. autoclass:: MultiStepsState + :members: .. autoclass:: ShouldSkipUpdateFunction :members: .. autofunction:: skip_large_updates diff --git a/optax/_src/combine.py b/optax/_src/combine.py index 4c7585be6..62ee65125 100644 --- a/optax/_src/combine.py +++ b/optax/_src/combine.py @@ -128,6 +128,7 @@ def update_fn(updates, state, params=None, **extra_args): class MultiTransformState(NamedTuple): + """State of the `GradientTransformation` returned by `multi_transform`.""" inner_states: Mapping[Hashable, base.OptState] diff --git a/optax/_src/constrain.py b/optax/_src/constrain.py index 045a350ef..d8e78e227 100644 --- a/optax/_src/constrain.py +++ b/optax/_src/constrain.py @@ -16,6 +16,7 @@ from typing import NamedTuple +import chex import jax import jax.numpy as jnp @@ -65,7 +66,7 @@ class ZeroNansState(NamedTuple): was detected in the corresponding parameter array at the last call to `update`. """ - found_nan: jax.Array + found_nan: chex.ArrayTree def zero_nans() -> base.GradientTransformation: diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index c9f715dde..50ceb3f3e 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -191,7 +191,7 @@ class MultiStepsState(NamedTuple): mini_step: Union[jax.Array, int] gradient_step: Union[jax.Array, int] inner_opt_state: base.OptState - acc_grads: jax.Array # TODO: double check this one + acc_grads: base.Updates skip_state: chex.ArrayTree = () From 2d01fca964f6a5f20473842538458c10bebbf050 Mon Sep 17 00:00:00 2001 From: Amos You Date: Tue, 16 Apr 2024 21:51:39 -0700 Subject: [PATCH 7/7] remove members in contrib + wrappers --- docs/api/contrib.rst | 1 - docs/api/optimizer_wrappers.rst | 8 -------- 2 files changed, 9 deletions(-) diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index 253e17e8a..de2a31d92 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -58,4 +58,3 @@ Sharpness aware minimization ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: sam .. autoclass:: SAMState - :members: diff --git a/docs/api/optimizer_wrappers.rst b/docs/api/optimizer_wrappers.rst index b029f0f0c..1c5eb30eb 100644 --- a/docs/api/optimizer_wrappers.rst +++ b/docs/api/optimizer_wrappers.rst @@ -25,7 +25,6 @@ Apply if finite ~~~~~~~~~~~~~~~~~ .. autofunction:: apply_if_finite .. autoclass:: ApplyIfFiniteState - :members: Flatten ~~~~~~~~ @@ -35,29 +34,22 @@ Lookahead ~~~~~~~~~~~~~~~~~ .. autofunction:: lookahead .. autoclass:: LookaheadParams - :members: .. autoclass:: LookaheadState - :members: Masked update ~~~~~~~~~~~~~~ .. autofunction:: masked .. autoclass:: MaskedState - :members: Maybe update ~~~~~~~~~~~~~~ .. autofunction:: maybe_update .. autoclass:: MaybeUpdateState - :members: Multi-step update ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MultiSteps - :members: .. autoclass:: MultiStepsState - :members: .. autoclass:: ShouldSkipUpdateFunction - :members: .. autofunction:: skip_large_updates .. autofunction:: skip_not_finite