Skip to content

Commit

Permalink
Ensure thermo dataset SMILES are mapped (#62)
Browse files Browse the repository at this point in the history
* Ensure thermo dataset SMILES are mapped

* lint codebase
  • Loading branch information
SimonBoothroyd authored Feb 15, 2024
1 parent 9f3a23e commit cd94f82
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 11 deletions.
1 change: 1 addition & 0 deletions descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
References:
[1]: https://github.com/leeping/forcebalance/blob/b395fd4b/src/optimizer.py
"""

import logging
import math
import typing
Expand Down
1 change: 1 addition & 0 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Train against dimer energies."""

import pathlib
import typing

Expand Down
1 change: 1 addition & 0 deletions descent/targets/energy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Train against relative energies and forces."""

import typing

import datasets
Expand Down
33 changes: 25 additions & 8 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Train against thermodynamic properties."""

import contextlib
import hashlib
import logging
Expand Down Expand Up @@ -113,9 +114,9 @@ class SimulationConfig(pydantic.BaseModel):
False, description="Whether to apply hydrogen mass repartitioning."
)

equilibrate: list[
smee.mm.MinimizationConfig | smee.mm.SimulationConfig
] = pydantic.Field(..., description="Configuration for equilibration simulations.")
equilibrate: list[smee.mm.MinimizationConfig | smee.mm.SimulationConfig] = (
pydantic.Field(..., description="Configuration for equilibration simulations.")
)

production: smee.mm.SimulationConfig = pydantic.Field(
..., description="Configuration for the production simulation."
Expand All @@ -137,6 +138,24 @@ class _Observables(typing.NamedTuple):
_SystemDict = dict[SimulationKey, smee.TensorSystem]


def _map_smiles(smiles: str) -> str:
"""Add atom mapping to a SMILES string if it is not already present."""
params = Chem.SmilesParserParams()
params.removeHs = False

mol = Chem.AddHs(Chem.MolFromSmiles(smiles, params))

map_idxs = sorted(atom.GetAtomMapNum() for atom in mol.GetAtoms())

if map_idxs == list(range(1, len(map_idxs) + 1)):
return smiles

for i, atom in enumerate(mol.GetAtoms()):
atom.SetAtomMapNum(i + 1)

return Chem.MolToSmiles(mol)


def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""Create a dataset from a list of existing data points.
Expand All @@ -148,12 +167,12 @@ def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""

for row in rows:
row["smiles_a"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_a"]))
row["smiles_a"] = _map_smiles(row["smiles_a"])

if row["smiles_b"] is None:
continue

row["smiles_b"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_b"]))
row["smiles_b"] = _map_smiles(row["smiles_b"])

# TODO: validate rows
table = pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)
Expand Down Expand Up @@ -519,9 +538,7 @@ def _predict_hmix(

value = enthalpy_mix - x_0 * enthalpy_0 - x_1 * enthalpy_1
std = torch.sqrt(
enthalpy_mix_std**2
+ x_0**2 * enthalpy_0_std**2
+ x_1**2 * enthalpy_1_std**2
enthalpy_mix_std**2 + x_0**2 * enthalpy_0_std**2 + x_1**2 * enthalpy_1_std**2
)

return value, std
Expand Down
25 changes: 22 additions & 3 deletions descent/tests/targets/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SimulationKey,
_compute_observables,
_convert_entry_to_system,
_map_smiles,
_Observables,
_plan_simulations,
_predict,
Expand Down Expand Up @@ -90,6 +91,21 @@ def mock_hmix() -> DataEntry:
}


@pytest.mark.parametrize(
"smiles, expected",
[
("C", "[C:1]([H:2])([H:3])([H:4])[H:5]"),
("[CH4:1]", "[C:1]([H:2])([H:3])([H:4])[H:5]"),
("[Cl:1][H:2]", "[Cl:1][H:2]"),
("[Cl:2][H:1]", "[Cl:2][H:1]"),
("[Cl:2][H:2]", "[Cl:1][H:2]"),
("[Cl:1][H]", "[Cl:1][H:2]"),
],
)
def test_map_smiles(smiles, expected):
assert _map_smiles(smiles) == expected


def test_create_dataset(mock_density_pure, mock_density_binary):
expected_entries = [mock_density_pure, mock_density_binary]

Expand All @@ -105,7 +121,10 @@ def test_extract_smiles(mock_density_pure, mock_density_binary):
dataset = create_dataset(mock_density_pure, mock_density_binary)
smiles = extract_smiles(dataset)

expected_smiles = ["CCO", "CO"]
expected_smiles = [
"[C:1]([C:2]([O:3][H:9])([H:7])[H:8])([H:4])([H:5])[H:6]",
"[C:1]([O:2][H:6])([H:3])([H:4])[H:5]",
]
assert smiles == expected_smiles


Expand Down Expand Up @@ -497,7 +516,7 @@ def test_predict_hmix(mock_hmix, mocker):
def test_predict(tmp_cwd, mock_density_pure, mocker):
dataset = create_dataset(mock_density_pure)

mock_topologies = {"CO": mocker.Mock()}
mock_topologies = {"[C:1]([O:2][H:6])([H:3])([H:4])[H:5]": mocker.Mock()}
mock_ff = mocker.Mock()

mock_density = torch.tensor(123.0)
Expand All @@ -520,7 +539,7 @@ def test_predict(tmp_cwd, mock_density_pure, mocker):
mock_compute.assert_called_once_with(
"bulk",
SimulationKey(
("CO",),
("[C:1]([O:2][H:6])([H:3])([H:4])[H:5]",),
(256,),
mock_density_pure["temperature"],
mock_density_pure["pressure"],
Expand Down
1 change: 1 addition & 0 deletions descent/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for working with datasets."""

import typing

import datasets
Expand Down
1 change: 1 addition & 0 deletions descent/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for defining loss functions."""

import functools
import typing

Expand Down
1 change: 1 addition & 0 deletions descent/utils/reporting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for reporting results."""

import base64
import io
import itertools
Expand Down

0 comments on commit cd94f82

Please sign in to comment.