Skip to content

Commit

Permalink
Migrate thermo
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Nov 11, 2023
1 parent 8b6d0af commit 6920539
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
2 changes: 2 additions & 0 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def report(
rmse = torch.sqrt(delta_sqr / len(energies["ref"]))
data_row[f"RMSE {force_field_name}"] = rmse.item()

data_row["Source"] = dimer["source"]

delta_sqr_count += len(energies["ref"])

rows.append(data_row)
Expand Down
21 changes: 14 additions & 7 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pickle
import typing

import datasets
import datasets.table
import numpy
import openmm.unit
import pyarrow
Expand All @@ -14,6 +16,8 @@
import torch
from rdkit import Chem

import descent.utils.dataset

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -119,7 +123,7 @@ class SimulationConfig(pydantic.BaseModel):
_SystemDict = dict[SimulationKey, smee.TensorSystem]


def create_dataset(*rows: DataEntry) -> pyarrow.Table:
def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""Create a dataset from a list of existing data points.
Args:
Expand All @@ -138,10 +142,13 @@ def create_dataset(*rows: DataEntry) -> pyarrow.Table:
row["smiles_b"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_b"]))

# TODO: validate rows
return pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)
table = pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)

dataset = datasets.Dataset(datasets.table.InMemoryTable(table))
return dataset


def extract_smiles(dataset: pyarrow.Table) -> list[str]:
def extract_smiles(dataset: datasets.Dataset) -> list[str]:
"""Return a list of unique SMILES strings in the dataset.
Args:
Expand All @@ -150,8 +157,8 @@ def extract_smiles(dataset: pyarrow.Table) -> list[str]:
Returns:
The unique SMILES strings with full atom mapping.
"""
smiles_a = dataset["smiles_a"].drop_null().unique().to_pylist()
smiles_b = dataset["smiles_b"].drop_null().unique().to_pylist()
smiles_a = {smiles for smiles in dataset.unique("smiles_a") if smiles is not None}
smiles_b = {smiles for smiles in dataset.unique("smiles_b") if smiles is not None}

smiles_unique = sorted({*smiles_a, *smiles_b})
return smiles_unique
Expand Down Expand Up @@ -510,7 +517,7 @@ def _predict(


def predict(
dataset: pyarrow.Table,
dataset: datasets.Dataset,
force_field: smee.TensorForceField,
topologies: dict[str, smee.TensorTopology],
output_dir: pathlib.Path,
Expand All @@ -532,7 +539,7 @@ def predict(
will be used for any data type not specified.
"""

entries: list[DataEntry] = dataset.to_pylist()
entries: list[DataEntry] = [*descent.utils.dataset.iter_dataset(dataset)]

required_simulations, entry_to_simulation = _plan_simulations(entries, topologies)
averages = {
Expand Down
9 changes: 5 additions & 4 deletions descent/tests/targets/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import smee.mm
import torch

import descent.utils.dataset
from descent.targets.thermo import (
DataEntry,
SimulationKey,
Expand Down Expand Up @@ -88,14 +89,14 @@ def mock_hmix() -> DataEntry:


def test_create_dataset(mock_density_pure, mock_density_binary):
expected_data_entries = [mock_density_pure, mock_density_binary]
expected_entries = [mock_density_pure, mock_density_binary]

dataset = create_dataset(*expected_data_entries)
dataset = create_dataset(*expected_entries)
assert len(dataset) == 2

data_entries = dataset.to_pylist()
entries = list(descent.utils.dataset.iter_dataset(dataset))

assert data_entries == pytest.approx(expected_data_entries)
assert entries == pytest.approx(expected_entries)


def test_extract_smiles(mock_density_pure, mock_density_binary):
Expand Down

0 comments on commit 6920539

Please sign in to comment.