diff --git a/descent/targets/thermo.py b/descent/targets/thermo.py index f3e808b..3c52b2f 100644 --- a/descent/targets/thermo.py +++ b/descent/targets/thermo.py @@ -1,7 +1,9 @@ """Train against thermodynamic properties.""" import contextlib import hashlib +import logging import pathlib +import pickle import typing import numpy @@ -13,6 +15,9 @@ import descent.utils.molecule +_LOGGER = logging.getLogger(__name__) + + DataType = typing.Literal["density", "hvap", "hmix"] DATA_TYPES = typing.get_args(DataType) @@ -235,7 +240,7 @@ def _bulk_config(temperature: float, pressure: float) -> SimulationConfig: n_steps=500000, timestep=2.0 * openmm.unit.femtosecond, ), - production_frequency=500, + production_frequency=1000, ) @@ -396,7 +401,7 @@ def _compute_averages( output_dir: pathlib.Path, cached_dir: pathlib.Path | None, ) -> dict[str, torch.Tensor]: - traj_hash = hashlib.sha256(key, usedforsecurity=False).hexdigest() + traj_hash = hashlib.sha256(pickle.dumps(key)).hexdigest() traj_name = f"{phase}-{traj_hash}-frames.msgpack" cached_path = None if cached_dir is None else cached_dir / traj_name @@ -410,6 +415,9 @@ def _compute_averages( system, force_field, cached_path, temperature, pressure ) + if cached_path is not None: + _LOGGER.debug(f"unable to re-weight {key}: data exists={cached_path.exists()}") + output_path = output_dir / traj_name config = default_config(phase, key.temperature, key.pressure) @@ -548,4 +556,7 @@ def predict( torch.tensor(entry["value"]) * per_type_scales.get(entry["type"], 1.0) ) - return torch.stack(reference), torch.stack(predicted) + predicted = torch.stack(predicted) + reference = torch.stack(reference).to(predicted.device) + + return reference, predicted diff --git a/descent/tests/utils/test_loss.py b/descent/tests/utils/test_loss.py index 7b67a81..61f2451 100644 --- a/descent/tests/utils/test_loss.py +++ b/descent/tests/utils/test_loss.py @@ -1,6 +1,6 @@ import torch -from descent.utils.loss import to_closure +from descent.utils.loss import approximate_hessian, to_closure def test_to_closure(): @@ -35,3 +35,21 @@ def mock_loss_fn(x: torch.Tensor, a: float, b: float) -> torch.Tensor: assert loss is not None assert grad is not None assert hess is None + + +def test_approximate_hessian(): + x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) + y_pred = 5.0 * x**2 + 3.0 * x + 2.0 + + actual_hessian = approximate_hessian(x, y_pred) + expected_hess = torch.tensor( + [ + [338.0, 0.0, 0.0, 0.0], + [0.0, 1058.0, 0.0, 0.0], + [0.0, 0.0, 2178.0, 0.0], + [0.0, 0.0, 0.0, 3698.0], + ] + ) + + assert actual_hessian.shape == expected_hess.shape + assert torch.allclose(actual_hessian, expected_hess) diff --git a/descent/utils/loss.py b/descent/utils/loss.py index 909ac70..b652d59 100644 --- a/descent/utils/loss.py +++ b/descent/utils/loss.py @@ -51,3 +51,23 @@ def closure_fn( return loss.detach(), gradient, hessian return closure_fn + + +def approximate_hessian(x: torch.Tensor, y_pred: torch.Tensor): + """Compute the outer product approximation of the hessian of a least squares + loss function of the sum ``sum((y_pred - y_ref)**2)``. + + Args: + x: The parameter tensor with ``shape=(n_parameters,)``. + y_pred: The values predicted using ``x`` with ``shape=(n_predications,)``. + + Returns: + The outer product approximation of the hessian with ``shape=n_parameters + """ + + y_pred_grad = [torch.autograd.grad(y, x, retain_graph=True)[0] for y in y_pred] + y_pred_grad = torch.stack(y_pred_grad, dim=0) + + return ( + 2.0 * torch.einsum("bi,bj->bij", y_pred_grad, y_pred_grad).sum(dim=0) + ).detach()