Skip to content

Commit

Permalink
Update thermo target to use new smee syntax (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 8, 2023
1 parent 173e2c7 commit 463b363
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
28 changes: 15 additions & 13 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import pydantic
import smee.mm
import torch

import descent.utils.molecule
from rdkit import Chem

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,6 +128,15 @@ def create_dataset(*rows: DataEntry) -> pyarrow.Table:
Returns:
The created dataset.
"""

for row in rows:
row["smiles_a"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_a"]))

if row["smiles_b"] is None:
continue

row["smiles_b"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_b"]))

# TODO: validate rows
return pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)

Expand All @@ -142,19 +150,11 @@ def extract_smiles(dataset: pyarrow.Table) -> list[str]:
Returns:
The unique SMILES strings with full atom mapping.
"""

from rdkit import Chem

smiles_a = dataset["smiles_a"].drop_null().unique().to_pylist()
smiles_b = dataset["smiles_b"].drop_null().unique().to_pylist()

smiles_unique = sorted({*smiles_a, *smiles_b})
smiles_mapped = [
descent.utils.molecule.mol_to_smiles(Chem.MolFromSmiles(smiles))
for smiles in smiles_unique
]

return smiles_mapped
return smiles_unique


def _convert_entry_to_system(
Expand Down Expand Up @@ -372,7 +372,9 @@ def _simulate(
config: The simulation configuration to use.
output_path: The path at which to write the simulation trajectory.
"""
coords, box_vectors = smee.mm.generate_system_coords(system, config.gen_coords)
coords, box_vectors = smee.mm.generate_system_coords(
system, force_field, config.gen_coords
)

beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * config.production.temperature)

Expand Down Expand Up @@ -410,7 +412,7 @@ def _compute_averages(
pressure = None if key.pressure is None else key.pressure * openmm.unit.atmospheres

if cached_path is not None and cached_path.exists():
with contextlib.suppress(smee.mm._ops.NotEnoughSamplesError):
with contextlib.suppress(smee.mm.NotEnoughSamplesError):
return smee.mm.reweight_ensemble_averages(
system, force_field, cached_path, temperature, pressure
)
Expand Down
7 changes: 2 additions & 5 deletions descent/tests/targets/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def test_extract_smiles(mock_density_pure, mock_density_binary):
dataset = create_dataset(mock_density_pure, mock_density_binary)
smiles = extract_smiles(dataset)

expected_smiles = [
"[H:1][C:4]([H:3])([O:5][H:2])[C:9]([H:6])([H:7])[H:8]",
"[H:1][O:3][C:6]([H:2])([H:4])[H:5]",
]
expected_smiles = ["CCO", "CO"]
assert smiles == expected_smiles


Expand Down Expand Up @@ -278,7 +275,7 @@ def test_simulation(tmp_cwd, mocker):
expected_output = tmp_cwd / "frames.msgpack"
_simulate(mock_system, mock_ff, config, expected_output)

mock_gen_coords.assert_called_once_with(mock_system, config.gen_coords)
mock_gen_coords.assert_called_once_with(mock_system, mock_ff, config.gen_coords)

mock_simulate.assert_called_once_with(
mock_system,
Expand Down

0 comments on commit 463b363

Please sign in to comment.