-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
64 bit precision breaks LBFGS with scan #1210
Comments
Hey, thanks for pointing this out, do you think this could be related to: [https://github.com//issues/1207] Pasting the output from your repro just for reference:
|
I think its the same issue. |
In your body_fn try:
|
I'm working on fixing the underlying issue ASAP |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I was trying to use
lax.scan
withoptax.lbfgs
and ran into the following issue. Everything works well when I don't use float64 precision withjax.config.update("jax_enable_x64", True)
but as soon as I turn it on, it shows the error that LBFGS's state is changing from iteration to iteration. The error message shows that this due to the line-search info withnum_linesearch_steps
changes datatype fromi64
toi32
The text was updated successfully, but these errors were encountered: