From f35492f60dc93f25851a69f2f2dc4ed62d0c5329 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Sat, 4 Nov 2023 12:44:05 -0400 Subject: [PATCH] Split `mol_to_smiles` into a util (#48) --- descent/targets/dimers.py | 55 +++++++++------------------- descent/tests/utils/test_molecule.py | 18 +++++++++ descent/utils/molecule.py | 29 +++++++++++++++ 3 files changed, 65 insertions(+), 37 deletions(-) create mode 100644 descent/tests/utils/test_molecule.py create mode 100644 descent/utils/molecule.py diff --git a/descent/targets/dimers.py b/descent/targets/dimers.py index dae32d6..b76da37 100644 --- a/descent/targets/dimers.py +++ b/descent/targets/dimers.py @@ -7,11 +7,11 @@ import smee.utils import torch +import descent.utils.molecule import descent.utils.reporting if typing.TYPE_CHECKING: import pandas - from rdkit import Chem EnergyFn = typing.Callable[ @@ -42,11 +42,11 @@ class Dimer(typing.TypedDict): source: str -def create_dataset(entries: list[Dimer]) -> pyarrow.Table: +def create_dataset(dimers: list[Dimer]) -> pyarrow.Table: """Create a dataset from a list of existing dimers. Args: - entries: The dimers to create the dataset from. + dimers: The dimers to create the dataset from. Returns: The created dataset. @@ -55,37 +55,18 @@ def create_dataset(entries: list[Dimer]) -> pyarrow.Table: return pyarrow.Table.from_pylist( [ { - "smiles_a": entry["smiles_a"], - "smiles_b": entry["smiles_b"], - "coords": torch.tensor(entry["coords"]).flatten().tolist(), - "energy": torch.tensor(entry["energy"]).flatten().tolist(), - "source": entry["source"], + "smiles_a": dimer["smiles_a"], + "smiles_b": dimer["smiles_b"], + "coords": torch.tensor(dimer["coords"]).flatten().tolist(), + "energy": torch.tensor(dimer["energy"]).flatten().tolist(), + "source": dimer["source"], } - for entry in entries + for dimer in dimers ], schema=DATA_SCHEMA, ) -def _mol_to_smiles(mol: "Chem.Mol") -> str: - """Convert a molecule to a SMILES string with atom mapping. - - Args: - mol: The molecule to convert. - - Returns: - The SMILES string. - """ - from rdkit import Chem - - mol = Chem.AddHs(mol) - - for atom in mol.GetAtoms(): - atom.SetAtomMapNum(atom.GetIdx() + 1) - - return Chem.MolToSmiles(mol) - - def create_from_des( data_dir: pathlib.Path, energy_fn: EnergyFn, @@ -109,7 +90,7 @@ def create_from_des( metadata = pandas.read_csv(data_dir / f"{data_dir.name}.csv", index_col=False) system_ids = metadata["system_id"].unique() - entries: list[Dimer] = [] + dimers: list[Dimer] = [] for system_id in system_ids: system_data = metadata[metadata["system_id"] == system_id] @@ -128,8 +109,8 @@ def create_from_des( ) mol_a, mol_b = Chem.GetMolFrags(dimer_example, asMols=True) - smiles_a = _mol_to_smiles(mol_a) - smiles_b = _mol_to_smiles(mol_b) + smiles_a = descent.utils.molecule.mol_to_smiles(mol_a, False) + smiles_b = descent.utils.molecule.mol_to_smiles(mol_b, False) source = ( f"{data_dir.name} system={system_id} orig={group_orig} group={group_id}" @@ -149,7 +130,7 @@ def create_from_des( coords = torch.tensor(coords_raw) energy = energy_fn(group_data, geometry_ids, coords) - entries.append( + dimers.append( { "smiles_a": smiles_a, "smiles_b": smiles_b, @@ -159,7 +140,7 @@ def create_from_des( } ) - return create_dataset(entries) + return create_dataset(dimers) def extract_smiles(dataset: pyarrow.Table) -> list[str]: @@ -328,17 +309,17 @@ def report( rows = [] - for entry in dataset.to_pylist(): - energies = {"ref": torch.tensor(entry["energy"])} + for dimer in dataset.to_pylist(): + energies = {"ref": torch.tensor(dimer["energy"])} energies.update( - (force_field_name, _predict(entry, force_field, topologies)[1]) + (force_field_name, _predict(dimer, force_field, topologies)[1]) for force_field_name, force_field in force_fields.items() ) plot_img = _plot_energies(energies) mol_img = descent.utils.reporting.mols_to_img( - entry["smiles_a"], entry["smiles_b"] + dimer["smiles_a"], dimer["smiles_b"] ) rows.append({"Dimer": mol_img, "Energy [kcal/mol]": plot_img}) diff --git a/descent/tests/utils/test_molecule.py b/descent/tests/utils/test_molecule.py new file mode 100644 index 0000000..844f677 --- /dev/null +++ b/descent/tests/utils/test_molecule.py @@ -0,0 +1,18 @@ +import pytest +from rdkit import Chem + +from descent.utils.molecule import mol_to_smiles + + +@pytest.mark.parametrize( + "input_smiles, expected_smiles, canonical", + [ + ("OC", "[H:1][C:4]([H:2])([O:3][H:5])[H:6]", True), + ("OC", "[O:1]([C:2]([H:4])([H:5])[H:6])[H:3]", False), + ], +) +def test_mol_to_smiles(input_smiles, expected_smiles, canonical): + mol = Chem.MolFromSmiles(input_smiles) + actual_smiles = mol_to_smiles(mol, canonical) + + assert actual_smiles == expected_smiles diff --git a/descent/utils/molecule.py b/descent/utils/molecule.py new file mode 100644 index 0000000..fce1191 --- /dev/null +++ b/descent/utils/molecule.py @@ -0,0 +1,29 @@ +import typing + +if typing.TYPE_CHECKING: + from rdkit import Chem + + +def mol_to_smiles(mol: "Chem.Mol", canonical: bool = True) -> str: + """Convert a molecule to a SMILES string with atom mapping. + + Args: + mol: The molecule to convert. + canonical: Whether to canonicalize the atom ordering prior to assigning + map indices. + + Returns: + The SMILES string. + """ + from rdkit import Chem + + mol = Chem.AddHs(mol) + + if canonical: + order = Chem.CanonicalRankAtoms(mol, includeChirality=True) + mol = Chem.RenumberAtoms(mol, list(order)) + + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(atom.GetIdx() + 1) + + return Chem.MolToSmiles(mol)