-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
867d801
commit 9120b99
Showing
3 changed files
with
96 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
|
||
from descent.utils.loss import to_closure | ||
|
||
|
||
def test_to_closure(): | ||
def mock_loss_fn(x: torch.Tensor, a: float, b: float) -> torch.Tensor: | ||
return (a * x**2 + b).sum() | ||
|
||
closure_fn = to_closure(mock_loss_fn, a=2.0, b=3.0) | ||
|
||
theta = torch.Tensor([1.0, 2.0]).requires_grad_(True) | ||
|
||
expected_loss = torch.tensor(16.0) | ||
expected_grad = torch.tensor([4.0, 8.0]) | ||
expected_hess = torch.tensor([[4.0, 0.0], [0.0, 4.0]]) | ||
|
||
loss, grad, hess = closure_fn(theta, True, True) | ||
|
||
assert loss.shape == expected_loss.shape | ||
assert torch.allclose(loss, expected_loss) | ||
|
||
assert grad.shape == expected_grad.shape | ||
assert torch.allclose(grad, expected_grad) | ||
|
||
assert hess.shape == expected_hess.shape | ||
assert torch.allclose(hess, expected_hess) | ||
|
||
loss, grad, hess = closure_fn(theta, False, True) | ||
assert loss is not None | ||
assert grad is None | ||
assert hess is not None | ||
|
||
loss, grad, hess = closure_fn(theta, True, False) | ||
assert loss is not None | ||
assert grad is not None | ||
assert hess is None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Utilities for defining loss functions.""" | ||
import functools | ||
import typing | ||
|
||
import torch | ||
|
||
ClosureFn = typing.Callable[ | ||
[torch.Tensor, bool, bool], | ||
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None], | ||
] | ||
P = typing.ParamSpec("P") | ||
|
||
|
||
def to_closure( | ||
loss_fn: typing.Callable[typing.Concatenate[torch.Tensor, P], torch.Tensor], | ||
*args: P.args, | ||
**kwargs: P.kwargs, | ||
) -> ClosureFn: | ||
"""Convert a loss function to a closure function used by second-order optimizers. | ||
Args: | ||
loss_fn: The loss function to convert. This should take in a tensor of | ||
parameters with ``shape=(n,)``, and optionally a set of ``args`` and | ||
``kwargs``. | ||
*args: Positional arguments passed to `loss_fn`. | ||
**kwargs: Keyword arguments passed to `loss_fn`. | ||
Returns: | ||
A closure function that takes in a tensor of parameters with ``shape=(n,)``, | ||
a boolean flag indicating whether to compute the gradient, and a boolean flag | ||
indicating whether to compute the Hessian. It returns a tuple of the loss | ||
value, the gradient, and the Hessian. | ||
""" | ||
|
||
loss_fn_wrapped = functools.partial(loss_fn, *args, **kwargs) | ||
|
||
def closure_fn( | ||
x: torch.Tensor, compute_gradient: bool, compute_hessian: bool | ||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: | ||
loss = loss_fn_wrapped(x) | ||
gradient, hessian = None, None | ||
|
||
if compute_hessian: | ||
hessian = torch.autograd.functional.hessian( | ||
loss_fn_wrapped, x, vectorize=True, create_graph=False | ||
).detach() | ||
if compute_gradient: | ||
(gradient,) = torch.autograd.grad(loss, x, create_graph=False) | ||
gradient = gradient.detach() | ||
|
||
return loss.detach(), gradient, hessian | ||
|
||
return closure_fn |