Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure thermo dataset SMILES are mapped #62

Merged
merged 3 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions descent/optim/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
References:
[1]: https://github.com/leeping/forcebalance/blob/b395fd4b/src/optimizer.py
"""

import logging
import math
import typing
Expand Down
1 change: 1 addition & 0 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Train against dimer energies."""

import pathlib
import typing

Expand Down
1 change: 1 addition & 0 deletions descent/targets/energy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Train against relative energies and forces."""

import typing

import datasets
Expand Down
33 changes: 25 additions & 8 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Train against thermodynamic properties."""

import contextlib
import hashlib
import logging
Expand Down Expand Up @@ -113,9 +114,9 @@ class SimulationConfig(pydantic.BaseModel):
False, description="Whether to apply hydrogen mass repartitioning."
)

equilibrate: list[
smee.mm.MinimizationConfig | smee.mm.SimulationConfig
] = pydantic.Field(..., description="Configuration for equilibration simulations.")
equilibrate: list[smee.mm.MinimizationConfig | smee.mm.SimulationConfig] = (
pydantic.Field(..., description="Configuration for equilibration simulations.")
)

production: smee.mm.SimulationConfig = pydantic.Field(
..., description="Configuration for the production simulation."
Expand All @@ -137,6 +138,24 @@ class _Observables(typing.NamedTuple):
_SystemDict = dict[SimulationKey, smee.TensorSystem]


def _map_smiles(smiles: str) -> str:
"""Add atom mapping to a SMILES string if it is not already present."""
params = Chem.SmilesParserParams()
params.removeHs = False

mol = Chem.AddHs(Chem.MolFromSmiles(smiles, params))

map_idxs = sorted(atom.GetAtomMapNum() for atom in mol.GetAtoms())

if map_idxs == list(range(1, len(map_idxs) + 1)):
return smiles

for i, atom in enumerate(mol.GetAtoms()):
atom.SetAtomMapNum(i + 1)

return Chem.MolToSmiles(mol)


def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""Create a dataset from a list of existing data points.

Expand All @@ -148,12 +167,12 @@ def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""

for row in rows:
row["smiles_a"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_a"]))
row["smiles_a"] = _map_smiles(row["smiles_a"])

if row["smiles_b"] is None:
continue

row["smiles_b"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_b"]))
row["smiles_b"] = _map_smiles(row["smiles_b"])

# TODO: validate rows
table = pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)
Expand Down Expand Up @@ -519,9 +538,7 @@ def _predict_hmix(

value = enthalpy_mix - x_0 * enthalpy_0 - x_1 * enthalpy_1
std = torch.sqrt(
enthalpy_mix_std**2
+ x_0**2 * enthalpy_0_std**2
+ x_1**2 * enthalpy_1_std**2
enthalpy_mix_std**2 + x_0**2 * enthalpy_0_std**2 + x_1**2 * enthalpy_1_std**2
)

return value, std
Expand Down
25 changes: 22 additions & 3 deletions descent/tests/targets/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SimulationKey,
_compute_observables,
_convert_entry_to_system,
_map_smiles,
_Observables,
_plan_simulations,
_predict,
Expand Down Expand Up @@ -90,6 +91,21 @@ def mock_hmix() -> DataEntry:
}


@pytest.mark.parametrize(
"smiles, expected",
[
("C", "[C:1]([H:2])([H:3])([H:4])[H:5]"),
("[CH4:1]", "[C:1]([H:2])([H:3])([H:4])[H:5]"),
("[Cl:1][H:2]", "[Cl:1][H:2]"),
("[Cl:2][H:1]", "[Cl:2][H:1]"),
("[Cl:2][H:2]", "[Cl:1][H:2]"),
("[Cl:1][H]", "[Cl:1][H:2]"),
],
)
def test_map_smiles(smiles, expected):
assert _map_smiles(smiles) == expected


def test_create_dataset(mock_density_pure, mock_density_binary):
expected_entries = [mock_density_pure, mock_density_binary]

Expand All @@ -105,7 +121,10 @@ def test_extract_smiles(mock_density_pure, mock_density_binary):
dataset = create_dataset(mock_density_pure, mock_density_binary)
smiles = extract_smiles(dataset)

expected_smiles = ["CCO", "CO"]
expected_smiles = [
"[C:1]([C:2]([O:3][H:9])([H:7])[H:8])([H:4])([H:5])[H:6]",
"[C:1]([O:2][H:6])([H:3])([H:4])[H:5]",
]
assert smiles == expected_smiles


Expand Down Expand Up @@ -497,7 +516,7 @@ def test_predict_hmix(mock_hmix, mocker):
def test_predict(tmp_cwd, mock_density_pure, mocker):
dataset = create_dataset(mock_density_pure)

mock_topologies = {"CO": mocker.Mock()}
mock_topologies = {"[C:1]([O:2][H:6])([H:3])([H:4])[H:5]": mocker.Mock()}
mock_ff = mocker.Mock()

mock_density = torch.tensor(123.0)
Expand All @@ -520,7 +539,7 @@ def test_predict(tmp_cwd, mock_density_pure, mocker):
mock_compute.assert_called_once_with(
"bulk",
SimulationKey(
("CO",),
("[C:1]([O:2][H:6])([H:3])([H:4])[H:5]",),
(256,),
mock_density_pure["temperature"],
mock_density_pure["pressure"],
Expand Down
1 change: 1 addition & 0 deletions descent/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for working with datasets."""

import typing

import datasets
Expand Down
1 change: 1 addition & 0 deletions descent/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for defining loss functions."""

import functools
import typing

Expand Down
1 change: 1 addition & 0 deletions descent/utils/reporting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for reporting results."""

import base64
import io
import itertools
Expand Down
Loading