Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State dtype stability in update_fn #1207

Open
rdyro opened this issue Feb 24, 2025 · 0 comments
Open

State dtype stability in update_fn #1207

rdyro opened this issue Feb 24, 2025 · 0 comments
Assignees

Comments

@rdyro
Copy link
Collaborator

rdyro commented Feb 24, 2025

Optax currently doesn't seem to enforce dtype stability of the entries in the state in the update function, trusting the optimizers to perhaps upscale arguments. However, this can lead to unfortunate recompilations or dtype errors in jax conditional flow functions. For example see:

#1171

This repro reproduces the problem in the zoom line search when value in LBFGS is passed with a different type:

import jax
from jax import numpy as jnp
import optax

if __name__ == "__main__":
    for dtype in [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]:
        a = jnp.array([1.0, 2.0, 3.0], dtype=dtype)
        opt = optax.lbfgs()
        state = opt.init(a)

        def cond(a_state):
            _, state = a_state
            return state[0].count < 10
        
        def step(a_state):
            a, state = a_state
            ga = a
            value = jnp.array(1.0)  # this defaults to float64 or float32 typically
            updates, state = opt.update(ga, state, a, value=value, grad=ga, value_fn=lambda x: jnp.mean(x ** 2))
            a = optax.apply_updates(a, updates) # throws the error
            return a, state

        a_final, state_final = jax.lax.while_loop(cond, step, (a, state))

throws

TypeError: true_fun output and false_fun output must have identical types, got
ZoomLinesearchState(count='ShapedArray(int32[])', params='ShapedArray(float16[3])', updates='ShapedArray(float16[3])', stepsize_guess='ShapedArray(float64[])', stepsize='ShapedArray(float64[])', value='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', grad='ShapedArray(float16[3])', slope='ShapedArray(float16[])', value_init='ShapedArray(float64[])', slope_init='ShapedArray(float16[])', decrease_error='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', curvature_error='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', error='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', interval_found='ShapedArray(bool[])', done='ShapedArray(bool[])', failed='ShapedArray(bool[])', low='ShapedArray(float64[])', value_low='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', slope_low='ShapedArray(float16[])', high='ShapedArray(float64[])', value_high='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', slope_high='ShapedArray(float16[])', cubic_ref='ShapedArray(float64[])', value_cubic_ref='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', safe_stepsize='ShapedArray(float64[])', safe_value='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', safe_grad='ShapedArray(float16[3])').
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Should optax optimizers perform state dtype cast before returning their state?

For example: return jax.tree.map(lambda x, y: x.astype(y.dtype), new_state, state)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants