From d6011ea7ba32b9504a7b84e59f2ba4de3693b706 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Fri, 27 Oct 2023 08:34:11 -0400 Subject: [PATCH] Add Levenberg Marquardt optimizer (#43) --- LICENSE_3RD_PARTY | 42 ++++ descent/optimizers/__init__.py | 5 + descent/optimizers/_lm.py | 349 +++++++++++++++++++++++++++ descent/tests/optimizers/__init__.py | 0 descent/tests/optimizers/test_lm.py | 225 +++++++++++++++++ devtools/envs/base.yaml | 8 +- 6 files changed, 625 insertions(+), 4 deletions(-) create mode 100644 LICENSE_3RD_PARTY create mode 100644 descent/optimizers/__init__.py create mode 100644 descent/optimizers/_lm.py create mode 100644 descent/tests/optimizers/__init__.py create mode 100644 descent/tests/optimizers/test_lm.py diff --git a/LICENSE_3RD_PARTY b/LICENSE_3RD_PARTY new file mode 100644 index 0000000..444bf09 --- /dev/null +++ b/LICENSE_3RD_PARTY @@ -0,0 +1,42 @@ +/* -------------------------------------------------------------------------- * + * ForceBalance * + * -------------------------------------------------------------------------- * + + BSD 3-clause (aka BSD 2.0) License + +Copyright 2011-2015 Stanford University and the Authors +Copyright 2015-2018 Regents of the University of California and the Authors + +Developed by: Lee-Ping Wang + University of California, Davis + http://www.ucdavis.edu + +Contributors: Yudong Qiu, Keri A. McKiernan, Jeffrey R. Wagner, Hyesu Jang, +Simon Boothroyd, Arthur Vigil, Erik G. Brandt, Johnny Israeli, John Stoppelman + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors +may be used to endorse or promote products derived from this software +without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +* -------------------------------------------------------------------------- */ \ No newline at end of file diff --git a/descent/optimizers/__init__.py b/descent/optimizers/__init__.py new file mode 100644 index 0000000..85e4b09 --- /dev/null +++ b/descent/optimizers/__init__.py @@ -0,0 +1,5 @@ +"""Custom parameter optimizers.""" + +from descent.optimizers._lm import LevenbergMarquardt, LevenbergMarquardtConfig + +__all__ = ["LevenbergMarquardt", "LevenbergMarquardtConfig"] diff --git a/descent/optimizers/_lm.py b/descent/optimizers/_lm.py new file mode 100644 index 0000000..d7b4ee7 --- /dev/null +++ b/descent/optimizers/_lm.py @@ -0,0 +1,349 @@ +"""Levenberg-Marquardt optimizer. + +Notes: + This is a reimplementation of the Levenberg-Marquardt optimizer from the fantastic + ForceBalance [1] package. The original code is licensed under the BSD 3-clause + license which can be found in the LICENSE_3RD_PARTY file. + +References: + [1]: https://github.com/leeping/forcebalance/blob/b395fd4b/src/optimizer.py +""" +import logging +import math +import typing + +import pydantic +import smee.utils +import torch + +_LOGGER = logging.getLogger(__name__) + + +ClosureFn = typing.Callable[ + [torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor] +] +CorrectFn = typing.Callable[[torch.Tensor], torch.Tensor] + + +class LevenbergMarquardtConfig(pydantic.BaseModel): + """Configuration for the Levenberg-Marquardt optimizer.""" + + type: typing.Literal["levenberg-marquardt"] = "levenberg-marquardt" + + trust_radius: float = pydantic.Field( + 0.2, description="Target trust radius.", gt=0.0 + ) + min_trust_radius: float = pydantic.Field(0.05, description="Minimum trust radius.") + + adaptive_factor: float = pydantic.Field( + 0.25, description="Adaptive trust radius adjustment factor.", gt=0.0 + ) + adaptive_damping: float = pydantic.Field( + 1.0, description="Adaptive trust radius adjustment damping.", gt=0.0 + ) + + error_tolerance: float = pydantic.Field( + 1.0, + description="Steps where the loss increases more than this amount are rejected.", + ) + + quality_threshold_low: float = pydantic.Field( + 0.25, + description="The threshold below which the step is considered low quality.", + ) + quality_threshold_high: float = pydantic.Field( + 0.75, + description="The threshold above which the step is considered high quality.", + ) + + +def _invert_svd(matrix: torch.Tensor, threshold: float = 1e-12) -> torch.Tensor: + """Invert a matrix using SVD. + + Args: + matrix: The matrix to invert. + threshold: The threshold below which singular values are considered zero. + + Returns: + The inverted matrix. + """ + u, s, vh = torch.linalg.svd(matrix) + + non_zero_idxs = s > threshold + + s_inverse = torch.zeros_like(s) + s_inverse[non_zero_idxs] = 1.0 / s[non_zero_idxs] + + return vh.T @ torch.diag(s_inverse) @ u.T + + +def _solver( + damping_factor: torch.Tensor, gradient: torch.Tensor, hessian: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Solve the Levenberg–Marquardt step. + + Args: + damping_factor: The damping factor with ``shape=(1,)``. + gradient: The gradient with ``shape=(n,)``. + hessian: The Hessian with ``shape=(n, n)``. + + Returns: + The step with ``shape=(n,)`` and the expected improvement with ``shape=()``. + """ + + hessian_regular = hessian + (damping_factor - 1) ** 2 * torch.eye( + len(hessian), device=hessian.device, dtype=hessian.dtype + ) + hessian_inverse = _invert_svd(hessian_regular) + + dx = -(hessian_inverse @ gradient) + solution = 0.5 * dx @ hessian @ dx + (dx * gradient).sum() + + return dx, solution + + +def _damping_factor_loss_fn( + damping_factor: torch.Tensor, + gradient: torch.Tensor, + hessian: torch.Tensor, + trust_radius: float, +) -> torch.Tensor: + """Computes the squared difference between the target trust radius and the step size + proposed by the Levenberg–Marquardt solver. + + This is used when finding the optimal damping factor. + + Args: + damping_factor: The damping factor with ``shape=(1,)``. + gradient: The gradient with ``shape=(n,)``. + hessian: The hessian with ``shape=(n, n)``. + trust_radius: The target trust radius. + + Returns: + The squared difference. + """ + dx, _ = _solver(damping_factor, gradient, hessian) + dx_norm = torch.linalg.norm(dx) + + _LOGGER.info( + f"finding trust radius: length {dx_norm:.4e} (target {trust_radius:.4e})" + ) + + return (dx_norm - trust_radius) ** 2 + + +def _step( + gradient: torch.Tensor, + hessian: torch.Tensor, + trust_radius: torch.Tensor, + initial_damping_factor: float = 1.0, + min_eigenvalue: float = 1.0e-4, +) -> tuple[torch.Tensor, torch.Tensor, bool]: + """Compute the Levenberg–Marquardt step. + + Args: + gradient: The gradient with ``shape=(n,)``. + hessian: The hessian with ``shape=(n, n)``. + trust_radius: The target trust radius. + initial_damping_factor: An initial guess of the Levenberg-Marquardt damping + factor + min_eigenvalue: Lower bound on hessian eigenvalue. If the smallest eigenvalue + is smaller than this, a small amount of steepest descent is mixed in to + try and correct this. + + Notes: + * the code to 'excise' certain parameters is for now removed until its clear + it is needed. + * only trust region is implemented (i.e., only trust0 > 0 is supported) + + Returns: + The step with ``shape=(n,)``, the expected improvement with ``shape=()``, and + a boolean indicating whether the damping factor was adjusted. + """ + from scipy import optimize + + eigenvalues, _ = torch.linalg.eigh(hessian) + eigenvalue_smallest = eigenvalues.min() + + if eigenvalue_smallest < min_eigenvalue: + # Mix in SD step if Hessian minimum eigenvalue is negative - experimental. + adjacency = ( + max(min_eigenvalue, 0.01 * abs(eigenvalue_smallest)) - eigenvalue_smallest + ) + + _LOGGER.info( + f"hessian has a small or negative eigenvalue ({eigenvalue_smallest:.1e}), " + f"mixing in some steepest descent ({adjacency:.1e}) to correct this." + ) + hessian += adjacency * torch.eye( + hessian.shape[0], device=hessian.device, dtype=hessian.dtype + ) + + damping_factor = torch.tensor(1.0) + + dx, improvement = _solver(damping_factor, gradient, hessian) + dx_norm = torch.linalg.norm(dx) + + adjust_damping = (dx_norm > trust_radius).item() + + if adjust_damping: + # LPW tried a few optimizers and found Brent works well, but also that the + # tolerance is fractional - if the optimized value is zero it takes a lot of + # meaningless steps. + damping_factor = optimize.brent( + _damping_factor_loss_fn, + ( + gradient.detach().cpu(), + hessian.detach().cpu(), + trust_radius.detach().cpu(), + ), + brack=(initial_damping_factor, initial_damping_factor * 4), + tol=1e-6, + ) + + dx, improvement = _solver(damping_factor, gradient, hessian) + dx_norm = torch.linalg.norm(dx) + + _LOGGER.info(f"trust-radius step found (length {dx_norm:.4e})") + + return dx, improvement, adjust_damping + + +def _reduce_trust_radius( + dx_norm: torch.Tensor, config: LevenbergMarquardtConfig +) -> torch.Tensor: + """Reduce the trust radius. + + Args: + dx_norm: The size of the previous step. + config: The optimizer config. + + Returns: + The reduced trust radius. + """ + trust_radius = max( + dx_norm * (1.0 / (1.0 + config.adaptive_factor)), config.min_trust_radius + ) + _LOGGER.info(f"reducing trust radius to {trust_radius:.4e}") + + return smee.utils.tensor_like(trust_radius, dx_norm) + + +def _update_trust_radius( + dx_norm: torch.Tensor, + step_quality: float, + trust_radius: torch.Tensor, + damping_adjusted: bool, + config: LevenbergMarquardtConfig, +) -> torch.Tensor: + """Adjust the trust radius based on the quality of the previous step. + + Args: + dx_norm: The size of the previous step. + step_quality: The quality of the previous step. + trust_radius: The current trust radius. + damping_adjusted: Whether the LM damping factor was adjusted during the + previous step. + config: The optimizer config. + + Returns: + The updated trust radius. + """ + + if step_quality <= config.quality_threshold_low: + trust_radius = max( + dx_norm * (1.0 / (1.0 + config.adaptive_factor)), + smee.utils.tensor_like(config.min_trust_radius, dx_norm), + ) + _LOGGER.info( + f"low quality step detected - reducing trust radius to {trust_radius:.4e}" + ) + + elif step_quality >= config.quality_threshold_high and damping_adjusted: + trust_radius += ( + config.adaptive_factor + * trust_radius + * math.exp( + -config.adaptive_damping * (trust_radius / config.trust_radius - 1.0) + ) + ) + _LOGGER.info(f"updating trust radius to {trust_radius: .4e}") + + return trust_radius + + +class LevenbergMarquardt: + """A Levenberg-Marquardt optimizer. + + Notes: + This is a reimplementation of the Levenberg-Marquardt optimizer from the + ForceBalance package, and so may differ from a standard implementation. + """ + + def __init__(self, config: LevenbergMarquardtConfig | None = None): + self.config = config if config is not None else LevenbergMarquardtConfig() + + self._closure_prev = None + self._trust_radius = torch.tensor(self.config.trust_radius) + + @torch.no_grad() + def step( + self, + x: torch.Tensor, + closure_fn: ClosureFn, + correct_fn: CorrectFn | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs a single optimization step. + + Args: + x: The initial guess of the parameters. + closure_fn: The closure that computes the loss (``shape=()``), its + gradient (``shape=(n,)``), and hessian (``shape=(n, n)``).. + correct_fn: A function that can be used to correct the parameters after + each step is taken and before the new loss is computed. This may + include, for example, ensuring that vdW parameters are all positive. + + Returns: + The optimized parameters. + """ + + correct_fn = correct_fn if correct_fn is not None else lambda x: x + closure_fn = torch.enable_grad()(closure_fn) + + if self._closure_prev is None: + # compute the initial loss, gradient and hessian + self._closure_prev = closure_fn(x) + + if self._trust_radius.device != x.device: + self._trust_radius = self._trust_radius.to(x.device) + + loss_prev, gradient_prev, hessian_prev = self._closure_prev + + dx, expected_improvement, damping_adjusted = _step( + gradient_prev, hessian_prev, self._trust_radius + ) + dx_norm = torch.linalg.norm(dx) + + x_next = correct_fn(x + dx).requires_grad_(x.requires_grad) + + loss_next, gradient_next, hessian_next = closure_fn(x_next) + loss_delta = loss_next - loss_prev + + step_quality = loss_delta / expected_improvement + + if loss_next > (loss_prev + self.config.error_tolerance): + # reject the 'bad' step and try again from where we were + loss, gradient, hessian = (loss_prev, gradient_prev, hessian_prev) + + self._trust_radius = _reduce_trust_radius(dx_norm, self.config) + else: + # accept the step + loss, gradient, hessian = (loss_next, gradient_next, hessian_next) + x.data.copy_(x_next.data) + + self._trust_radius = _update_trust_radius( + dx_norm, step_quality, self._trust_radius, damping_adjusted, self.config + ) + + self._closure_prev = (loss, gradient, hessian) + return self._closure_prev diff --git a/descent/tests/optimizers/__init__.py b/descent/tests/optimizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/descent/tests/optimizers/test_lm.py b/descent/tests/optimizers/test_lm.py new file mode 100644 index 0000000..913c579 --- /dev/null +++ b/descent/tests/optimizers/test_lm.py @@ -0,0 +1,225 @@ +import logging + +import pytest +import torch + +from descent.optimizers._lm import ( + LevenbergMarquardt, + _damping_factor_loss_fn, + _solver, + _step, +) + + +def test_solver(): + gradient = torch.tensor([0.5, -0.3, 0.7]) + hessian = torch.tensor( + [ + [2.0, 0.5, 0.3], + [0.5, 1.8, 0.2], + [0.3, 0.2, 1.5], + ] + ) + + damping_factor = torch.tensor(1.2) + + dx, solution = _solver(damping_factor, gradient, hessian) + + # computed using ForceBalance 1.9.3 + expected_dx = torch.tensor([-0.24833229, 0.27860679, -0.44235173]) + expected_solution = torch.tensor(-0.26539651272205717) + + assert dx.shape == expected_dx.shape + assert torch.allclose(dx, expected_dx) + + assert solution.shape == expected_solution.shape + assert torch.allclose(solution, expected_solution) + + +def test_step(): + gradient = torch.tensor([0.5, -0.3, 0.7]) + hessian = torch.tensor( + [ + [2.0, 0.5, 0.3], + [0.5, 1.8, 0.2], + [0.3, 0.2, 1.5], + ] + ) + + expected_trust_radius = torch.tensor(0.123) + + dx, solution, adjusted = _step( + gradient, hessian, trust_radius=expected_trust_radius + ) + assert isinstance(dx, torch.Tensor) + assert dx.shape == gradient.shape + + assert isinstance(solution, torch.Tensor) + assert solution.shape == torch.Size([]) + + assert torch.isclose(torch.norm(dx), torch.tensor(expected_trust_radius)) + assert adjusted is True + + +def test_step_sd(caplog): + gradient = torch.tensor([1.0, 1.0]) + hessian = torch.tensor([[1.0, 1.0], [1.0, 1.0]]) + + with caplog.at_level(logging.INFO): + _ = _step(gradient, hessian, trust_radius=1.0) + + # TODO: not 100% sure on what cases LPW was trying to correct for here, + # for now settle to double check the SD is applied. + assert "hessian has a small or negative eigenvalue" in caplog.text + + +def test_damping_factor_loss_fn(mocker): + dx = torch.tensor([3.0, 4.0, 0.0]) + dx_norm = torch.linalg.norm(dx) + + damping_factor = mocker.Mock() + gradient = mocker.Mock() + hessian = mocker.Mock() + + solver_fn = mocker.patch( + "descent.optimizers._lm._solver", autospec=True, return_value=(dx, 0.0) + ) + + trust_radius = 12 + + difference = _damping_factor_loss_fn( + damping_factor, gradient, hessian, trust_radius + ) + + solver_fn.assert_called_once_with(damping_factor, gradient, hessian) + + expected_difference = (dx_norm - trust_radius) ** 2 + assert torch.isclose(difference, expected_difference) + + +def test_levenberg_marquardt_adaptive(mocker, caplog): + """Make sure the trust radius is adjusted correctly based on the loss.""" + + mock_dx_traj = [ + ( + torch.tensor([10.0, 20]), + torch.tensor(-100.0), + False, + ), + ( + torch.tensor([0.1, 0.2]), + torch.tensor(-0.5), + False, + ), + ( + torch.tensor([0.05, 0.01]), + torch.tensor(-2.0), + True, + ), + ] + mock_step_fn = mocker.patch( + "descent.optimizers._lm._step", autospec=True, side_effect=mock_dx_traj + ) + + mock_loss_traj = [ + ( + torch.tensor(0.0), + torch.tensor([1.0, 2.0]), + torch.tensor([[3.0, 4.0], [5.0, 6.0]]), + ), + ( + torch.tensor(150.0), + torch.tensor([7.0, 8.0]), + torch.tensor([[9.0, 10.0], [11.0, 12.0]]), + ), + ( + torch.tensor(-0.1), + torch.tensor([13.0, 14.0]), + torch.tensor([[15.0, 16.0], [17.0, 18.0]]), + ), + ( + torch.tensor(-2.1), + torch.tensor([19.0, 20.0]), + torch.tensor([[21.0, 22.0], [23.0, 24.0]]), + ), + ] + expected_loss_traj = [ + ( + pytest.approx(mock_loss_traj[i][1]), + pytest.approx(mock_loss_traj[i][2]), + mocker.ANY, + ) + for i in [0, 0, 2] + ] + + x_traj = [] + + def mock_loss_fn(_x): + x_traj.append(_x.clone()) + return mock_loss_traj.pop(0) + + x = torch.tensor([0.0, 0.0]) + + optimizer = LevenbergMarquardt() + + with caplog.at_level(logging.INFO): + for _ in range(3): + optimizer.step(x, mock_loss_fn) + + expected_x_traj = [ + torch.tensor([0.0, 0.0]), + torch.tensor([10.0, 20.0]), + # previous step should have been rejected + torch.tensor([0.1, 0.2]), + torch.tensor([0.15, 0.21]), + ] + assert x.shape == expected_x_traj[-1].shape + assert torch.allclose(x, expected_x_traj[-1]) + + trust_radius_messages = [m for m in caplog.messages if "trust radius" in m] + + expected_messages = [ + "reducing trust radius to", + "low quality step detected - reducing trust radius to", + "updating trust radius to", + ] + assert len(trust_radius_messages) == len(expected_messages) + + for message, expected in zip(trust_radius_messages, expected_messages): + assert message.startswith(expected) + + # mock_step_fn.assert_has_calls(expected_loss_traj, any_order=False) + mock_step_fn_calls = [call.args for call in mock_step_fn.call_args_list] + + assert mock_step_fn_calls == expected_loss_traj + + +def test_levenberg_marquardt(): + expected = torch.tensor([5.0, 3.0, 2.0]) + + x_ref = torch.linspace(-2.0, 2.0, 100) + y_ref = expected[0] * x_ref**2 + expected[1] * x_ref + expected[2] + + theta = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) + + def loss_fn(_theta: torch.Tensor) -> torch.Tensor: + y = _theta[0] * x_ref**2 + _theta[1] * x_ref + _theta[2] + return torch.sum((y - y_ref) ** 2) + + @torch.enable_grad() + def target_fn( + _theta: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + loss = loss_fn(_theta) + (grad,) = torch.autograd.grad(loss, _theta, torch.tensor(1.0)) + hess = torch.autograd.functional.hessian(loss_fn, _theta) + + return loss.detach(), grad.detach(), hess.detach() + + optimizer = LevenbergMarquardt() + + for _ in range(15): + optimizer.step(theta, target_fn) + + assert theta.shape == expected.shape + assert torch.allclose(theta, expected) diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 7bc8174..6c1e58f 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -9,16 +9,16 @@ dependencies: - pip # Core packages - # - smee + - smee >=0.4.0 - pytorch - pydantic - # Optional packages - - ### Optimize + ### Levenberg Marquardt - scipy + # Optional packages + # Examples - jupyter - nbconvert