Skip to content

Commit

Permalink
Add hessian search test
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Nov 2, 2023
1 parent 3564539 commit 47df575
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
2 changes: 1 addition & 1 deletion descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _step(

def _hessian_diagonal_search(
x: torch.Tensor,
closure: torch.Tensor,
closure: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
closure_fn: ClosureFn,
correct_fn: CorrectFn,
damping_factor: torch.Tensor,
Expand Down
59 changes: 49 additions & 10 deletions descent/tests/optim/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from descent.optim._lm import (
LevenbergMarquardtConfig,
_damping_factor_loss_fn,
_hessian_diagonal_search,
_solver,
_step,
levenberg_marquardt,
Expand Down Expand Up @@ -82,6 +83,40 @@ def test_step_sd(config, caplog):
assert "hessian has a small or negative eigenvalue" in caplog.text


def test_hessian_diagonal_search(caplog, mocker):
mocker.patch(
"scipy.optimize.brent",
autospec=True,
side_effect=[
(torch.tensor(1.0), 2.0, None, None),
(torch.tensor(1.0)),
(torch.tensor(0.2), -2.0, None, None),
],
)

config = LevenbergMarquardtConfig(max_steps=1)

closure = (torch.tensor(1.0), torch.tensor([2.0]), torch.tensor([[3.0]]))

theta = torch.tensor([0.0], requires_grad=True)

with caplog.at_level(logging.INFO):
dx, expected = _hessian_diagonal_search(
theta,
closure,
mocker.MagicMock(),
mocker.MagicMock(),
torch.tensor(1.0),
torch.tensor(1.0),
config,
)

assert "restarting search with step size" in caplog.text

assert dx.shape == theta.shape
assert torch.isclose(torch.tensor(expected), torch.tensor(-2.0))


def test_damping_factor_loss_fn(mocker):
dx = torch.tensor([3.0, 4.0, 0.0])
dx_norm = torch.linalg.norm(dx)
Expand Down Expand Up @@ -191,7 +226,16 @@ def mock_loss_fn(_x, *_):
assert mock_step_fn_calls == expected_loss_traj


def test_levenberg_marquardt():
@pytest.mark.parametrize(
"mode,n_steps,expected_n_grad_calls,expected_n_hess_calls",
[
("hessian-search", 2, 2 + 1, 2 + 1),
("adaptive", 15, 15 + 1, 15 + 1),
],
)
def test_levenberg_marquardt(
mode, n_steps, expected_n_grad_calls, expected_n_hess_calls
):
expected = torch.tensor([5.0, 3.0, 2.0])

x_ref = torch.linspace(-2.0, 2.0, 100)
Expand All @@ -203,7 +247,6 @@ def loss_fn(_theta: torch.Tensor) -> torch.Tensor:
y = _theta[0] * x_ref**2 + _theta[1] * x_ref + _theta[2]
return torch.sum((y - y_ref) ** 2)

n_loss_calls = 0
n_grad_calls = 0
n_hess_calls = 0

Expand All @@ -213,8 +256,7 @@ def closure_fn(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
loss, grad, hess = loss_fn(_theta), None, None

nonlocal n_loss_calls, n_grad_calls, n_hess_calls
n_loss_calls += 1
nonlocal n_grad_calls, n_hess_calls

if compute_gradient:
(grad,) = torch.autograd.grad(loss, _theta, torch.tensor(1.0))
Expand All @@ -226,15 +268,12 @@ def closure_fn(

return loss.detach(), grad, hess

n_steps = 15
config = LevenbergMarquardtConfig(max_steps=n_steps, mode=mode)

theta_new = levenberg_marquardt(
theta, closure_fn, None, config=LevenbergMarquardtConfig(max_steps=n_steps)
)
theta_new = levenberg_marquardt(theta, closure_fn, None, config)

assert theta_new.shape == expected.shape
assert torch.allclose(theta_new, expected)

assert n_loss_calls == n_steps + 1 # +1 for initial closure call
assert n_grad_calls == n_steps + 1
assert n_grad_calls == n_steps + 1 # +1 for initial closure call
assert n_hess_calls == n_steps + 1

0 comments on commit 47df575

Please sign in to comment.