Skip to content

Commit

Permalink
Return parameters with lowest loss from LM optimizer (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 9, 2023
1 parent 463b363 commit 092e2fe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
10 changes: 7 additions & 3 deletions descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def levenberg_marquardt(
whether the step was accepted or rejected.
Returns:
The optimized parameters.
The parameters that minimize the loss.
"""

x = x.clone().detach().requires_grad_(x.requires_grad)
Expand All @@ -526,9 +526,10 @@ def levenberg_marquardt(
trust_radius = torch.tensor(config.trust_radius).to(x.device)

loss_history = []

has_converged = False

best_x, best_loss = x.clone(), closure_prev[0]

for step in range(config.max_steps):
loss_prev, gradient_prev, hessian_prev = closure_prev

Expand Down Expand Up @@ -574,6 +575,9 @@ def levenberg_marquardt(
x.data.copy_(x_next.data)
loss_history.append(loss.detach().cpu().clone())

if loss < best_loss:
best_x, best_loss = x.clone(), loss.detach().clone()

closure_prev = (loss, gradient, hessian)

report_fn(step, x, loss, gradient, hessian, step_quality, accept_step)
Expand All @@ -588,4 +592,4 @@ def levenberg_marquardt(
if not has_converged:
_LOGGER.info(f"optimization has not converged after {config.max_steps} steps.")

return x
return best_x
10 changes: 8 additions & 2 deletions descent/tests/optim/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def test_levenberg_marquardt_adaptive(mocker, caplog):
torch.tensor([19.0, 20.0]),
torch.tensor([[21.0, 22.0], [23.0, 24.0]]),
),
(
torch.tensor(-2.11),
torch.tensor([19.0, 20.0]),
torch.tensor([[21.0, 22.0], [23.0, 24.0]]),
),
]
expected_loss_traj = [
(
Expand Down Expand Up @@ -264,9 +269,10 @@ def mock_loss_fn(_x, *_):
# previous step should have been rejected
torch.tensor([0.1, 0.2]),
torch.tensor([0.15, 0.21]),
torch.tensor([0.2, 0.3]),
]
assert x_new.shape == expected_x_traj[-1].shape
assert torch.allclose(x_new, expected_x_traj[-1])
assert x_new.shape == expected_x_traj[-2].shape # -2 has lowest loss.
assert torch.allclose(x_new, expected_x_traj[-2])

trust_radius_messages = [m for m in caplog.messages if "trust radius" in m]

Expand Down

0 comments on commit 092e2fe

Please sign in to comment.