From 173e2c75f9962c6052a7aac75e7739539153aaff Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Wed, 8 Nov 2023 17:09:11 -0500 Subject: [PATCH] Add optional report function to LM (#51) --- descent/optim/_lm.py | 23 ++++++++++++++++++++--- descent/tests/optim/test_lm.py | 4 ++-- devtools/envs/base.yaml | 2 +- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/descent/optim/_lm.py b/descent/optim/_lm.py index bc0a9a1..16bc19b 100644 --- a/descent/optim/_lm.py +++ b/descent/optim/_lm.py @@ -25,6 +25,10 @@ ] CorrectFn = typing.Callable[[torch.Tensor], torch.Tensor] +ReportFn = typing.Callable[ + [int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, bool], None +] + Mode = typing.Literal["adaptive", "hessian-search"] _ADAPTIVE, _HESSIAN_SEARCH = typing.get_args(Mode) @@ -477,9 +481,10 @@ def _has_converged( @torch.no_grad() def levenberg_marquardt( x: torch.Tensor, + config: LevenbergMarquardtConfig, closure_fn: ClosureFn, correct_fn: CorrectFn | None = None, - config: LevenbergMarquardtConfig | None = None, + report_fn: ReportFn | None = None, ) -> torch.Tensor: """Optimize a given set of parameters using the Levenberg-Marquardt algorithm. @@ -490,12 +495,21 @@ def levenberg_marquardt( Args: x: The initial guess of the parameters with ``shape=(n,)``. + config: The optimizer config. closure_fn: A function that computes the loss (``shape=()``), its - gradient (``shape=(n,)``), and hessian (``shape=(n, n)``).. + gradient (``shape=(n,)``), and hessian (``shape=(n, n)``). It should + accept as arguments the current parameter tensor, and two booleans + indicating whether the gradient and hessian are required. correct_fn: A function that can be used to correct the parameters after each step is taken and before the new loss is computed. This may include, for example, ensuring that vdW parameters are all positive. - config: The optimizer config. + It should accept as arguments the current parameter tensor and return + the corrected parameter tensor. + report_fn: An optional function that should be called at the end of every + step. This can be used to report the current state of the optimization. + It should accept as arguments the step number, the current parameter tensor + the loss, gradient and hessian, the step 'quality', and a bool indicating + whether the step was accepted or rejected. Returns: The optimized parameters. @@ -506,6 +520,8 @@ def levenberg_marquardt( correct_fn = correct_fn if correct_fn is not None else lambda y: y closure_fn = torch.enable_grad()(closure_fn) + report_fn = report_fn if report_fn is not None else lambda *_, **__: None + closure_prev = closure_fn(x, True, True) trust_radius = torch.tensor(config.trust_radius).to(x.device) @@ -560,6 +576,7 @@ def levenberg_marquardt( closure_prev = (loss, gradient, hessian) + report_fn(step, x, loss, gradient, hessian, step_quality, accept_step) _LOGGER.info(f"step={step} loss={loss.detach().cpu().item():.4e}") if _has_converged(dx, loss_history, gradient, step_quality, config): diff --git a/descent/tests/optim/test_lm.py b/descent/tests/optim/test_lm.py index fc72219..24462d5 100644 --- a/descent/tests/optim/test_lm.py +++ b/descent/tests/optim/test_lm.py @@ -255,7 +255,7 @@ def mock_loss_fn(_x, *_): with caplog.at_level(logging.INFO): x_new = levenberg_marquardt( - x, mock_loss_fn, None, config=LevenbergMarquardtConfig(max_steps=3) + x, LevenbergMarquardtConfig(max_steps=3), mock_loss_fn, None ) expected_x_traj = [ @@ -330,7 +330,7 @@ def closure_fn( config = LevenbergMarquardtConfig(max_steps=n_steps, mode=mode) - theta_new = levenberg_marquardt(theta, closure_fn, None, config) + theta_new = levenberg_marquardt(theta, config, closure_fn, None) assert theta_new.shape == expected.shape assert torch.allclose(theta_new, expected) diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 255ad87..43f6946 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -9,7 +9,7 @@ dependencies: - pip # Core packages - - smee >=0.4.0 + - smee >=0.7.0 - pytorch - pydantic