From 9120b9951158f4271cbd1382194747c4e2d15258 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Sat, 4 Nov 2023 11:09:08 -0400 Subject: [PATCH] Add `to_closure` utility (#47) --- descent/optim/_lm.py | 11 ++++--- descent/tests/utils/test_loss.py | 37 ++++++++++++++++++++++ descent/utils/loss.py | 53 ++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 descent/tests/utils/test_loss.py create mode 100644 descent/utils/loss.py diff --git a/descent/optim/_lm.py b/descent/optim/_lm.py index 85baf4b..bc0a9a1 100644 --- a/descent/optim/_lm.py +++ b/descent/optim/_lm.py @@ -20,7 +20,8 @@ ClosureFn = typing.Callable[ - [torch.Tensor, bool, bool], tuple[torch.Tensor, torch.Tensor, torch.Tensor] + [torch.Tensor, bool, bool], + tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None], ] CorrectFn = typing.Callable[[torch.Tensor], torch.Tensor] @@ -100,19 +101,19 @@ class LevenbergMarquardtConfig(pydantic.BaseModel): ) convergence_loss: float = pydantic.Field( - 1.0e-4, + 0.1, description="The loss will be considered converged if its std deviation over " "the last `n_convergence_steps` steps is less than this value.", gt=0.0, ) convergence_gradient: float = pydantic.Field( - 1.0e-3, + 0.1, description="The gradient will be considered converged if its norm is less than" "this value.", gt=0.0, ) convergence_step: float = pydantic.Field( - 1.0e-4, + 0.01, description="The step size will be considered converged if its norm is less " "than this value.", gt=0.0, @@ -123,7 +124,7 @@ class LevenbergMarquardtConfig(pydantic.BaseModel): "in the loss.", ) n_convergence_criteria: int = pydantic.Field( - 1, + 2, description="The number of convergence criteria that must be satisfied before " "the optimization is considered converged. If 0, no convergence criteria will " "be used and the optimizer will run for ``max_steps`` full steps.", diff --git a/descent/tests/utils/test_loss.py b/descent/tests/utils/test_loss.py new file mode 100644 index 0000000..7b67a81 --- /dev/null +++ b/descent/tests/utils/test_loss.py @@ -0,0 +1,37 @@ +import torch + +from descent.utils.loss import to_closure + + +def test_to_closure(): + def mock_loss_fn(x: torch.Tensor, a: float, b: float) -> torch.Tensor: + return (a * x**2 + b).sum() + + closure_fn = to_closure(mock_loss_fn, a=2.0, b=3.0) + + theta = torch.Tensor([1.0, 2.0]).requires_grad_(True) + + expected_loss = torch.tensor(16.0) + expected_grad = torch.tensor([4.0, 8.0]) + expected_hess = torch.tensor([[4.0, 0.0], [0.0, 4.0]]) + + loss, grad, hess = closure_fn(theta, True, True) + + assert loss.shape == expected_loss.shape + assert torch.allclose(loss, expected_loss) + + assert grad.shape == expected_grad.shape + assert torch.allclose(grad, expected_grad) + + assert hess.shape == expected_hess.shape + assert torch.allclose(hess, expected_hess) + + loss, grad, hess = closure_fn(theta, False, True) + assert loss is not None + assert grad is None + assert hess is not None + + loss, grad, hess = closure_fn(theta, True, False) + assert loss is not None + assert grad is not None + assert hess is None diff --git a/descent/utils/loss.py b/descent/utils/loss.py new file mode 100644 index 0000000..909ac70 --- /dev/null +++ b/descent/utils/loss.py @@ -0,0 +1,53 @@ +"""Utilities for defining loss functions.""" +import functools +import typing + +import torch + +ClosureFn = typing.Callable[ + [torch.Tensor, bool, bool], + tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None], +] +P = typing.ParamSpec("P") + + +def to_closure( + loss_fn: typing.Callable[typing.Concatenate[torch.Tensor, P], torch.Tensor], + *args: P.args, + **kwargs: P.kwargs, +) -> ClosureFn: + """Convert a loss function to a closure function used by second-order optimizers. + + Args: + loss_fn: The loss function to convert. This should take in a tensor of + parameters with ``shape=(n,)``, and optionally a set of ``args`` and + ``kwargs``. + *args: Positional arguments passed to `loss_fn`. + **kwargs: Keyword arguments passed to `loss_fn`. + + Returns: + A closure function that takes in a tensor of parameters with ``shape=(n,)``, + a boolean flag indicating whether to compute the gradient, and a boolean flag + indicating whether to compute the Hessian. It returns a tuple of the loss + value, the gradient, and the Hessian. + """ + + loss_fn_wrapped = functools.partial(loss_fn, *args, **kwargs) + + def closure_fn( + x: torch.Tensor, compute_gradient: bool, compute_hessian: bool + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + loss = loss_fn_wrapped(x) + gradient, hessian = None, None + + if compute_hessian: + hessian = torch.autograd.functional.hessian( + loss_fn_wrapped, x, vectorize=True, create_graph=False + ).detach() + if compute_gradient: + (gradient,) = torch.autograd.grad(loss, x, create_graph=False) + gradient = gradient.detach() + + return loss.detach(), gradient, hessian + + return closure_fn