From 1d28c0f4b0c8f2e40ca9a316efd7377f81465d80 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Thu, 23 Nov 2023 08:46:51 -0500 Subject: [PATCH] Add energy and force target (#59) --- descent/targets/energy.py | 157 ++++++++++++++++++++++++++ descent/tests/targets/test_energy.py | 158 +++++++++++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 descent/targets/energy.py create mode 100644 descent/tests/targets/test_energy.py diff --git a/descent/targets/energy.py b/descent/targets/energy.py new file mode 100644 index 0000000..d00ddbe --- /dev/null +++ b/descent/targets/energy.py @@ -0,0 +1,157 @@ +"""Train against relative energies and forces.""" +import typing + +import datasets +import datasets.table +import pyarrow +import smee +import torch + +DATA_SCHEMA = pyarrow.schema( + [ + ("smiles", pyarrow.string()), + ("coords", pyarrow.list_(pyarrow.float64())), + ("energy", pyarrow.list_(pyarrow.float64())), + ("forces", pyarrow.list_(pyarrow.float64())), + ] +) + + +class Entry(typing.TypedDict): + """Represents a set of reference energies and forces.""" + + smiles: str + """The indexed SMILES description of the molecule the energies and forces were + computed for.""" + + coords: torch.Tensor + """The coordinates [Å] the energies and forces were evaluated at with + ``shape=(n_confs, n_particles, 3)``.""" + energy: torch.Tensor + """The reference energies [kcal/mol] with ``shape=(n_confs,)``.""" + forces: torch.Tensor + """The reference forces [kcal/mol/Å] with ``shape=(n_confs, n_particles, 3)``.""" + + +def create_dataset(entries: list[Entry]) -> datasets.Dataset: + """Create a dataset from a list of existing entries. + + Args: + entries: The entries to create the dataset from. + + Returns: + The created dataset. + """ + + table = pyarrow.Table.from_pylist( + [ + { + "smiles": entry["smiles"], + "coords": torch.tensor(entry["coords"]).flatten().tolist(), + "energy": torch.tensor(entry["energy"]).flatten().tolist(), + "forces": torch.tensor(entry["forces"]).flatten().tolist(), + } + for entry in entries + ], + schema=DATA_SCHEMA, + ) + # TODO: validate rows + dataset = datasets.Dataset(datasets.table.InMemoryTable(table)) + dataset.set_format("torch") + + return dataset + + +def extract_smiles(dataset: datasets.Dataset) -> list[str]: + """Return a list of unique SMILES strings in the dataset. + + Args: + dataset: The dataset to extract the SMILES strings from. + + Returns: + The list of unique SMILES strings. + """ + return sorted({*dataset.unique("smiles")}) + + +def predict( + dataset: datasets.Dataset, + force_field: smee.TensorForceField, + topologies: dict[str, smee.TensorTopology], + reference: typing.Literal["mean", "min"] = "mean", + normalize: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Predict the relative energies [kcal/mol] and forces [kcal/mol/Å] of a dataset. + + Args: + dataset: The dataset to predict the energies and forces of. + force_field: The force field to use to predict the energies and forces. + topologies: The topologies of the molecules in the dataset. Each key should be + a fully indexed SMILES string. + reference: The reference energy to compute the relative energies with respect + to. This should be either the "mean" energy of all conformers, or the + energy of the conformer with the lowest reference energy ("min"). + normalize: Whether to scale the relative energies by ``1/sqrt(n_confs_i)`` + and the forces by ``1/sqrt(n_confs_i * n_atoms_per_conf_i * 3)`` This + is useful when wanting to compute the MSE per entry. + + Returns: + The predicted and reference relative energies [kcal/mol] with + ``shape=(n_confs,)``, and predicted and reference forces [kcal/mol/Å] with + ``shape=(n_confs * n_atoms_per_conf, 3)``. + """ + energy_ref_all, energy_pred_all = [], [] + forces_ref_all, forces_pred_all = [], [] + + for entry in dataset: + smiles = entry["smiles"] + + energy_ref = entry["energy"] + forces_ref = entry["forces"].reshape(len(energy_ref), -1, 3) + + coords = ( + entry["coords"] + .reshape(len(energy_ref), -1, 3) + .detach() + .requires_grad_(True) + ) + topology = topologies[smiles] + + energy_pred = smee.compute_energy(topology, force_field, coords) + forces_pred = torch.autograd.grad( + energy_pred.sum(), + coords, + create_graph=True, + retain_graph=True, + allow_unused=True, + )[0] + + if reference.lower() == "mean": + energy_ref_0 = energy_ref.mean() + energy_pred_0 = energy_pred.mean() + elif reference.lower() == "min": + min_idx = energy_ref.argmin() + + energy_ref_0 = energy_ref[min_idx] + energy_pred_0 = energy_pred[min_idx] + else: + raise NotImplementedError(f"invalid reference energy {reference}") + + scale_energy, scale_forces = 1.0, 1.0 + + if normalize: + scale_energy = 1.0 / torch.sqrt(torch.tensor(energy_pred.numel())) + scale_forces = 1.0 / torch.sqrt(torch.tensor(forces_pred.numel())) + + energy_ref_all.append(scale_energy * (energy_ref - energy_ref_0)) + forces_ref_all.append(scale_forces * forces_ref.reshape(-1, 3)) + + energy_pred_all.append(scale_energy * (energy_pred - energy_pred_0)) + forces_pred_all.append(scale_forces * forces_pred.reshape(-1, 3)) + + return ( + torch.cat(energy_ref_all), + torch.cat(energy_pred_all), + torch.cat(forces_ref_all), + torch.cat(forces_pred_all), + ) diff --git a/descent/tests/targets/test_energy.py b/descent/tests/targets/test_energy.py new file mode 100644 index 0000000..518e685 --- /dev/null +++ b/descent/tests/targets/test_energy.py @@ -0,0 +1,158 @@ +import math + +import openff.interchange +import openff.toolkit +import pytest +import smee.converters +import torch + +import descent.utils.dataset +from descent.targets.energy import Entry, create_dataset, extract_smiles, predict + + +@pytest.fixture +def mock_meoh_entry() -> Entry: + return { + "smiles": "[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", + "coords": torch.arange(36, dtype=torch.float32).reshape(2, 6, 3), + "energy": 3.0 * torch.arange(2, dtype=torch.float32), + "forces": torch.arange(36, dtype=torch.float32).reshape(2, 6, 3) + 36.0, + } + + +@pytest.fixture +def mock_hoh_entry() -> Entry: + return { + "smiles": "[H:2][O:1][H:3]", + "coords": torch.tensor( + [ + [[0.0, 0.0, 0.0], [-1.0, -0.5, 0.0], [1.0, -0.5, 0.0]], + [[0.0, 0.0, 0.0], [-0.7, -0.5, 0.0], [0.7, -0.5, 0.0]], + ] + ), + "energy": torch.tensor([2.0, 3.0]), + "forces": torch.arange(18, dtype=torch.float32).reshape(2, 3, 3), + } + + +def test_create_dataset(mock_meoh_entry): + expected_entries = [ + { + "smiles": mock_meoh_entry["smiles"], + "coords": pytest.approx(mock_meoh_entry["coords"].flatten()), + "energy": pytest.approx(mock_meoh_entry["energy"]), + "forces": pytest.approx(mock_meoh_entry["forces"].flatten()), + }, + ] + + dataset = create_dataset([mock_meoh_entry]) + assert len(dataset) == 1 + + entries = list(descent.utils.dataset.iter_dataset(dataset)) + assert entries == expected_entries + + +def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): + expected_smiles = ["[C:1]([O:2][H:6])([H:3])([H:4])[H:5]", "[H:2][O:1][H:3]"] + + dataset = create_dataset([mock_meoh_entry, mock_hoh_entry]) + smiles = extract_smiles(dataset) + + assert smiles == expected_smiles + + +@pytest.mark.parametrize( + "reference, normalize," + "expected_energy_ref, expected_forces_ref, " + "expected_energy_pred, expected_forces_pred", + [ + ( + "mean", + True, + torch.tensor([-0.5, 0.5]) / math.sqrt(2.0), + torch.tensor( + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [12.0, 13.0, 14.0], + [15.0, 16.0, 17.0], + ] + ) + / math.sqrt(6.0 * 3.0), + torch.tensor([7.899425506591797, -7.89942741394043]) / math.sqrt(2.0), + torch.tensor( + [ + [0.0, 83.55978393554688, 0.0], + [-161.40325927734375, -41.77988815307617, 0.0], + [161.40325927734375, -41.77988815307617, 0.0], + [0.0, -137.45770263671875, 0.0], + [102.62999725341797, 68.72884368896484, 0.0], + [-102.62999725341797, 68.72884368896484, 0.0], + ] + ) + / math.sqrt(6.0 * 3.0), + ), + ( + "min", + False, + torch.tensor([0.0, 1.0]), + torch.tensor( + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [12.0, 13.0, 14.0], + [15.0, 16.0, 17.0], + ] + ), + torch.tensor([0.0, -15.798852920532227]), + torch.tensor( + [ + [0.0, 83.55978393554688, 0.0], + [-161.40325927734375, -41.77988815307617, 0.0], + [161.40325927734375, -41.77988815307617, 0.0], + [0.0, -137.45770263671875, 0.0], + [102.62999725341797, 68.72884368896484, 0.0], + [-102.62999725341797, 68.72884368896484, 0.0], + ] + ), + ), + ], +) +def test_predict( + reference, + normalize, + expected_energy_ref, + expected_forces_ref, + expected_energy_pred, + expected_forces_pred, + mock_hoh_entry, +): + dataset = create_dataset([mock_hoh_entry]) + + force_field, [topology] = smee.converters.convert_interchange( + openff.interchange.Interchange.from_smirnoff( + openff.toolkit.ForceField("openff-1.3.0.offxml"), + openff.toolkit.Molecule.from_mapped_smiles( + mock_hoh_entry["smiles"] + ).to_topology(), + ) + ) + topologies = {mock_hoh_entry["smiles"]: topology} + + energy_ref, energy_pred, forces_ref, forces_pred = predict( + dataset, force_field, topologies, reference=reference, normalize=normalize + ) + + assert energy_pred.shape == expected_energy_pred.shape + assert torch.allclose(energy_pred, expected_energy_pred) + assert energy_ref.shape == expected_energy_ref.shape + assert torch.allclose(energy_ref, expected_energy_ref) + + assert forces_pred.shape == expected_forces_pred.shape + assert torch.allclose(forces_pred, expected_forces_pred) + assert forces_ref.shape == expected_forces_ref.shape + assert torch.allclose(forces_ref, expected_forces_ref)