Replies: 1 comment
-
I think it makes sense to add this option. Would you be interested in contributing a PR? If you're interested in a quick workaround, optax has a very simple implementation of this functionality which you could directly reproduce in your code Line 287 in 9b682ab |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have an expensive loss function that returns a lot of intermediate results as auxiliary data.
I want to optimize it with LBFGS (with line-search). However, to cache the calculated gradients, I need to use
opt.value_and_grad_from_state
but it does not allow to thread auxiliary values.For a simple case, I have adapted the example into a MWE.
Is there an option like
has_aux
as is available injax.value_and_grad
Beta Was this translation helpful? Give feedback.
All reactions