Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Hugging Face datasets #56

Merged
merged 4 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 42 additions & 26 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import pathlib
import typing

import datasets
import datasets.table
import pyarrow
import smee
import smee.utils
import torch
import tqdm

import descent.utils.dataset
import descent.utils.molecule
import descent.utils.reporting

Expand Down Expand Up @@ -42,7 +46,7 @@ class Dimer(typing.TypedDict):
source: str


def create_dataset(dimers: list[Dimer]) -> pyarrow.Table:
def create_dataset(dimers: list[Dimer]) -> datasets.Dataset:
"""Create a dataset from a list of existing dimers.

Args:
Expand All @@ -51,8 +55,8 @@ def create_dataset(dimers: list[Dimer]) -> pyarrow.Table:
Returns:
The created dataset.
"""
# TODO: validate rows
return pyarrow.Table.from_pylist(

table = pyarrow.Table.from_pylist(
[
{
"smiles_a": dimer["smiles_a"],
Expand All @@ -65,12 +69,17 @@ def create_dataset(dimers: list[Dimer]) -> pyarrow.Table:
],
schema=DATA_SCHEMA,
)
# TODO: validate rows
dataset = datasets.Dataset(datasets.table.InMemoryTable(table))
dataset.set_format("torch")

return dataset


def create_from_des(
data_dir: pathlib.Path,
energy_fn: EnergyFn,
) -> pyarrow.Table:
) -> datasets.Dataset:
"""Create a dataset from a DESXXX dimer set.

Args:
Expand All @@ -85,14 +94,16 @@ def create_from_des(
The created dataset.
"""
import pandas
from rdkit import Chem
from rdkit import Chem, RDLogger

RDLogger.DisableLog("rdApp.*")

metadata = pandas.read_csv(data_dir / f"{data_dir.name}.csv", index_col=False)

system_ids = metadata["system_id"].unique()
dimers: list[Dimer] = []

for system_id in system_ids:
for system_id in tqdm.tqdm(system_ids, desc="loading dimers"):
system_data = metadata[metadata["system_id"] == system_id]

group_ids = metadata[metadata["system_id"] == system_id]["group_id"].unique()
Expand Down Expand Up @@ -130,20 +141,21 @@ def create_from_des(
coords = torch.tensor(coords_raw)
energy = energy_fn(group_data, geometry_ids, coords)

dimers.append(
{
"smiles_a": smiles_a,
"smiles_b": smiles_b,
"coords": coords,
"energy": energy,
"source": source,
}
)
dimer = {
"smiles_a": smiles_a,
"smiles_b": smiles_b,
"coords": coords,
"energy": energy,
"source": source,
}
dimers.append(dimer)

RDLogger.EnableLog("rdApp.*")

return create_dataset(dimers)


def extract_smiles(dataset: pyarrow.Table) -> list[str]:
def extract_smiles(dataset: datasets.Dataset) -> list[str]:
"""Return a list of unique SMILES strings in the dataset.

Args:
Expand All @@ -153,8 +165,8 @@ def extract_smiles(dataset: pyarrow.Table) -> list[str]:
The list of unique SMILES strings.
"""

smiles_a = dataset["smiles_a"].drop_null().unique().to_pylist()
smiles_b = dataset["smiles_b"].drop_null().unique().to_pylist()
smiles_a = dataset.unique("smiles_a")
smiles_b = dataset.unique("smiles_b")

return sorted({*smiles_a, *smiles_b})

Expand Down Expand Up @@ -238,7 +250,7 @@ def _predict(


def predict(
dataset: pyarrow.Table,
dataset: datasets.Dataset,
force_field: smee.TensorForceField,
topologies: dict[str, smee.TensorTopology],
) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -255,12 +267,13 @@ def predict(
``shape=(n_dimers * n_conf_per_dimer,)``.
"""

dimers: list[Dimer] = dataset.to_pylist()

reference, predicted = zip(
*[_predict(dimer, force_field, topologies) for dimer in dimers]
*[
_predict(dimer, force_field, topologies)
for dimer in descent.utils.dataset.iter_dataset(dataset)
]
)
return torch.stack(reference).flatten(), torch.stack(predicted).flatten()
return torch.cat(reference), torch.cat(predicted)


def _plot_energies(energies: dict[str, torch.Tensor]) -> str:
Expand Down Expand Up @@ -291,7 +304,7 @@ def _plot_energies(energies: dict[str, torch.Tensor]) -> str:


def report(
dataset: pyarrow.Table,
dataset: datasets.Dataset,
force_fields: dict[str, smee.TensorForceField],
topologies: dict[str, smee.TensorTopology],
output_path: pathlib.Path,
Expand All @@ -314,8 +327,8 @@ def report(
}
delta_sqr_count = 0

for dimer in dataset.to_pylist():
energies = {"ref": torch.tensor(dimer["energy"])}
for dimer in descent.utils.dataset.iter_dataset(dataset):
energies = {"ref": dimer["energy"]}
energies.update(
(force_field_name, _predict(dimer, force_field, topologies)[1])
for force_field_name, force_field in force_fields.items()
Expand All @@ -333,6 +346,8 @@ def report(
rmse = torch.sqrt(delta_sqr / len(energies["ref"]))
data_row[f"RMSE {force_field_name}"] = rmse.item()

data_row["Source"] = dimer["source"]

delta_sqr_count += len(energies["ref"])

rows.append(data_row)
Expand Down Expand Up @@ -380,6 +395,7 @@ def report(
selectable=False,
disabled=True,
formatters=formatters_full,
configuration={"rowHeight": 400},
),
sizing_mode="stretch_width",
scroll=True,
Expand Down
21 changes: 14 additions & 7 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pickle
import typing

import datasets
import datasets.table
import numpy
import openmm.unit
import pyarrow
Expand All @@ -14,6 +16,8 @@
import torch
from rdkit import Chem

import descent.utils.dataset

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -119,7 +123,7 @@ class SimulationConfig(pydantic.BaseModel):
_SystemDict = dict[SimulationKey, smee.TensorSystem]


def create_dataset(*rows: DataEntry) -> pyarrow.Table:
def create_dataset(*rows: DataEntry) -> datasets.Dataset:
"""Create a dataset from a list of existing data points.

Args:
Expand All @@ -138,10 +142,13 @@ def create_dataset(*rows: DataEntry) -> pyarrow.Table:
row["smiles_b"] = Chem.MolToSmiles(Chem.MolFromSmiles(row["smiles_b"]))

# TODO: validate rows
return pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)
table = pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA)

dataset = datasets.Dataset(datasets.table.InMemoryTable(table))
return dataset


def extract_smiles(dataset: pyarrow.Table) -> list[str]:
def extract_smiles(dataset: datasets.Dataset) -> list[str]:
"""Return a list of unique SMILES strings in the dataset.

Args:
Expand All @@ -150,8 +157,8 @@ def extract_smiles(dataset: pyarrow.Table) -> list[str]:
Returns:
The unique SMILES strings with full atom mapping.
"""
smiles_a = dataset["smiles_a"].drop_null().unique().to_pylist()
smiles_b = dataset["smiles_b"].drop_null().unique().to_pylist()
smiles_a = {smiles for smiles in dataset.unique("smiles_a") if smiles is not None}
smiles_b = {smiles for smiles in dataset.unique("smiles_b") if smiles is not None}

smiles_unique = sorted({*smiles_a, *smiles_b})
return smiles_unique
Expand Down Expand Up @@ -510,7 +517,7 @@ def _predict(


def predict(
dataset: pyarrow.Table,
dataset: datasets.Dataset,
force_field: smee.TensorForceField,
topologies: dict[str, smee.TensorTopology],
output_dir: pathlib.Path,
Expand All @@ -532,7 +539,7 @@ def predict(
will be used for any data type not specified.
"""

entries: list[DataEntry] = dataset.to_pylist()
entries: list[DataEntry] = [*descent.utils.dataset.iter_dataset(dataset)]

required_simulations, entry_to_simulation = _plan_simulations(entries, topologies)
averages = {
Expand Down
19 changes: 10 additions & 9 deletions descent/tests/targets/test_dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import smee.converters
import torch

import descent.utils.dataset
from descent.targets.dimers import (
Dimer,
compute_dimer_energy,
Expand All @@ -28,22 +29,21 @@ def mock_dimer() -> Dimer:


def test_create_dataset(mock_dimer):
expected_data_entries = [
expected_entries = [
{
"smiles_a": mock_dimer["smiles_a"],
"smiles_b": mock_dimer["smiles_b"],
"coords": mock_dimer["coords"].flatten().tolist(),
"energy": mock_dimer["energy"].tolist(),
"coords": pytest.approx(mock_dimer["coords"].flatten()),
"energy": pytest.approx(mock_dimer["energy"]),
"source": mock_dimer["source"],
},
]

dataset = create_dataset([mock_dimer])
assert len(dataset) == 1

data_entries = dataset.to_pylist()

assert data_entries == pytest.approx(expected_data_entries)
entries = list(descent.utils.dataset.iter_dataset(dataset))
assert entries == expected_entries


def test_create_from_des(data_dir):
Expand All @@ -63,12 +63,13 @@ def energy_fn(data, ids, coords):
expected = {
"smiles_a": "[C:1]([O:2][H:6])([H:3])([H:4])[H:5]",
"smiles_b": "[O:1]([H:2])[H:3]",
"coords": expected_coords.flatten().tolist(),
"energy": [-1.23],
"coords": pytest.approx(expected_coords.flatten()),
"energy": pytest.approx(torch.tensor([-1.23])),
"source": "DESMOCK system=4321 orig=MOCK group=1423",
}

assert dataset.to_pylist() == [pytest.approx(expected)]
entries = list(descent.utils.dataset.iter_dataset(dataset))
assert entries == [expected]


def test_extract_smiles(mock_dimer):
Expand Down
9 changes: 5 additions & 4 deletions descent/tests/targets/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import smee.mm
import torch

import descent.utils.dataset
from descent.targets.thermo import (
DataEntry,
SimulationKey,
Expand Down Expand Up @@ -88,14 +89,14 @@ def mock_hmix() -> DataEntry:


def test_create_dataset(mock_density_pure, mock_density_binary):
expected_data_entries = [mock_density_pure, mock_density_binary]
expected_entries = [mock_density_pure, mock_density_binary]

dataset = create_dataset(*expected_data_entries)
dataset = create_dataset(*expected_entries)
assert len(dataset) == 2

data_entries = dataset.to_pylist()
entries = list(descent.utils.dataset.iter_dataset(dataset))

assert data_entries == pytest.approx(expected_data_entries)
assert entries == pytest.approx(expected_entries)


def test_extract_smiles(mock_density_pure, mock_density_binary):
Expand Down
18 changes: 18 additions & 0 deletions descent/tests/utils/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import datasets

from descent.utils.dataset import iter_dataset


def test_iter_dataset():
test_dataset = datasets.Dataset.from_dict(
{"column1": ["data1", "data2", "data3"], "column2": ["data4", "data5", "data6"]}
)
expected_output = [
{"column1": "data1", "column2": "data4"},
{"column1": "data2", "column2": "data5"},
{"column1": "data3", "column2": "data6"},
]

output = list(iter_dataset(test_dataset))

assert output == expected_output
22 changes: 22 additions & 0 deletions descent/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Utilities for working with datasets."""
import typing

import datasets


def iter_dataset(dataset: datasets.Dataset) -> typing.Iterator[dict[str, typing.Any]]:
"""Iterate over a Hugging Face Dataset, yielding each 'row' as a dictionary.

Args:
dataset: The dataset to iterate over.

Yields:
A dictionary representing a single entry in the batch, where each key is a
column name and the corresponding value is the entry in that column for the
current row.
"""

columns = [*dataset.features]

for row in zip(*[dataset[column] for column in columns]):
yield {column: v for column, v in zip(columns, row)}
1 change: 1 addition & 0 deletions devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- pytorch
- pydantic
- pyarrow
- datasets

### Levenberg Marquardt
- scipy
Expand Down