Skip to content

Commit

Permalink
Split mol_to_smiles into a util (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 4, 2023
1 parent 9120b99 commit f35492f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 37 deletions.
55 changes: 18 additions & 37 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import smee.utils
import torch

import descent.utils.molecule
import descent.utils.reporting

if typing.TYPE_CHECKING:
import pandas
from rdkit import Chem


EnergyFn = typing.Callable[
Expand Down Expand Up @@ -42,11 +42,11 @@ class Dimer(typing.TypedDict):
source: str


def create_dataset(entries: list[Dimer]) -> pyarrow.Table:
def create_dataset(dimers: list[Dimer]) -> pyarrow.Table:
"""Create a dataset from a list of existing dimers.
Args:
entries: The dimers to create the dataset from.
dimers: The dimers to create the dataset from.
Returns:
The created dataset.
Expand All @@ -55,37 +55,18 @@ def create_dataset(entries: list[Dimer]) -> pyarrow.Table:
return pyarrow.Table.from_pylist(
[
{
"smiles_a": entry["smiles_a"],
"smiles_b": entry["smiles_b"],
"coords": torch.tensor(entry["coords"]).flatten().tolist(),
"energy": torch.tensor(entry["energy"]).flatten().tolist(),
"source": entry["source"],
"smiles_a": dimer["smiles_a"],
"smiles_b": dimer["smiles_b"],
"coords": torch.tensor(dimer["coords"]).flatten().tolist(),
"energy": torch.tensor(dimer["energy"]).flatten().tolist(),
"source": dimer["source"],
}
for entry in entries
for dimer in dimers
],
schema=DATA_SCHEMA,
)


def _mol_to_smiles(mol: "Chem.Mol") -> str:
"""Convert a molecule to a SMILES string with atom mapping.
Args:
mol: The molecule to convert.
Returns:
The SMILES string.
"""
from rdkit import Chem

mol = Chem.AddHs(mol)

for atom in mol.GetAtoms():
atom.SetAtomMapNum(atom.GetIdx() + 1)

return Chem.MolToSmiles(mol)


def create_from_des(
data_dir: pathlib.Path,
energy_fn: EnergyFn,
Expand All @@ -109,7 +90,7 @@ def create_from_des(
metadata = pandas.read_csv(data_dir / f"{data_dir.name}.csv", index_col=False)

system_ids = metadata["system_id"].unique()
entries: list[Dimer] = []
dimers: list[Dimer] = []

for system_id in system_ids:
system_data = metadata[metadata["system_id"] == system_id]
Expand All @@ -128,8 +109,8 @@ def create_from_des(
)
mol_a, mol_b = Chem.GetMolFrags(dimer_example, asMols=True)

smiles_a = _mol_to_smiles(mol_a)
smiles_b = _mol_to_smiles(mol_b)
smiles_a = descent.utils.molecule.mol_to_smiles(mol_a, False)
smiles_b = descent.utils.molecule.mol_to_smiles(mol_b, False)

source = (
f"{data_dir.name} system={system_id} orig={group_orig} group={group_id}"
Expand All @@ -149,7 +130,7 @@ def create_from_des(
coords = torch.tensor(coords_raw)
energy = energy_fn(group_data, geometry_ids, coords)

entries.append(
dimers.append(
{
"smiles_a": smiles_a,
"smiles_b": smiles_b,
Expand All @@ -159,7 +140,7 @@ def create_from_des(
}
)

return create_dataset(entries)
return create_dataset(dimers)


def extract_smiles(dataset: pyarrow.Table) -> list[str]:
Expand Down Expand Up @@ -328,17 +309,17 @@ def report(

rows = []

for entry in dataset.to_pylist():
energies = {"ref": torch.tensor(entry["energy"])}
for dimer in dataset.to_pylist():
energies = {"ref": torch.tensor(dimer["energy"])}
energies.update(
(force_field_name, _predict(entry, force_field, topologies)[1])
(force_field_name, _predict(dimer, force_field, topologies)[1])
for force_field_name, force_field in force_fields.items()
)

plot_img = _plot_energies(energies)

mol_img = descent.utils.reporting.mols_to_img(
entry["smiles_a"], entry["smiles_b"]
dimer["smiles_a"], dimer["smiles_b"]
)
rows.append({"Dimer": mol_img, "Energy [kcal/mol]": plot_img})

Expand Down
18 changes: 18 additions & 0 deletions descent/tests/utils/test_molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from rdkit import Chem

from descent.utils.molecule import mol_to_smiles


@pytest.mark.parametrize(
"input_smiles, expected_smiles, canonical",
[
("OC", "[H:1][C:4]([H:2])([O:3][H:5])[H:6]", True),
("OC", "[O:1]([C:2]([H:4])([H:5])[H:6])[H:3]", False),
],
)
def test_mol_to_smiles(input_smiles, expected_smiles, canonical):
mol = Chem.MolFromSmiles(input_smiles)
actual_smiles = mol_to_smiles(mol, canonical)

assert actual_smiles == expected_smiles
29 changes: 29 additions & 0 deletions descent/utils/molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import typing

if typing.TYPE_CHECKING:
from rdkit import Chem


def mol_to_smiles(mol: "Chem.Mol", canonical: bool = True) -> str:
"""Convert a molecule to a SMILES string with atom mapping.
Args:
mol: The molecule to convert.
canonical: Whether to canonicalize the atom ordering prior to assigning
map indices.
Returns:
The SMILES string.
"""
from rdkit import Chem

mol = Chem.AddHs(mol)

if canonical:
order = Chem.CanonicalRankAtoms(mol, includeChirality=True)
mol = Chem.RenumberAtoms(mol, list(order))

for atom in mol.GetAtoms():
atom.SetAtomMapNum(atom.GetIdx() + 1)

return Chem.MolToSmiles(mol)

0 comments on commit f35492f

Please sign in to comment.