Skip to content

Commit

Permalink
Add LM optimizer convergence checks (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 4, 2023
1 parent 53bd12e commit 867d801
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 1 deletion.
102 changes: 101 additions & 1 deletion descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,37 @@ class LevenbergMarquardtConfig(pydantic.BaseModel):
description="The threshold above which the step is considered high quality.",
)

convergence_loss: float = pydantic.Field(
1.0e-4,
description="The loss will be considered converged if its std deviation over "
"the last `n_convergence_steps` steps is less than this value.",
gt=0.0,
)
convergence_gradient: float = pydantic.Field(
1.0e-3,
description="The gradient will be considered converged if its norm is less than"
"this value.",
gt=0.0,
)
convergence_step: float = pydantic.Field(
1.0e-4,
description="The step size will be considered converged if its norm is less "
"than this value.",
gt=0.0,
)
n_convergence_steps: int = pydantic.Field(
2,
description="The number of steps to consider when checking for convergence "
"in the loss.",
)
n_convergence_criteria: int = pydantic.Field(
1,
description="The number of convergence criteria that must be satisfied before "
"the optimization is considered converged. If 0, no convergence criteria will "
"be used and the optimizer will run for ``max_steps`` full steps.",
ge=0,
)

max_steps: int = pydantic.Field(
..., description="The maximum number of full steps to perform.", gt=0
)
Expand Down Expand Up @@ -387,6 +418,61 @@ def _update_trust_radius(
return trust_radius


def _has_converged(
dx: torch.Tensor,
loss_history: list[torch.Tensor],
gradient: torch.Tensor,
step_quality: float,
config: LevenbergMarquardtConfig,
) -> bool:
"""Check whether the optimization has converged.
Args:
dx: The current step.
loss_history: The loss history.
gradient: The current gradient.
step_quality: The quality of the current step.
config: The optimizer config.
Returns:
Whether the optimization has converged.
"""
if config.n_convergence_criteria == 0:
return False

if step_quality <= config.quality_threshold_low:
# don't converge on low quality steps
return False

grad_norm = torch.linalg.norm(gradient)
grad_converged = grad_norm < config.convergence_gradient

step_norm = torch.linalg.norm(dx)
step_converged = 0.0 <= step_norm < config.convergence_step

loss_std = (
torch.inf
if len(loss_history) == 0
else torch.std(torch.stack(loss_history[-config.n_convergence_steps :]))
)
loss_converged = (
len(loss_history) >= config.n_convergence_steps
and loss_std < config.convergence_loss
)

if grad_converged:
_LOGGER.info(f"gradient norm is converged: ({grad_norm:.2e})")
if step_converged:
_LOGGER.info(f"step size is converged: ({step_norm:.2e})")
if loss_converged:
_LOGGER.info(f"loss is converged: ({loss_std:.2e})")

return bool(
sum((grad_converged, step_converged, loss_converged))
>= config.n_convergence_criteria
)


@torch.no_grad()
def levenberg_marquardt(
x: torch.Tensor,
Expand Down Expand Up @@ -422,6 +508,10 @@ def levenberg_marquardt(
closure_prev = closure_fn(x, True, True)
trust_radius = torch.tensor(config.trust_radius).to(x.device)

loss_history = []

has_converged = False

for step in range(config.max_steps):
loss_prev, gradient_prev, hessian_prev = closure_prev

Expand Down Expand Up @@ -465,9 +555,19 @@ def levenberg_marquardt(

if accept_step:
x.data.copy_(x_next.data)
loss_history.append(loss.detach().cpu().clone())

closure_prev = (loss, gradient, hessian)

_LOGGER.info(f"step={step} loss={loss.detach().cpu().item()}")
_LOGGER.info(f"step={step} loss={loss.detach().cpu().item():.4e}")

if _has_converged(dx, loss_history, gradient, step_quality, config):
_LOGGER.info(f"optimization has converged after {step + 1} steps.")
has_converged = True

break

if not has_converged:
_LOGGER.info(f"optimization has not converged after {config.max_steps} steps.")

return x
60 changes: 60 additions & 0 deletions descent/tests/optim/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from descent.optim._lm import (
LevenbergMarquardtConfig,
_damping_factor_loss_fn,
_has_converged,
_hessian_diagonal_search,
_solver,
_step,
Expand Down Expand Up @@ -141,6 +142,65 @@ def test_damping_factor_loss_fn(mocker):
assert torch.isclose(difference, expected_difference)


@pytest.mark.parametrize(
"n_convergence_criteria, n_convergence_steps, step_quality, expected_converged, expected_logs",
[
(0, 2, 1.0, False, []),
(1, 2, 0.0, False, []),
(
1,
2,
1.0,
True,
[
"gradient norm is converged",
"step size is converged",
"loss is converged",
],
),
(
3,
3,
1.0,
False,
[
"gradient norm is converged",
"step size is converged",
],
),
],
)
def test_has_converged(
n_convergence_criteria,
n_convergence_steps,
step_quality,
expected_converged,
expected_logs,
caplog,
):
dx = torch.tensor([0.0, 0.01, 0.0])
gradient = torch.tensor([0.0, -0.02, 0.0])
loss_history = [torch.tensor(1.0), torch.tensor(0.1), torch.tensor(0.11)]

with caplog.at_level(logging.INFO):
has_converged = _has_converged(
dx,
loss_history,
gradient,
step_quality,
LevenbergMarquardtConfig(
max_steps=1,
n_convergence_steps=n_convergence_steps,
n_convergence_criteria=n_convergence_criteria,
convergence_step=0.2,
convergence_loss=0.2,
convergence_gradient=0.2,
),
)
assert has_converged == expected_converged
assert all(log in caplog.text for log in expected_logs)


def test_levenberg_marquardt_adaptive(mocker, caplog):
"""Make sure the trust radius is adjusted correctly based on the loss."""

Expand Down

0 comments on commit 867d801

Please sign in to comment.