From 47df575175bd4e7fa3bd9a15c0c731c771f248f3 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Thu, 2 Nov 2023 07:00:12 -0400 Subject: [PATCH] Add hessian search test --- descent/optim/_lm.py | 2 +- descent/tests/optim/test_lm.py | 59 ++++++++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/descent/optim/_lm.py b/descent/optim/_lm.py index c47c279..6fb517f 100644 --- a/descent/optim/_lm.py +++ b/descent/optim/_lm.py @@ -252,7 +252,7 @@ def _step( def _hessian_diagonal_search( x: torch.Tensor, - closure: torch.Tensor, + closure: tuple[torch.Tensor, torch.Tensor, torch.Tensor], closure_fn: ClosureFn, correct_fn: CorrectFn, damping_factor: torch.Tensor, diff --git a/descent/tests/optim/test_lm.py b/descent/tests/optim/test_lm.py index 72bc322..e4d1148 100644 --- a/descent/tests/optim/test_lm.py +++ b/descent/tests/optim/test_lm.py @@ -6,6 +6,7 @@ from descent.optim._lm import ( LevenbergMarquardtConfig, _damping_factor_loss_fn, + _hessian_diagonal_search, _solver, _step, levenberg_marquardt, @@ -82,6 +83,40 @@ def test_step_sd(config, caplog): assert "hessian has a small or negative eigenvalue" in caplog.text +def test_hessian_diagonal_search(caplog, mocker): + mocker.patch( + "scipy.optimize.brent", + autospec=True, + side_effect=[ + (torch.tensor(1.0), 2.0, None, None), + (torch.tensor(1.0)), + (torch.tensor(0.2), -2.0, None, None), + ], + ) + + config = LevenbergMarquardtConfig(max_steps=1) + + closure = (torch.tensor(1.0), torch.tensor([2.0]), torch.tensor([[3.0]])) + + theta = torch.tensor([0.0], requires_grad=True) + + with caplog.at_level(logging.INFO): + dx, expected = _hessian_diagonal_search( + theta, + closure, + mocker.MagicMock(), + mocker.MagicMock(), + torch.tensor(1.0), + torch.tensor(1.0), + config, + ) + + assert "restarting search with step size" in caplog.text + + assert dx.shape == theta.shape + assert torch.isclose(torch.tensor(expected), torch.tensor(-2.0)) + + def test_damping_factor_loss_fn(mocker): dx = torch.tensor([3.0, 4.0, 0.0]) dx_norm = torch.linalg.norm(dx) @@ -191,7 +226,16 @@ def mock_loss_fn(_x, *_): assert mock_step_fn_calls == expected_loss_traj -def test_levenberg_marquardt(): +@pytest.mark.parametrize( + "mode,n_steps,expected_n_grad_calls,expected_n_hess_calls", + [ + ("hessian-search", 2, 2 + 1, 2 + 1), + ("adaptive", 15, 15 + 1, 15 + 1), + ], +) +def test_levenberg_marquardt( + mode, n_steps, expected_n_grad_calls, expected_n_hess_calls +): expected = torch.tensor([5.0, 3.0, 2.0]) x_ref = torch.linspace(-2.0, 2.0, 100) @@ -203,7 +247,6 @@ def loss_fn(_theta: torch.Tensor) -> torch.Tensor: y = _theta[0] * x_ref**2 + _theta[1] * x_ref + _theta[2] return torch.sum((y - y_ref) ** 2) - n_loss_calls = 0 n_grad_calls = 0 n_hess_calls = 0 @@ -213,8 +256,7 @@ def closure_fn( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: loss, grad, hess = loss_fn(_theta), None, None - nonlocal n_loss_calls, n_grad_calls, n_hess_calls - n_loss_calls += 1 + nonlocal n_grad_calls, n_hess_calls if compute_gradient: (grad,) = torch.autograd.grad(loss, _theta, torch.tensor(1.0)) @@ -226,15 +268,12 @@ def closure_fn( return loss.detach(), grad, hess - n_steps = 15 + config = LevenbergMarquardtConfig(max_steps=n_steps, mode=mode) - theta_new = levenberg_marquardt( - theta, closure_fn, None, config=LevenbergMarquardtConfig(max_steps=n_steps) - ) + theta_new = levenberg_marquardt(theta, closure_fn, None, config) assert theta_new.shape == expected.shape assert torch.allclose(theta_new, expected) - assert n_loss_calls == n_steps + 1 # +1 for initial closure call - assert n_grad_calls == n_steps + 1 + assert n_grad_calls == n_steps + 1 # +1 for initial closure call assert n_hess_calls == n_steps + 1