From 53bd12ec0bb41a1d725490873aa94c92f05afa5b Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Thu, 2 Nov 2023 07:06:49 -0400 Subject: [PATCH] Add hessian diagonal search support to LM optimizer (#45) --- descent/optim/__init__.py | 4 +- descent/optim/_lm.py | 278 ++++++++++++++++++++++++--------- descent/tests/optim/test_lm.py | 134 +++++++++++----- 3 files changed, 297 insertions(+), 119 deletions(-) diff --git a/descent/optim/__init__.py b/descent/optim/__init__.py index 387a9ca..8ec8e9a 100644 --- a/descent/optim/__init__.py +++ b/descent/optim/__init__.py @@ -1,5 +1,5 @@ """Custom parameter optimizers.""" -from descent.optim._lm import LevenbergMarquardt, LevenbergMarquardtConfig +from descent.optim._lm import LevenbergMarquardtConfig, levenberg_marquardt -__all__ = ["LevenbergMarquardt", "LevenbergMarquardtConfig"] +__all__ = ["LevenbergMarquardtConfig", "levenberg_marquardt"] diff --git a/descent/optim/_lm.py b/descent/optim/_lm.py index d7b4ee7..6fb517f 100644 --- a/descent/optim/_lm.py +++ b/descent/optim/_lm.py @@ -20,31 +20,74 @@ ClosureFn = typing.Callable[ - [torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor] + [torch.Tensor, bool, bool], tuple[torch.Tensor, torch.Tensor, torch.Tensor] ] CorrectFn = typing.Callable[[torch.Tensor], torch.Tensor] +Mode = typing.Literal["adaptive", "hessian-search"] +_ADAPTIVE, _HESSIAN_SEARCH = typing.get_args(Mode) + + class LevenbergMarquardtConfig(pydantic.BaseModel): """Configuration for the Levenberg-Marquardt optimizer.""" type: typing.Literal["levenberg-marquardt"] = "levenberg-marquardt" + mode: Mode = pydantic.Field( + _ADAPTIVE, description="The mode to run the optimizer in." + ) + trust_radius: float = pydantic.Field( 0.2, description="Target trust radius.", gt=0.0 ) - min_trust_radius: float = pydantic.Field(0.05, description="Minimum trust radius.") + trust_radius_min: float = pydantic.Field(0.05, description="Minimum trust radius.") + + min_eigenvalue: float = pydantic.Field( + 1.0e-4, + description="Lower bound on hessian eigenvalue. If the smallest eigenvalue " + "is smaller than this, a small amount of steepest descent is mixed in prior " + "to taking a next step to try and correct this.", + ) + min_damping_factor: float = pydantic.Field( + 1.0, description="Minimum damping factor.", gt=0.0 + ) adaptive_factor: float = pydantic.Field( - 0.25, description="Adaptive trust radius adjustment factor.", gt=0.0 + 0.25, + description="Adaptive trust radius adjustment factor to use when running in " + "adaptive mode.", + gt=0.0, ) adaptive_damping: float = pydantic.Field( - 1.0, description="Adaptive trust radius adjustment damping.", gt=0.0 + 1.0, + description="Adaptive trust radius adjustment damping to use when running in " + "adaptive mode.", + gt=0.0, + ) + + search_tolerance: float = pydantic.Field( + 1.0e-4, + description="The tolerance used when searching for the optimal damping factor " + "with hessian diagonal search (i.e. ``mode='hessian-search'``).", + gt=0.0, + ) + search_trust_radius_max: float = pydantic.Field( + 1.0e-3, + description="The maximum trust radius to use when falling back to a second " + "line search if the loss would increase after the one.", + gt=0.0, + ) + search_trust_radius_factor: float = pydantic.Field( + 0.1, + description="The factor to scale the trust radius by when falling back to a " + "second line search.", + gt=0.0, ) error_tolerance: float = pydantic.Field( 1.0, - description="Steps where the loss increases more than this amount are rejected.", + description="Steps that increase the loss more than this amount are rejected.", ) quality_threshold_low: float = pydantic.Field( @@ -56,6 +99,10 @@ class LevenbergMarquardtConfig(pydantic.BaseModel): description="The threshold above which the step is considered high quality.", ) + max_steps: int = pydantic.Field( + ..., description="The maximum number of full steps to perform.", gt=0 + ) + def _invert_svd(matrix: torch.Tensor, threshold: float = 1e-12) -> torch.Tensor: """Invert a matrix using SVD. @@ -125,7 +172,7 @@ def _damping_factor_loss_fn( dx, _ = _solver(damping_factor, gradient, hessian) dx_norm = torch.linalg.norm(dx) - _LOGGER.info( + _LOGGER.debug( f"finding trust radius: length {dx_norm:.4e} (target {trust_radius:.4e})" ) @@ -136,20 +183,15 @@ def _step( gradient: torch.Tensor, hessian: torch.Tensor, trust_radius: torch.Tensor, - initial_damping_factor: float = 1.0, - min_eigenvalue: float = 1.0e-4, -) -> tuple[torch.Tensor, torch.Tensor, bool]: - """Compute the Levenberg–Marquardt step. + config: LevenbergMarquardtConfig, +) -> tuple[torch.Tensor, torch.Tensor, bool, torch.Tensor]: + """Compute the next Levenberg–Marquardt step. Args: gradient: The gradient with ``shape=(n,)``. hessian: The hessian with ``shape=(n, n)``. trust_radius: The target trust radius. - initial_damping_factor: An initial guess of the Levenberg-Marquardt damping - factor - min_eigenvalue: Lower bound on hessian eigenvalue. If the smallest eigenvalue - is smaller than this, a small amount of steepest descent is mixed in to - try and correct this. + config: The optimizer config. Notes: * the code to 'excise' certain parameters is for now removed until its clear @@ -157,18 +199,20 @@ def _step( * only trust region is implemented (i.e., only trust0 > 0 is supported) Returns: - The step with ``shape=(n,)``, the expected improvement with ``shape=()``, and - a boolean indicating whether the damping factor was adjusted. + The step with ``shape=(n,)``, the expected improvement with ``shape=()``, + a boolean indicating whether the damping factor was adjusted, and the damping + factor with ``shape=(1,)``. """ from scipy import optimize eigenvalues, _ = torch.linalg.eigh(hessian) eigenvalue_smallest = eigenvalues.min() - if eigenvalue_smallest < min_eigenvalue: + if eigenvalue_smallest < config.min_eigenvalue: # Mix in SD step if Hessian minimum eigenvalue is negative - experimental. adjacency = ( - max(min_eigenvalue, 0.01 * abs(eigenvalue_smallest)) - eigenvalue_smallest + max(config.min_eigenvalue, 0.01 * abs(eigenvalue_smallest)) + - eigenvalue_smallest ) _LOGGER.info( @@ -184,7 +228,7 @@ def _step( dx, improvement = _solver(damping_factor, gradient, hessian) dx_norm = torch.linalg.norm(dx) - adjust_damping = (dx_norm > trust_radius).item() + adjust_damping = bool(dx_norm > trust_radius) if adjust_damping: # LPW tried a few optimizers and found Brent works well, but also that the @@ -197,16 +241,87 @@ def _step( hessian.detach().cpu(), trust_radius.detach().cpu(), ), - brack=(initial_damping_factor, initial_damping_factor * 4), - tol=1e-6, + brack=(config.min_damping_factor, config.min_damping_factor * 4), + tol=1.0e-6 if config.mode.lower() == _ADAPTIVE else 1.0e-4, ) dx, improvement = _solver(damping_factor, gradient, hessian) + + return dx, improvement, adjust_damping, damping_factor + + +def _hessian_diagonal_search( + x: torch.Tensor, + closure: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + closure_fn: ClosureFn, + correct_fn: CorrectFn, + damping_factor: torch.Tensor, + trust_radius: torch.Tensor, + config: LevenbergMarquardtConfig, +) -> tuple[torch.Tensor, float]: + """ + + Args: + Args: + x: The current parameters. + closure: The loss, gradient and hessian evaluated at ``x``. + closure_fn: The closure function. + correct_fn: The parameter 'correction' function. + damping_factor: The current damping factor. + trust_radius: The current trust radius. + config: The optimizer config. + + Returns: + The step with ``shape=(n,)`` and the expected improvement with ``shape=()``. + """ + from scipy import optimize + + loss, gradient, hessian = closure + + def search_fn(factor: torch.Tensor): + dx_next, _ = _solver(factor, gradient, hessian) + x_next = correct_fn(dx_next + x).requires_grad_(x.requires_grad) + + loss_micro, _, _ = closure_fn(x_next, False, False) + return loss_micro - loss + + damping_factor, expected_improvement, _, _ = optimize.brent( + search_fn, + (), + (damping_factor, damping_factor * 4), + config.search_tolerance, + True, + ) + + if expected_improvement > 0.0: + trust_radius = min( + config.search_trust_radius_factor * trust_radius, + config.search_trust_radius_max, + ) + brent_args = (gradient.detach().cpu(), hessian.detach().cpu(), trust_radius) + + damping_factor = optimize.brent( + _damping_factor_loss_fn, + brent_args, + (config.min_damping_factor, config.min_damping_factor * 4), + 1e-6, + ) + + dx, _ = _solver(damping_factor, gradient, hessian) dx_norm = torch.linalg.norm(dx) - _LOGGER.info(f"trust-radius step found (length {dx_norm:.4e})") + _LOGGER.info(f"restarting search with step size {dx_norm}") + + damping_factor, expected_improvement, _, _ = optimize.brent( + search_fn, + (), + (damping_factor, damping_factor * 4), + config.search_tolerance, + True, + ) - return dx, improvement, adjust_damping + dx, _ = _solver(damping_factor, gradient, hessian) + return dx, expected_improvement def _reduce_trust_radius( @@ -222,7 +337,7 @@ def _reduce_trust_radius( The reduced trust radius. """ trust_radius = max( - dx_norm * (1.0 / (1.0 + config.adaptive_factor)), config.min_trust_radius + dx_norm * (1.0 / (1.0 + config.adaptive_factor)), config.trust_radius_min ) _LOGGER.info(f"reducing trust radius to {trust_radius:.4e}") @@ -253,7 +368,7 @@ def _update_trust_radius( if step_quality <= config.quality_threshold_low: trust_radius = max( dx_norm * (1.0 / (1.0 + config.adaptive_factor)), - smee.utils.tensor_like(config.min_trust_radius, dx_norm), + smee.utils.tensor_like(config.trust_radius_min, dx_norm), ) _LOGGER.info( f"low quality step detected - reducing trust radius to {trust_radius:.4e}" @@ -272,78 +387,87 @@ def _update_trust_radius( return trust_radius -class LevenbergMarquardt: - """A Levenberg-Marquardt optimizer. +@torch.no_grad() +def levenberg_marquardt( + x: torch.Tensor, + closure_fn: ClosureFn, + correct_fn: CorrectFn | None = None, + config: LevenbergMarquardtConfig | None = None, +) -> torch.Tensor: + """Optimize a given set of parameters using the Levenberg-Marquardt algorithm. Notes: - This is a reimplementation of the Levenberg-Marquardt optimizer from the - ForceBalance package, and so may differ from a standard implementation. - """ - - def __init__(self, config: LevenbergMarquardtConfig | None = None): - self.config = config if config is not None else LevenbergMarquardtConfig() + * This optimizer assumes a least-square loss function. + * This is a reimplementation of the Levenberg-Marquardt optimizer from the + ForceBalance package, and so may differ from a standard implementation. - self._closure_prev = None - self._trust_radius = torch.tensor(self.config.trust_radius) + Args: + x: The initial guess of the parameters with ``shape=(n,)``. + closure_fn: A function that computes the loss (``shape=()``), its + gradient (``shape=(n,)``), and hessian (``shape=(n, n)``).. + correct_fn: A function that can be used to correct the parameters after + each step is taken and before the new loss is computed. This may + include, for example, ensuring that vdW parameters are all positive. + config: The optimizer config. - @torch.no_grad() - def step( - self, - x: torch.Tensor, - closure_fn: ClosureFn, - correct_fn: CorrectFn | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Performs a single optimization step. + Returns: + The optimized parameters. + """ - Args: - x: The initial guess of the parameters. - closure_fn: The closure that computes the loss (``shape=()``), its - gradient (``shape=(n,)``), and hessian (``shape=(n, n)``).. - correct_fn: A function that can be used to correct the parameters after - each step is taken and before the new loss is computed. This may - include, for example, ensuring that vdW parameters are all positive. + x = x.clone().detach().requires_grad_(x.requires_grad) - Returns: - The optimized parameters. - """ + correct_fn = correct_fn if correct_fn is not None else lambda y: y + closure_fn = torch.enable_grad()(closure_fn) - correct_fn = correct_fn if correct_fn is not None else lambda x: x - closure_fn = torch.enable_grad()(closure_fn) + closure_prev = closure_fn(x, True, True) + trust_radius = torch.tensor(config.trust_radius).to(x.device) - if self._closure_prev is None: - # compute the initial loss, gradient and hessian - self._closure_prev = closure_fn(x) + for step in range(config.max_steps): + loss_prev, gradient_prev, hessian_prev = closure_prev - if self._trust_radius.device != x.device: - self._trust_radius = self._trust_radius.to(x.device) + dx, expected_improvement, damping_adjusted, damping_factor = _step( + gradient_prev, hessian_prev, trust_radius, config + ) - loss_prev, gradient_prev, hessian_prev = self._closure_prev + if config.mode.lower() == _HESSIAN_SEARCH: + dx, expected_improvement = _hessian_diagonal_search( + x, + closure_prev, + closure_fn, + correct_fn, + damping_factor, + trust_radius, + config, + ) - dx, expected_improvement, damping_adjusted = _step( - gradient_prev, hessian_prev, self._trust_radius - ) dx_norm = torch.linalg.norm(dx) + _LOGGER.info(f"{config.mode} step found (length {dx_norm:.4e})") x_next = correct_fn(x + dx).requires_grad_(x.requires_grad) - loss_next, gradient_next, hessian_next = closure_fn(x_next) - loss_delta = loss_next - loss_prev + loss, gradient, hessian = closure_fn(x_next, True, True) + loss_delta = loss - loss_prev step_quality = loss_delta / expected_improvement + accept_step = True - if loss_next > (loss_prev + self.config.error_tolerance): + if loss > (loss_prev + config.error_tolerance): # reject the 'bad' step and try again from where we were loss, gradient, hessian = (loss_prev, gradient_prev, hessian_prev) + trust_radius = _reduce_trust_radius(dx_norm, config) - self._trust_radius = _reduce_trust_radius(dx_norm, self.config) - else: - # accept the step - loss, gradient, hessian = (loss_next, gradient_next, hessian_next) + accept_step = False + elif config.mode.lower() == _ADAPTIVE: + # this was a 'good' step - we can maybe increase the trust radius + trust_radius = _update_trust_radius( + dx_norm, step_quality, trust_radius, damping_adjusted, config + ) + + if accept_step: x.data.copy_(x_next.data) - self._trust_radius = _update_trust_radius( - dx_norm, step_quality, self._trust_radius, damping_adjusted, self.config - ) + closure_prev = (loss, gradient, hessian) + + _LOGGER.info(f"step={step} loss={loss.detach().cpu().item()}") - self._closure_prev = (loss, gradient, hessian) - return self._closure_prev + return x diff --git a/descent/tests/optim/test_lm.py b/descent/tests/optim/test_lm.py index 955e038..e4d1148 100644 --- a/descent/tests/optim/test_lm.py +++ b/descent/tests/optim/test_lm.py @@ -4,13 +4,20 @@ import torch from descent.optim._lm import ( - LevenbergMarquardt, + LevenbergMarquardtConfig, _damping_factor_loss_fn, + _hessian_diagonal_search, _solver, _step, + levenberg_marquardt, ) +@pytest.fixture +def config(): + return LevenbergMarquardtConfig(max_steps=10) + + def test_solver(): gradient = torch.tensor([0.5, -0.3, 0.7]) hessian = torch.tensor( @@ -36,7 +43,7 @@ def test_solver(): assert torch.allclose(solution, expected_solution) -def test_step(): +def test_step(config): gradient = torch.tensor([0.5, -0.3, 0.7]) hessian = torch.tensor( [ @@ -48,8 +55,8 @@ def test_step(): expected_trust_radius = torch.tensor(0.123) - dx, solution, adjusted = _step( - gradient, hessian, trust_radius=expected_trust_radius + dx, solution, adjusted, damping_factor = _step( + gradient, hessian, trust_radius=expected_trust_radius, config=config ) assert isinstance(dx, torch.Tensor) assert dx.shape == gradient.shape @@ -60,19 +67,56 @@ def test_step(): assert torch.isclose(torch.norm(dx), torch.tensor(expected_trust_radius)) assert adjusted is True + assert isinstance(damping_factor, torch.Tensor) + assert damping_factor.shape == torch.Size([]) -def test_step_sd(caplog): + +def test_step_sd(config, caplog): gradient = torch.tensor([1.0, 1.0]) hessian = torch.tensor([[1.0, 1.0], [1.0, 1.0]]) with caplog.at_level(logging.INFO): - _ = _step(gradient, hessian, trust_radius=1.0) + _ = _step(gradient, hessian, trust_radius=torch.tensor(1.0), config=config) # TODO: not 100% sure on what cases LPW was trying to correct for here, # for now settle to double check the SD is applied. assert "hessian has a small or negative eigenvalue" in caplog.text +def test_hessian_diagonal_search(caplog, mocker): + mocker.patch( + "scipy.optimize.brent", + autospec=True, + side_effect=[ + (torch.tensor(1.0), 2.0, None, None), + (torch.tensor(1.0)), + (torch.tensor(0.2), -2.0, None, None), + ], + ) + + config = LevenbergMarquardtConfig(max_steps=1) + + closure = (torch.tensor(1.0), torch.tensor([2.0]), torch.tensor([[3.0]])) + + theta = torch.tensor([0.0], requires_grad=True) + + with caplog.at_level(logging.INFO): + dx, expected = _hessian_diagonal_search( + theta, + closure, + mocker.MagicMock(), + mocker.MagicMock(), + torch.tensor(1.0), + torch.tensor(1.0), + config, + ) + + assert "restarting search with step size" in caplog.text + + assert dx.shape == theta.shape + assert torch.isclose(torch.tensor(expected), torch.tensor(-2.0)) + + def test_damping_factor_loss_fn(mocker): dx = torch.tensor([3.0, 4.0, 0.0]) dx_norm = torch.linalg.norm(dx) @@ -101,21 +145,9 @@ def test_levenberg_marquardt_adaptive(mocker, caplog): """Make sure the trust radius is adjusted correctly based on the loss.""" mock_dx_traj = [ - ( - torch.tensor([10.0, 20]), - torch.tensor(-100.0), - False, - ), - ( - torch.tensor([0.1, 0.2]), - torch.tensor(-0.5), - False, - ), - ( - torch.tensor([0.05, 0.01]), - torch.tensor(-2.0), - True, - ), + (torch.tensor([10.0, 20]), torch.tensor(-100.0), False, torch.tensor(1.0)), + (torch.tensor([0.1, 0.2]), torch.tensor(-0.5), False, torch.tensor(0.5)), + (torch.tensor([0.05, 0.01]), torch.tensor(-2.0), True, torch.tensor(0.25)), ] mock_step_fn = mocker.patch( "descent.optim._lm._step", autospec=True, side_effect=mock_dx_traj @@ -148,23 +180,23 @@ def test_levenberg_marquardt_adaptive(mocker, caplog): pytest.approx(mock_loss_traj[i][1]), pytest.approx(mock_loss_traj[i][2]), mocker.ANY, + mocker.ANY, ) for i in [0, 0, 2] ] x_traj = [] - def mock_loss_fn(_x): + def mock_loss_fn(_x, *_): x_traj.append(_x.clone()) return mock_loss_traj.pop(0) x = torch.tensor([0.0, 0.0]) - optimizer = LevenbergMarquardt() - with caplog.at_level(logging.INFO): - for _ in range(3): - optimizer.step(x, mock_loss_fn) + x_new = levenberg_marquardt( + x, mock_loss_fn, None, config=LevenbergMarquardtConfig(max_steps=3) + ) expected_x_traj = [ torch.tensor([0.0, 0.0]), @@ -173,8 +205,8 @@ def mock_loss_fn(_x): torch.tensor([0.1, 0.2]), torch.tensor([0.15, 0.21]), ] - assert x.shape == expected_x_traj[-1].shape - assert torch.allclose(x, expected_x_traj[-1]) + assert x_new.shape == expected_x_traj[-1].shape + assert torch.allclose(x_new, expected_x_traj[-1]) trust_radius_messages = [m for m in caplog.messages if "trust radius" in m] @@ -194,7 +226,16 @@ def mock_loss_fn(_x): assert mock_step_fn_calls == expected_loss_traj -def test_levenberg_marquardt(): +@pytest.mark.parametrize( + "mode,n_steps,expected_n_grad_calls,expected_n_hess_calls", + [ + ("hessian-search", 2, 2 + 1, 2 + 1), + ("adaptive", 15, 15 + 1, 15 + 1), + ], +) +def test_levenberg_marquardt( + mode, n_steps, expected_n_grad_calls, expected_n_hess_calls +): expected = torch.tensor([5.0, 3.0, 2.0]) x_ref = torch.linspace(-2.0, 2.0, 100) @@ -206,20 +247,33 @@ def loss_fn(_theta: torch.Tensor) -> torch.Tensor: y = _theta[0] * x_ref**2 + _theta[1] * x_ref + _theta[2] return torch.sum((y - y_ref) ** 2) + n_grad_calls = 0 + n_hess_calls = 0 + @torch.enable_grad() - def target_fn( - _theta: torch.Tensor, + def closure_fn( + _theta: torch.Tensor, compute_gradient: bool, compute_hessian: bool ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - loss = loss_fn(_theta) - (grad,) = torch.autograd.grad(loss, _theta, torch.tensor(1.0)) - hess = torch.autograd.functional.hessian(loss_fn, _theta) + loss, grad, hess = loss_fn(_theta), None, None + + nonlocal n_grad_calls, n_hess_calls + + if compute_gradient: + (grad,) = torch.autograd.grad(loss, _theta, torch.tensor(1.0)) + grad = grad.detach() + n_grad_calls += 1 + if compute_hessian: + hess = torch.autograd.functional.hessian(loss_fn, _theta) + n_hess_calls += 1 + + return loss.detach(), grad, hess - return loss.detach(), grad.detach(), hess.detach() + config = LevenbergMarquardtConfig(max_steps=n_steps, mode=mode) - optimizer = LevenbergMarquardt() + theta_new = levenberg_marquardt(theta, closure_fn, None, config) - for _ in range(15): - optimizer.step(theta, target_fn) + assert theta_new.shape == expected.shape + assert torch.allclose(theta_new, expected) - assert theta.shape == expected.shape - assert torch.allclose(theta, expected) + assert n_grad_calls == n_steps + 1 # +1 for initial closure call + assert n_hess_calls == n_steps + 1