You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Optax currently doesn't seem to enforce dtype stability of the entries in the state in the update function, trusting the optimizers to perhaps upscale arguments. However, this can lead to unfortunate recompilations or dtype errors in jax conditional flow functions. For example see:
This repro reproduces the problem in the zoom line search when value in LBFGS is passed with a different type:
importjaxfromjaximportnumpyasjnpimportoptaxif__name__=="__main__":
fordtypein [jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64]:
a=jnp.array([1.0, 2.0, 3.0], dtype=dtype)
opt=optax.lbfgs()
state=opt.init(a)
defcond(a_state):
_, state=a_statereturnstate[0].count<10defstep(a_state):
a, state=a_statega=avalue=jnp.array(1.0) # this defaults to float64 or float32 typicallyupdates, state=opt.update(ga, state, a, value=value, grad=ga, value_fn=lambdax: jnp.mean(x**2))
a=optax.apply_updates(a, updates) # throws the errorreturna, statea_final, state_final=jax.lax.while_loop(cond, step, (a, state))
throws
TypeError: true_fun output and false_fun output must have identical types, got
ZoomLinesearchState(count='ShapedArray(int32[])', params='ShapedArray(float16[3])', updates='ShapedArray(float16[3])', stepsize_guess='ShapedArray(float64[])', stepsize='ShapedArray(float64[])', value='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', grad='ShapedArray(float16[3])', slope='ShapedArray(float16[])', value_init='ShapedArray(float64[])', slope_init='ShapedArray(float16[])', decrease_error='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', curvature_error='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', error='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', interval_found='ShapedArray(bool[])', done='ShapedArray(bool[])', failed='ShapedArray(bool[])', low='ShapedArray(float64[])', value_low='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', slope_low='ShapedArray(float16[])', high='ShapedArray(float64[])', value_high='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', slope_high='ShapedArray(float16[])', cubic_ref='ShapedArray(float64[])', value_cubic_ref='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', safe_stepsize='ShapedArray(float64[])', safe_value='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float16[])', safe_grad='ShapedArray(float16[3])').
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Should optax optimizers perform state dtype cast before returning their state?
For example: return jax.tree.map(lambda x, y: x.astype(y.dtype), new_state, state)
The text was updated successfully, but these errors were encountered:
Optax currently doesn't seem to enforce dtype stability of the entries in the state in the update function, trusting the optimizers to perhaps upscale arguments. However, this can lead to unfortunate recompilations or dtype errors in jax conditional flow functions. For example see:
#1171
This repro reproduces the problem in the zoom line search when
value
in LBFGS is passed with a different type:throws
Should optax optimizers perform state dtype cast before returning their state?
For example:
return jax.tree.map(lambda x, y: x.astype(y.dtype), new_state, state)
The text was updated successfully, but these errors were encountered: