Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

64 bit precision breaks LBFGS with scan #1210

Open
SNMS95 opened this issue Feb 26, 2025 · 4 comments
Open

64 bit precision breaks LBFGS with scan #1210

SNMS95 opened this issue Feb 26, 2025 · 4 comments
Assignees

Comments

@SNMS95
Copy link

SNMS95 commented Feb 26, 2025

I was trying to use lax.scan with optax.lbfgs and ran into the following issue. Everything works well when I don't use float64 precision with
jax.config.update("jax_enable_x64", True) but as soon as I turn it on, it shows the error that LBFGS's state is changing from iteration to iteration. The error message shows that this due to the line-search info with num_linesearch_steps changes datatype from i64 to i32

import optax
import jax
import jax.numpy as jnp
import equinox.internal as eqxi

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

def f(x): 
    return jnp.sum(x ** 2)

solver = optax.lbfgs()
params = jnp.array([1., 2., 3.])
print('Objective function: ', f(params))
opt_state = solver.init(params)
value_and_grad = jax.value_and_grad(f)
#optax.value_and_grad_from_state(f)

def body_fn(carry, _):
    params, opt_state = carry
    value, grad = value_and_grad(params)
    updates, opt_state = solver.update(grad, opt_state, params, value=value, grad=grad, value_fn=f)
    params = optax.apply_updates(params, updates)
    return (params, opt_state), None

(params, opt_state,), _ = eqxi.scan(
    body_fn,
    (params, opt_state),
    None,
    length=100,
    kind=("checkpoint"))
print('Objective function: ', f(params))
@rdyro
Copy link
Collaborator

rdyro commented Feb 26, 2025

Hey, thanks for pointing this out, do you think this could be related to: [https://github.com//issues/1207]

Pasting the output from your repro just for reference:

Objective function:  14.0
Traceback (most recent call last):
  File "/home/rdyro/test.py", line 26, in <module>
    (params, opt_state,), _ = eqxi.scan(
                              ^^^^^^^^^^
  File "/home/rdyro/.pyenv/versions/dev/lib/python3.12/site-packages/equinox/internal/_loop/loop.py", line 228, in scan
    _, carry, ys = checkpointed_while_loop(
                   ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro/.pyenv/versions/dev/lib/python3.12/site-packages/equinox/internal/_loop/checkpointed.py", line 247, in checkpointed_while_loop
    body_fun_ = filter_closure_convert(body_fun_, init_val_)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro/.pyenv/versions/dev/lib/python3.12/site-packages/equinox/_ad.py", line 709, in filter_closure_convert
    closed_jaxpr, out_dynamic_struct, out_static = filter_make_jaxpr(fn)(
                                                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro/.pyenv/versions/dev/lib/python3.12/site-packages/equinox/_make_jaxpr.py", line 43, in __call__
    jaxpr, out_struct = jax.make_jaxpr(_fn, return_shape=True)(*dynamic_flat)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro/.pyenv/versions/dev/lib/python3.12/site-packages/equinox/internal/_loop/common.py", line 497, in new_body_fun
    raise ValueError(
ValueError: `body_fun` must have the same input and output structure. Difference is:
  (
    i64[],
    (
      f64[3],
      (
        ScaleByLBFGSState(
          count=i32[],
          params=f64[3],
          updates=f64[3],
          diff_params_memory=f64[10,3],
          diff_updates_memory=f64[10,3],
          weights_memory=f64[10]
        ),
        EmptyState(),
        ScaleByZoomLinesearchState(
          learning_rate=f64[],
          value=f64[],
          grad=f64[3],
          info=ZoomLinesearchInfo(
-           num_linesearch_steps=i64[],
+           num_linesearch_steps=i32[],
            decrease_error=f64[],
            curvature_error=f64[]
          )
        )
      )
    ),
    None
  )
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@rdyro rdyro self-assigned this Feb 26, 2025
@SNMS95
Copy link
Author

SNMS95 commented Feb 26, 2025

I think its the same issue.
Is there a way to temporarily fix it by any chance?

@rdyro
Copy link
Collaborator

rdyro commented Feb 26, 2025

In your body_fn try:

prev_opt_state = opt_state
...
opt_state = jax.tree.map(lambda x, y: x.astype(y.dtype), opt_state, prev_opt_state)

@rdyro
Copy link
Collaborator

rdyro commented Feb 26, 2025

I'm working on fixing the underlying issue ASAP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants