Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split mol_to_smiles into a util #48

Merged
merged 1 commit into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 18 additions & 37 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import smee.utils
import torch

import descent.utils.molecule
import descent.utils.reporting

if typing.TYPE_CHECKING:
import pandas
from rdkit import Chem


EnergyFn = typing.Callable[
Expand Down Expand Up @@ -42,11 +42,11 @@ class Dimer(typing.TypedDict):
source: str


def create_dataset(entries: list[Dimer]) -> pyarrow.Table:
def create_dataset(dimers: list[Dimer]) -> pyarrow.Table:
"""Create a dataset from a list of existing dimers.

Args:
entries: The dimers to create the dataset from.
dimers: The dimers to create the dataset from.

Returns:
The created dataset.
Expand All @@ -55,37 +55,18 @@ def create_dataset(entries: list[Dimer]) -> pyarrow.Table:
return pyarrow.Table.from_pylist(
[
{
"smiles_a": entry["smiles_a"],
"smiles_b": entry["smiles_b"],
"coords": torch.tensor(entry["coords"]).flatten().tolist(),
"energy": torch.tensor(entry["energy"]).flatten().tolist(),
"source": entry["source"],
"smiles_a": dimer["smiles_a"],
"smiles_b": dimer["smiles_b"],
"coords": torch.tensor(dimer["coords"]).flatten().tolist(),
"energy": torch.tensor(dimer["energy"]).flatten().tolist(),
"source": dimer["source"],
}
for entry in entries
for dimer in dimers
],
schema=DATA_SCHEMA,
)


def _mol_to_smiles(mol: "Chem.Mol") -> str:
"""Convert a molecule to a SMILES string with atom mapping.

Args:
mol: The molecule to convert.

Returns:
The SMILES string.
"""
from rdkit import Chem

mol = Chem.AddHs(mol)

for atom in mol.GetAtoms():
atom.SetAtomMapNum(atom.GetIdx() + 1)

return Chem.MolToSmiles(mol)


def create_from_des(
data_dir: pathlib.Path,
energy_fn: EnergyFn,
Expand All @@ -109,7 +90,7 @@ def create_from_des(
metadata = pandas.read_csv(data_dir / f"{data_dir.name}.csv", index_col=False)

system_ids = metadata["system_id"].unique()
entries: list[Dimer] = []
dimers: list[Dimer] = []

for system_id in system_ids:
system_data = metadata[metadata["system_id"] == system_id]
Expand All @@ -128,8 +109,8 @@ def create_from_des(
)
mol_a, mol_b = Chem.GetMolFrags(dimer_example, asMols=True)

smiles_a = _mol_to_smiles(mol_a)
smiles_b = _mol_to_smiles(mol_b)
smiles_a = descent.utils.molecule.mol_to_smiles(mol_a, False)
smiles_b = descent.utils.molecule.mol_to_smiles(mol_b, False)

source = (
f"{data_dir.name} system={system_id} orig={group_orig} group={group_id}"
Expand All @@ -149,7 +130,7 @@ def create_from_des(
coords = torch.tensor(coords_raw)
energy = energy_fn(group_data, geometry_ids, coords)

entries.append(
dimers.append(
{
"smiles_a": smiles_a,
"smiles_b": smiles_b,
Expand All @@ -159,7 +140,7 @@ def create_from_des(
}
)

return create_dataset(entries)
return create_dataset(dimers)


def extract_smiles(dataset: pyarrow.Table) -> list[str]:
Expand Down Expand Up @@ -328,17 +309,17 @@ def report(

rows = []

for entry in dataset.to_pylist():
energies = {"ref": torch.tensor(entry["energy"])}
for dimer in dataset.to_pylist():
energies = {"ref": torch.tensor(dimer["energy"])}
energies.update(
(force_field_name, _predict(entry, force_field, topologies)[1])
(force_field_name, _predict(dimer, force_field, topologies)[1])
for force_field_name, force_field in force_fields.items()
)

plot_img = _plot_energies(energies)

mol_img = descent.utils.reporting.mols_to_img(
entry["smiles_a"], entry["smiles_b"]
dimer["smiles_a"], dimer["smiles_b"]
)
rows.append({"Dimer": mol_img, "Energy [kcal/mol]": plot_img})

Expand Down
18 changes: 18 additions & 0 deletions descent/tests/utils/test_molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from rdkit import Chem

from descent.utils.molecule import mol_to_smiles


@pytest.mark.parametrize(
"input_smiles, expected_smiles, canonical",
[
("OC", "[H:1][C:4]([H:2])([O:3][H:5])[H:6]", True),
("OC", "[O:1]([C:2]([H:4])([H:5])[H:6])[H:3]", False),
],
)
def test_mol_to_smiles(input_smiles, expected_smiles, canonical):
mol = Chem.MolFromSmiles(input_smiles)
actual_smiles = mol_to_smiles(mol, canonical)

assert actual_smiles == expected_smiles
29 changes: 29 additions & 0 deletions descent/utils/molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import typing

if typing.TYPE_CHECKING:
from rdkit import Chem


def mol_to_smiles(mol: "Chem.Mol", canonical: bool = True) -> str:
"""Convert a molecule to a SMILES string with atom mapping.

Args:
mol: The molecule to convert.
canonical: Whether to canonicalize the atom ordering prior to assigning
map indices.

Returns:
The SMILES string.
"""
from rdkit import Chem

mol = Chem.AddHs(mol)

if canonical:
order = Chem.CanonicalRankAtoms(mol, includeChirality=True)
mol = Chem.RenumberAtoms(mol, list(order))

for atom in mol.GetAtoms():
atom.SetAtomMapNum(atom.GetIdx() + 1)

return Chem.MolToSmiles(mol)