diff --git a/descent/optim/_lm.py b/descent/optim/_lm.py index 6fb517f..85baf4b 100644 --- a/descent/optim/_lm.py +++ b/descent/optim/_lm.py @@ -99,6 +99,37 @@ class LevenbergMarquardtConfig(pydantic.BaseModel): description="The threshold above which the step is considered high quality.", ) + convergence_loss: float = pydantic.Field( + 1.0e-4, + description="The loss will be considered converged if its std deviation over " + "the last `n_convergence_steps` steps is less than this value.", + gt=0.0, + ) + convergence_gradient: float = pydantic.Field( + 1.0e-3, + description="The gradient will be considered converged if its norm is less than" + "this value.", + gt=0.0, + ) + convergence_step: float = pydantic.Field( + 1.0e-4, + description="The step size will be considered converged if its norm is less " + "than this value.", + gt=0.0, + ) + n_convergence_steps: int = pydantic.Field( + 2, + description="The number of steps to consider when checking for convergence " + "in the loss.", + ) + n_convergence_criteria: int = pydantic.Field( + 1, + description="The number of convergence criteria that must be satisfied before " + "the optimization is considered converged. If 0, no convergence criteria will " + "be used and the optimizer will run for ``max_steps`` full steps.", + ge=0, + ) + max_steps: int = pydantic.Field( ..., description="The maximum number of full steps to perform.", gt=0 ) @@ -387,6 +418,61 @@ def _update_trust_radius( return trust_radius +def _has_converged( + dx: torch.Tensor, + loss_history: list[torch.Tensor], + gradient: torch.Tensor, + step_quality: float, + config: LevenbergMarquardtConfig, +) -> bool: + """Check whether the optimization has converged. + + Args: + dx: The current step. + loss_history: The loss history. + gradient: The current gradient. + step_quality: The quality of the current step. + config: The optimizer config. + + Returns: + Whether the optimization has converged. + """ + if config.n_convergence_criteria == 0: + return False + + if step_quality <= config.quality_threshold_low: + # don't converge on low quality steps + return False + + grad_norm = torch.linalg.norm(gradient) + grad_converged = grad_norm < config.convergence_gradient + + step_norm = torch.linalg.norm(dx) + step_converged = 0.0 <= step_norm < config.convergence_step + + loss_std = ( + torch.inf + if len(loss_history) == 0 + else torch.std(torch.stack(loss_history[-config.n_convergence_steps :])) + ) + loss_converged = ( + len(loss_history) >= config.n_convergence_steps + and loss_std < config.convergence_loss + ) + + if grad_converged: + _LOGGER.info(f"gradient norm is converged: ({grad_norm:.2e})") + if step_converged: + _LOGGER.info(f"step size is converged: ({step_norm:.2e})") + if loss_converged: + _LOGGER.info(f"loss is converged: ({loss_std:.2e})") + + return bool( + sum((grad_converged, step_converged, loss_converged)) + >= config.n_convergence_criteria + ) + + @torch.no_grad() def levenberg_marquardt( x: torch.Tensor, @@ -422,6 +508,10 @@ def levenberg_marquardt( closure_prev = closure_fn(x, True, True) trust_radius = torch.tensor(config.trust_radius).to(x.device) + loss_history = [] + + has_converged = False + for step in range(config.max_steps): loss_prev, gradient_prev, hessian_prev = closure_prev @@ -465,9 +555,19 @@ def levenberg_marquardt( if accept_step: x.data.copy_(x_next.data) + loss_history.append(loss.detach().cpu().clone()) closure_prev = (loss, gradient, hessian) - _LOGGER.info(f"step={step} loss={loss.detach().cpu().item()}") + _LOGGER.info(f"step={step} loss={loss.detach().cpu().item():.4e}") + + if _has_converged(dx, loss_history, gradient, step_quality, config): + _LOGGER.info(f"optimization has converged after {step + 1} steps.") + has_converged = True + + break + + if not has_converged: + _LOGGER.info(f"optimization has not converged after {config.max_steps} steps.") return x diff --git a/descent/tests/optim/test_lm.py b/descent/tests/optim/test_lm.py index e4d1148..fc72219 100644 --- a/descent/tests/optim/test_lm.py +++ b/descent/tests/optim/test_lm.py @@ -6,6 +6,7 @@ from descent.optim._lm import ( LevenbergMarquardtConfig, _damping_factor_loss_fn, + _has_converged, _hessian_diagonal_search, _solver, _step, @@ -141,6 +142,65 @@ def test_damping_factor_loss_fn(mocker): assert torch.isclose(difference, expected_difference) +@pytest.mark.parametrize( + "n_convergence_criteria, n_convergence_steps, step_quality, expected_converged, expected_logs", + [ + (0, 2, 1.0, False, []), + (1, 2, 0.0, False, []), + ( + 1, + 2, + 1.0, + True, + [ + "gradient norm is converged", + "step size is converged", + "loss is converged", + ], + ), + ( + 3, + 3, + 1.0, + False, + [ + "gradient norm is converged", + "step size is converged", + ], + ), + ], +) +def test_has_converged( + n_convergence_criteria, + n_convergence_steps, + step_quality, + expected_converged, + expected_logs, + caplog, +): + dx = torch.tensor([0.0, 0.01, 0.0]) + gradient = torch.tensor([0.0, -0.02, 0.0]) + loss_history = [torch.tensor(1.0), torch.tensor(0.1), torch.tensor(0.11)] + + with caplog.at_level(logging.INFO): + has_converged = _has_converged( + dx, + loss_history, + gradient, + step_quality, + LevenbergMarquardtConfig( + max_steps=1, + n_convergence_steps=n_convergence_steps, + n_convergence_criteria=n_convergence_criteria, + convergence_step=0.2, + convergence_loss=0.2, + convergence_gradient=0.2, + ), + ) + assert has_converged == expected_converged + assert all(log in caplog.text for log in expected_logs) + + def test_levenberg_marquardt_adaptive(mocker, caplog): """Make sure the trust radius is adjusted correctly based on the loss."""