diff --git a/descent/optim/_lm.py b/descent/optim/_lm.py index 16bc19b..d639b38 100644 --- a/descent/optim/_lm.py +++ b/descent/optim/_lm.py @@ -512,7 +512,7 @@ def levenberg_marquardt( whether the step was accepted or rejected. Returns: - The optimized parameters. + The parameters that minimize the loss. """ x = x.clone().detach().requires_grad_(x.requires_grad) @@ -526,9 +526,10 @@ def levenberg_marquardt( trust_radius = torch.tensor(config.trust_radius).to(x.device) loss_history = [] - has_converged = False + best_x, best_loss = x.clone(), closure_prev[0] + for step in range(config.max_steps): loss_prev, gradient_prev, hessian_prev = closure_prev @@ -574,6 +575,9 @@ def levenberg_marquardt( x.data.copy_(x_next.data) loss_history.append(loss.detach().cpu().clone()) + if loss < best_loss: + best_x, best_loss = x.clone(), loss.detach().clone() + closure_prev = (loss, gradient, hessian) report_fn(step, x, loss, gradient, hessian, step_quality, accept_step) @@ -588,4 +592,4 @@ def levenberg_marquardt( if not has_converged: _LOGGER.info(f"optimization has not converged after {config.max_steps} steps.") - return x + return best_x diff --git a/descent/tests/optim/test_lm.py b/descent/tests/optim/test_lm.py index 24462d5..e4d55f9 100644 --- a/descent/tests/optim/test_lm.py +++ b/descent/tests/optim/test_lm.py @@ -234,6 +234,11 @@ def test_levenberg_marquardt_adaptive(mocker, caplog): torch.tensor([19.0, 20.0]), torch.tensor([[21.0, 22.0], [23.0, 24.0]]), ), + ( + torch.tensor(-2.11), + torch.tensor([19.0, 20.0]), + torch.tensor([[21.0, 22.0], [23.0, 24.0]]), + ), ] expected_loss_traj = [ ( @@ -264,9 +269,10 @@ def mock_loss_fn(_x, *_): # previous step should have been rejected torch.tensor([0.1, 0.2]), torch.tensor([0.15, 0.21]), + torch.tensor([0.2, 0.3]), ] - assert x_new.shape == expected_x_traj[-1].shape - assert torch.allclose(x_new, expected_x_traj[-1]) + assert x_new.shape == expected_x_traj[-2].shape # -2 has lowest loss. + assert torch.allclose(x_new, expected_x_traj[-2]) trust_radius_messages = [m for m in caplog.messages if "trust radius" in m]