diff --git a/differt2d/optimize.py b/differt2d/optimize.py index ca2f9a0..ed1eef5 100644 --- a/differt2d/optimize.py +++ b/differt2d/optimize.py @@ -39,7 +39,7 @@ def minimize( fun: Callable[[X], Y], x0: Array, steps: int = 100, - optimizer: optax.GradientTransformation = optax.adam(learning_rate=0.01), + optimizer: optax.GradientTransformation = optax.adam(learning_rate=0.1), ) -> Tuple[X, Y]: """ Minimizes a scalar function of one or more variables. @@ -67,7 +67,6 @@ def minimize( f_and_df = jax.value_and_grad(fun) opt_state = optimizer.init(x0) - @jax.jit def f(carry, x): x, opt_state = carry loss, grads = f_and_df(x)