Skip to content

Commit

Permalink
Add to_closure utility (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 4, 2023
1 parent 867d801 commit 9120b99
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 5 deletions.
11 changes: 6 additions & 5 deletions descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@


ClosureFn = typing.Callable[
[torch.Tensor, bool, bool], tuple[torch.Tensor, torch.Tensor, torch.Tensor]
[torch.Tensor, bool, bool],
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None],
]
CorrectFn = typing.Callable[[torch.Tensor], torch.Tensor]

Expand Down Expand Up @@ -100,19 +101,19 @@ class LevenbergMarquardtConfig(pydantic.BaseModel):
)

convergence_loss: float = pydantic.Field(
1.0e-4,
0.1,
description="The loss will be considered converged if its std deviation over "
"the last `n_convergence_steps` steps is less than this value.",
gt=0.0,
)
convergence_gradient: float = pydantic.Field(
1.0e-3,
0.1,
description="The gradient will be considered converged if its norm is less than"
"this value.",
gt=0.0,
)
convergence_step: float = pydantic.Field(
1.0e-4,
0.01,
description="The step size will be considered converged if its norm is less "
"than this value.",
gt=0.0,
Expand All @@ -123,7 +124,7 @@ class LevenbergMarquardtConfig(pydantic.BaseModel):
"in the loss.",
)
n_convergence_criteria: int = pydantic.Field(
1,
2,
description="The number of convergence criteria that must be satisfied before "
"the optimization is considered converged. If 0, no convergence criteria will "
"be used and the optimizer will run for ``max_steps`` full steps.",
Expand Down
37 changes: 37 additions & 0 deletions descent/tests/utils/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from descent.utils.loss import to_closure


def test_to_closure():
def mock_loss_fn(x: torch.Tensor, a: float, b: float) -> torch.Tensor:
return (a * x**2 + b).sum()

closure_fn = to_closure(mock_loss_fn, a=2.0, b=3.0)

theta = torch.Tensor([1.0, 2.0]).requires_grad_(True)

expected_loss = torch.tensor(16.0)
expected_grad = torch.tensor([4.0, 8.0])
expected_hess = torch.tensor([[4.0, 0.0], [0.0, 4.0]])

loss, grad, hess = closure_fn(theta, True, True)

assert loss.shape == expected_loss.shape
assert torch.allclose(loss, expected_loss)

assert grad.shape == expected_grad.shape
assert torch.allclose(grad, expected_grad)

assert hess.shape == expected_hess.shape
assert torch.allclose(hess, expected_hess)

loss, grad, hess = closure_fn(theta, False, True)
assert loss is not None
assert grad is None
assert hess is not None

loss, grad, hess = closure_fn(theta, True, False)
assert loss is not None
assert grad is not None
assert hess is None
53 changes: 53 additions & 0 deletions descent/utils/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Utilities for defining loss functions."""
import functools
import typing

import torch

ClosureFn = typing.Callable[
[torch.Tensor, bool, bool],
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None],
]
P = typing.ParamSpec("P")


def to_closure(
loss_fn: typing.Callable[typing.Concatenate[torch.Tensor, P], torch.Tensor],
*args: P.args,
**kwargs: P.kwargs,
) -> ClosureFn:
"""Convert a loss function to a closure function used by second-order optimizers.
Args:
loss_fn: The loss function to convert. This should take in a tensor of
parameters with ``shape=(n,)``, and optionally a set of ``args`` and
``kwargs``.
*args: Positional arguments passed to `loss_fn`.
**kwargs: Keyword arguments passed to `loss_fn`.
Returns:
A closure function that takes in a tensor of parameters with ``shape=(n,)``,
a boolean flag indicating whether to compute the gradient, and a boolean flag
indicating whether to compute the Hessian. It returns a tuple of the loss
value, the gradient, and the Hessian.
"""

loss_fn_wrapped = functools.partial(loss_fn, *args, **kwargs)

def closure_fn(
x: torch.Tensor, compute_gradient: bool, compute_hessian: bool
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
loss = loss_fn_wrapped(x)
gradient, hessian = None, None

if compute_hessian:
hessian = torch.autograd.functional.hessian(
loss_fn_wrapped, x, vectorize=True, create_graph=False
).detach()
if compute_gradient:
(gradient,) = torch.autograd.grad(loss, x, create_graph=False)
gradient = gradient.detach()

return loss.detach(), gradient, hessian

return closure_fn

0 comments on commit 9120b99

Please sign in to comment.