Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LM optimizer convergence checks #46

Merged
merged 1 commit into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,37 @@ class LevenbergMarquardtConfig(pydantic.BaseModel):
description="The threshold above which the step is considered high quality.",
)

convergence_loss: float = pydantic.Field(
1.0e-4,
description="The loss will be considered converged if its std deviation over "
"the last `n_convergence_steps` steps is less than this value.",
gt=0.0,
)
convergence_gradient: float = pydantic.Field(
1.0e-3,
description="The gradient will be considered converged if its norm is less than"
"this value.",
gt=0.0,
)
convergence_step: float = pydantic.Field(
1.0e-4,
description="The step size will be considered converged if its norm is less "
"than this value.",
gt=0.0,
)
n_convergence_steps: int = pydantic.Field(
2,
description="The number of steps to consider when checking for convergence "
"in the loss.",
)
n_convergence_criteria: int = pydantic.Field(
1,
description="The number of convergence criteria that must be satisfied before "
"the optimization is considered converged. If 0, no convergence criteria will "
"be used and the optimizer will run for ``max_steps`` full steps.",
ge=0,
)

max_steps: int = pydantic.Field(
..., description="The maximum number of full steps to perform.", gt=0
)
Expand Down Expand Up @@ -387,6 +418,61 @@ def _update_trust_radius(
return trust_radius


def _has_converged(
dx: torch.Tensor,
loss_history: list[torch.Tensor],
gradient: torch.Tensor,
step_quality: float,
config: LevenbergMarquardtConfig,
) -> bool:
"""Check whether the optimization has converged.

Args:
dx: The current step.
loss_history: The loss history.
gradient: The current gradient.
step_quality: The quality of the current step.
config: The optimizer config.

Returns:
Whether the optimization has converged.
"""
if config.n_convergence_criteria == 0:
return False

if step_quality <= config.quality_threshold_low:
# don't converge on low quality steps
return False

grad_norm = torch.linalg.norm(gradient)
grad_converged = grad_norm < config.convergence_gradient

step_norm = torch.linalg.norm(dx)
step_converged = 0.0 <= step_norm < config.convergence_step

loss_std = (
torch.inf
if len(loss_history) == 0
else torch.std(torch.stack(loss_history[-config.n_convergence_steps :]))
)
loss_converged = (
len(loss_history) >= config.n_convergence_steps
and loss_std < config.convergence_loss
)

if grad_converged:
_LOGGER.info(f"gradient norm is converged: ({grad_norm:.2e})")
if step_converged:
_LOGGER.info(f"step size is converged: ({step_norm:.2e})")
if loss_converged:
_LOGGER.info(f"loss is converged: ({loss_std:.2e})")

return bool(
sum((grad_converged, step_converged, loss_converged))
>= config.n_convergence_criteria
)


@torch.no_grad()
def levenberg_marquardt(
x: torch.Tensor,
Expand Down Expand Up @@ -422,6 +508,10 @@ def levenberg_marquardt(
closure_prev = closure_fn(x, True, True)
trust_radius = torch.tensor(config.trust_radius).to(x.device)

loss_history = []

has_converged = False

for step in range(config.max_steps):
loss_prev, gradient_prev, hessian_prev = closure_prev

Expand Down Expand Up @@ -465,9 +555,19 @@ def levenberg_marquardt(

if accept_step:
x.data.copy_(x_next.data)
loss_history.append(loss.detach().cpu().clone())

closure_prev = (loss, gradient, hessian)

_LOGGER.info(f"step={step} loss={loss.detach().cpu().item()}")
_LOGGER.info(f"step={step} loss={loss.detach().cpu().item():.4e}")

if _has_converged(dx, loss_history, gradient, step_quality, config):
_LOGGER.info(f"optimization has converged after {step + 1} steps.")
has_converged = True

break

if not has_converged:
_LOGGER.info(f"optimization has not converged after {config.max_steps} steps.")

return x
60 changes: 60 additions & 0 deletions descent/tests/optim/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from descent.optim._lm import (
LevenbergMarquardtConfig,
_damping_factor_loss_fn,
_has_converged,
_hessian_diagonal_search,
_solver,
_step,
Expand Down Expand Up @@ -141,6 +142,65 @@ def test_damping_factor_loss_fn(mocker):
assert torch.isclose(difference, expected_difference)


@pytest.mark.parametrize(
"n_convergence_criteria, n_convergence_steps, step_quality, expected_converged, expected_logs",
[
(0, 2, 1.0, False, []),
(1, 2, 0.0, False, []),
(
1,
2,
1.0,
True,
[
"gradient norm is converged",
"step size is converged",
"loss is converged",
],
),
(
3,
3,
1.0,
False,
[
"gradient norm is converged",
"step size is converged",
],
),
],
)
def test_has_converged(
n_convergence_criteria,
n_convergence_steps,
step_quality,
expected_converged,
expected_logs,
caplog,
):
dx = torch.tensor([0.0, 0.01, 0.0])
gradient = torch.tensor([0.0, -0.02, 0.0])
loss_history = [torch.tensor(1.0), torch.tensor(0.1), torch.tensor(0.11)]

with caplog.at_level(logging.INFO):
has_converged = _has_converged(
dx,
loss_history,
gradient,
step_quality,
LevenbergMarquardtConfig(
max_steps=1,
n_convergence_steps=n_convergence_steps,
n_convergence_criteria=n_convergence_criteria,
convergence_step=0.2,
convergence_loss=0.2,
convergence_gradient=0.2,
),
)
assert has_converged == expected_converged
assert all(log in caplog.text for log in expected_logs)


def test_levenberg_marquardt_adaptive(mocker, caplog):
"""Make sure the trust radius is adjusted correctly based on the loss."""

Expand Down