Skip to content

Commit

Permalink
Add approximate hessian utility (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 5, 2023
1 parent e07fc06 commit bc4f255
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
17 changes: 14 additions & 3 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Train against thermodynamic properties."""
import contextlib
import hashlib
import logging
import pathlib
import pickle
import typing

import numpy
Expand All @@ -13,6 +15,9 @@

import descent.utils.molecule

_LOGGER = logging.getLogger(__name__)


DataType = typing.Literal["density", "hvap", "hmix"]

DATA_TYPES = typing.get_args(DataType)
Expand Down Expand Up @@ -235,7 +240,7 @@ def _bulk_config(temperature: float, pressure: float) -> SimulationConfig:
n_steps=500000,
timestep=2.0 * openmm.unit.femtosecond,
),
production_frequency=500,
production_frequency=1000,
)


Expand Down Expand Up @@ -396,7 +401,7 @@ def _compute_averages(
output_dir: pathlib.Path,
cached_dir: pathlib.Path | None,
) -> dict[str, torch.Tensor]:
traj_hash = hashlib.sha256(key, usedforsecurity=False).hexdigest()
traj_hash = hashlib.sha256(pickle.dumps(key)).hexdigest()
traj_name = f"{phase}-{traj_hash}-frames.msgpack"

cached_path = None if cached_dir is None else cached_dir / traj_name
Expand All @@ -410,6 +415,9 @@ def _compute_averages(
system, force_field, cached_path, temperature, pressure
)

if cached_path is not None:
_LOGGER.debug(f"unable to re-weight {key}: data exists={cached_path.exists()}")

output_path = output_dir / traj_name

config = default_config(phase, key.temperature, key.pressure)
Expand Down Expand Up @@ -548,4 +556,7 @@ def predict(
torch.tensor(entry["value"]) * per_type_scales.get(entry["type"], 1.0)
)

return torch.stack(reference), torch.stack(predicted)
predicted = torch.stack(predicted)
reference = torch.stack(reference).to(predicted.device)

return reference, predicted
20 changes: 19 additions & 1 deletion descent/tests/utils/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from descent.utils.loss import to_closure
from descent.utils.loss import approximate_hessian, to_closure


def test_to_closure():
Expand Down Expand Up @@ -35,3 +35,21 @@ def mock_loss_fn(x: torch.Tensor, a: float, b: float) -> torch.Tensor:
assert loss is not None
assert grad is not None
assert hess is None


def test_approximate_hessian():
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
y_pred = 5.0 * x**2 + 3.0 * x + 2.0

actual_hessian = approximate_hessian(x, y_pred)
expected_hess = torch.tensor(
[
[338.0, 0.0, 0.0, 0.0],
[0.0, 1058.0, 0.0, 0.0],
[0.0, 0.0, 2178.0, 0.0],
[0.0, 0.0, 0.0, 3698.0],
]
)

assert actual_hessian.shape == expected_hess.shape
assert torch.allclose(actual_hessian, expected_hess)
20 changes: 20 additions & 0 deletions descent/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,23 @@ def closure_fn(
return loss.detach(), gradient, hessian

return closure_fn


def approximate_hessian(x: torch.Tensor, y_pred: torch.Tensor):
"""Compute the outer product approximation of the hessian of a least squares
loss function of the sum ``sum((y_pred - y_ref)**2)``.
Args:
x: The parameter tensor with ``shape=(n_parameters,)``.
y_pred: The values predicted using ``x`` with ``shape=(n_predications,)``.
Returns:
The outer product approximation of the hessian with ``shape=n_parameters
"""

y_pred_grad = [torch.autograd.grad(y, x, retain_graph=True)[0] for y in y_pred]
y_pred_grad = torch.stack(y_pred_grad, dim=0)

return (
2.0 * torch.einsum("bi,bj->bij", y_pred_grad, y_pred_grad).sum(dim=0)
).detach()

0 comments on commit bc4f255

Please sign in to comment.