From e07fc0612a6ae3f653dcba464f505dea00c0de30 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Sat, 4 Nov 2023 16:33:01 -0400 Subject: [PATCH] Add initial thermo target (#49) --- descent/targets/thermo.py | 551 +++++++++++++++++++++++++++ descent/tests/targets/test_thermo.py | 504 ++++++++++++++++++++++++ 2 files changed, 1055 insertions(+) create mode 100644 descent/targets/thermo.py create mode 100644 descent/tests/targets/test_thermo.py diff --git a/descent/targets/thermo.py b/descent/targets/thermo.py new file mode 100644 index 0000000..f3e808b --- /dev/null +++ b/descent/targets/thermo.py @@ -0,0 +1,551 @@ +"""Train against thermodynamic properties.""" +import contextlib +import hashlib +import pathlib +import typing + +import numpy +import openmm.unit +import pyarrow +import pydantic +import smee.mm +import torch + +import descent.utils.molecule + +DataType = typing.Literal["density", "hvap", "hmix"] + +DATA_TYPES = typing.get_args(DataType) + +DATA_SCHEMA = pyarrow.schema( + [ + ("type", pyarrow.string()), + ("smiles_a", pyarrow.string()), + ("x_a", pyarrow.float64()), + ("smiles_b", pyarrow.string()), + ("x_b", pyarrow.float64()), + ("temperature", pyarrow.float64()), + ("pressure", pyarrow.float64()), + ("value", pyarrow.float64()), + ("std", pyarrow.float64()), + ("units", pyarrow.string()), + ("source", pyarrow.string()), + ] +) + +_REQUIRES_BULK_SIM = {"density": True, "hvap": True, "hmix": True} +"""Whether a bulk simulation is required for each data type.""" +_REQUIRES_PURE_SIM = {"density": False, "hvap": False, "hmix": True} +"""Whether a simulation of each component is required for each data type.""" +_REQUIRES_VACUUM_SIM = {"density": False, "hvap": True, "hmix": False} +"""Whether a vacuum simulation is required for each data type.""" + +Phase = typing.Literal["bulk", "vacuum"] +PHASES = typing.get_args(Phase) + + +class DataEntry(typing.TypedDict): + """Represents a single experimental data point.""" + + type: DataType + """The type of data point.""" + + smiles_a: str + """The SMILES definition of the first component.""" + x_a: float | None + """The mole fraction of the first component. This must be set to 1.0 if the data""" + + smiles_b: str | None + """The SMILES definition of the second component if present.""" + x_b: float | None + """The mole fraction of the second component if present.""" + + temperature: float + """The temperature at which the data point was measured.""" + pressure: float + """The pressure at which the data point was measured.""" + + value: float + """The value of the data point.""" + std: float | None + """The standard deviation of the data point if available.""" + units: str + """The units of the data point.""" + + source: str + """The source of the data point.""" + + +class SimulationKey(typing.NamedTuple): + """A key used to identify a simulation.""" + + smiles: tuple[str, ...] + """The SMILES definitions of the components present in the system.""" + counts: tuple[int, ...] + """The number of copies of each component present in the system.""" + + temperature: float + """The temperature [K] at which the simulation was run.""" + pressure: float | None + """The pressure [atm] at which the simulation was run.""" + + +class SimulationConfig(pydantic.BaseModel): + """Configuration for a simulation to run.""" + + max_mols: int = pydantic.Field( + ..., description="The maximum number of molecules to simulate." + ) + gen_coords: smee.mm.GenerateCoordsConfig = pydantic.Field( + ..., description="Configuration for generating initial coordinates." + ) + + equilibrate: list[ + smee.mm.MinimizationConfig | smee.mm.SimulationConfig + ] = pydantic.Field(..., description="Configuration for equilibration simulations.") + + production: smee.mm.SimulationConfig = pydantic.Field( + ..., description="Configuration for the production simulation." + ) + production_frequency: int = pydantic.Field( + ..., description="The frequency at which to write frames during production." + ) + + +_SystemDict = dict[SimulationKey, smee.TensorSystem] + + +def create_dataset(*rows: DataEntry) -> pyarrow.Table: + """Create a dataset from a list of existing data points. + + Args: + rows: The data points to create the dataset from. + + Returns: + The created dataset. + """ + # TODO: validate rows + return pyarrow.Table.from_pylist([*rows], schema=DATA_SCHEMA) + + +def extract_smiles(dataset: pyarrow.Table) -> list[str]: + """Return a list of unique SMILES strings in the dataset. + + Args: + dataset: The dataset to extract the SMILES strings from. + + Returns: + The unique SMILES strings with full atom mapping. + """ + + from rdkit import Chem + + smiles_a = dataset["smiles_a"].drop_null().unique().to_pylist() + smiles_b = dataset["smiles_b"].drop_null().unique().to_pylist() + + smiles_unique = sorted({*smiles_a, *smiles_b}) + smiles_mapped = [ + descent.utils.molecule.mol_to_smiles(Chem.MolFromSmiles(smiles)) + for smiles in smiles_unique + ] + + return smiles_mapped + + +def _convert_entry_to_system( + entry: DataEntry, topologies: dict[str, smee.TensorTopology], max_mols: int +) -> tuple[SimulationKey, smee.TensorSystem]: + """Convert a data entry into a system ready to simulate. + + Args: + entry: The data entry to convert. + topologies: The topologies of the molecules present in the dataset, with keys + of mapped SMILES patterns. + max_mols: The maximum number of molecules to simulate. + + Returns: + The system and its associated key. + """ + smiles_a = entry["smiles_a"] + smiles_b = entry["smiles_b"] + + fraction_a = 0.0 if entry["x_a"] is None else entry["x_a"] + fraction_b = 0.0 if entry["x_b"] is None else entry["x_b"] + + assert numpy.isclose(fraction_a + fraction_b, 1.0) + + n_copies_a = int(max_mols * fraction_a) + n_copies_b = int(max_mols * fraction_b) + + smiles = [smiles_a] + + system_topologies = [topologies[smiles_a]] + n_copies = [n_copies_a] + + if n_copies_b > 0: + smiles.append(smiles_b) + + system_topologies.append(topologies[smiles_b]) + n_copies.append(n_copies_b) + + key = SimulationKey( + tuple(smiles), tuple(n_copies), entry["temperature"], entry["pressure"] + ) + system = smee.TensorSystem(system_topologies, n_copies, True) + + return key, system + + +def _bulk_config(temperature: float, pressure: float) -> SimulationConfig: + """Return a default simulation configuration for simulations of the bulk phase. + + Args: + temperature: The temperature [K] at which to run the simulation. + pressure: The pressure [atm] at which to run the simulation. + + Returns: + The default simulation configuration. + """ + temperature = temperature * openmm.unit.kelvin + pressure = pressure * openmm.unit.atmosphere + + return SimulationConfig( + max_mols=256, + gen_coords=smee.mm.GenerateCoordsConfig(), + equilibrate=[ + smee.mm.MinimizationConfig(), + # short NVT equilibration simulation + smee.mm.SimulationConfig( + temperature=temperature, + pressure=None, + n_steps=50000, + timestep=1.0 * openmm.unit.femtosecond, + ), + # short NPT equilibration simulation + smee.mm.SimulationConfig( + temperature=temperature, + pressure=pressure, + n_steps=50000, + timestep=1.0 * openmm.unit.femtosecond, + ), + ], + production=smee.mm.SimulationConfig( + temperature=temperature, + pressure=pressure, + n_steps=500000, + timestep=2.0 * openmm.unit.femtosecond, + ), + production_frequency=500, + ) + + +def _vacuum_config(temperature: float, pressure: float | None) -> SimulationConfig: + """Return a default simulation configuration for simulations of the vacuum phase. + + Args: + temperature: The temperature [K] at which to run the simulation. + pressure: The pressure [atm] at which to run the simulation. + + Returns: + The default simulation configuration. + """ + temperature = temperature * openmm.unit.kelvin + assert pressure is None + + return SimulationConfig( + max_mols=1, + gen_coords=smee.mm.GenerateCoordsConfig(), + equilibrate=[ + smee.mm.MinimizationConfig(), + smee.mm.SimulationConfig( + temperature=temperature, + pressure=None, + n_steps=50000, + timestep=1.0 * openmm.unit.femtosecond, + ), + ], + production=smee.mm.SimulationConfig( + temperature=temperature, + pressure=None, + n_steps=1000000, + timestep=1.0 * openmm.unit.femtosecond, + ), + production_frequency=500, + ) + + +def default_config( + phase: Phase, temperature: float, pressure: float | None +) -> SimulationConfig: + """Return a default simulation configuration for the specified phase. + + Args: + phase: The phase to return the default configuration for. + temperature: The temperature [K] at which to run the simulation. + pressure: The pressure [atm] at which to run the simulation. + + Returns: + The default simulation configuration. + """ + + if phase.lower() == "bulk": + return _bulk_config(temperature, pressure) + elif phase.lower() == "vacuum": + return _vacuum_config(temperature, pressure) + else: + raise NotImplementedError(phase) + + +def _plan_simulations( + entries: list[DataEntry], topologies: dict[str, smee.TensorTopology] +) -> tuple[dict[Phase, _SystemDict], list[dict[str, SimulationKey]]]: + """Plan the simulations required to compute the properties in a dataset. + + Args: + entries: The entries in the dataset. + topologies: The topologies of the molecules present in the dataset, with keys + of mapped SMILES patterns. + + Returns: + The systems to simulate and the simulations required to compute each property. + """ + systems_per_phase: dict[Phase, _SystemDict] = {phase: {} for phase in PHASES} + simulations_per_entry = [] + + for entry in entries: + data_type = entry["type"].lower() + + if data_type not in DATA_TYPES: + raise NotImplementedError(data_type) + + required_sims: dict[str, SimulationKey] = {} + + bulk_config = default_config("bulk", entry["temperature"], entry["pressure"]) + max_mols = bulk_config.max_mols + + if _REQUIRES_BULK_SIM[data_type]: + key, system = _convert_entry_to_system(entry, topologies, max_mols) + + systems_per_phase["bulk"][key] = system + required_sims["bulk"] = key + + if _REQUIRES_PURE_SIM[data_type]: + for i, smiles in enumerate((entry["smiles_a"], entry["smiles_b"])): + key = SimulationKey( + (smiles,), (max_mols,), entry["temperature"], entry["pressure"] + ) + system = smee.TensorSystem([topologies[smiles]], [max_mols], True) + + systems_per_phase["bulk"][key] = system + required_sims[f"bulk_{i}"] = key + + if _REQUIRES_VACUUM_SIM[data_type]: + assert entry["smiles_b"] is None, "vacuum sims only support pure systems" + + system = smee.TensorSystem([topologies[entry["smiles_a"]]], [1], False) + key = SimulationKey((entry["smiles_a"],), (1,), entry["temperature"], None) + + systems_per_phase["vacuum"][key] = system + required_sims["vacuum"] = key + + simulations_per_entry.append(required_sims) + + return systems_per_phase, simulations_per_entry + + +def _simulate( + system: smee.TensorSystem, + force_field: smee.TensorForceField, + config: SimulationConfig, + output_path: pathlib.Path, +): + """Simulate a system. + + Args: + system: The system to simulate. + force_field: The force field to use. + config: The simulation configuration to use. + output_path: The path at which to write the simulation trajectory. + """ + coords, box_vectors = smee.mm.generate_system_coords(system, config.gen_coords) + + beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * config.production.temperature) + + output_path.parent.mkdir(parents=True, exist_ok=True) + + with output_path.open("wb") as output: + reporter = smee.mm.TensorReporter( + output, config.production_frequency, beta, config.production.pressure + ) + smee.mm.simulate( + system, + force_field, + coords, + box_vectors, + config.equilibrate, + config.production, + [reporter], + ) + + +def _compute_averages( + phase: Phase, + key: SimulationKey, + system: smee.TensorSystem, + force_field: smee.TensorForceField, + output_dir: pathlib.Path, + cached_dir: pathlib.Path | None, +) -> dict[str, torch.Tensor]: + traj_hash = hashlib.sha256(key, usedforsecurity=False).hexdigest() + traj_name = f"{phase}-{traj_hash}-frames.msgpack" + + cached_path = None if cached_dir is None else cached_dir / traj_name + + temperature = key.temperature * openmm.unit.kelvin + pressure = None if key.pressure is None else key.pressure * openmm.unit.atmospheres + + if cached_path is not None and cached_path.exists(): + with contextlib.suppress(smee.mm._ops.NotEnoughSamplesError): + return smee.mm.reweight_ensemble_averages( + system, force_field, cached_path, temperature, pressure + ) + + output_path = output_dir / traj_name + + config = default_config(phase, key.temperature, key.pressure) + _simulate(system, force_field, config, output_path) + + return smee.mm.compute_ensemble_averages( + system, force_field, output_path, temperature, pressure + ) + + +def _predict_density( + entry: DataEntry, averages: dict[str, torch.Tensor] +) -> torch.Tensor: + assert entry["units"] == "g/mL" + return averages["density"] + + +def _predict_hvap( + entry: DataEntry, + averages_bulk: dict[str, torch.Tensor], + averages_vacuum: dict[str, torch.Tensor], + system_bulk: smee.TensorSystem, +) -> torch.Tensor: + assert entry["units"] == "kcal/mol" + + temperature = entry["temperature"] * openmm.unit.kelvin + + potential_bulk = averages_bulk["potential_energy"] / sum(system_bulk.n_copies) + potential_vacuum = averages_vacuum["potential_energy"] + + rt = (temperature * openmm.unit.MOLAR_GAS_CONSTANT_R).value_in_unit( + openmm.unit.kilocalorie_per_mole + ) + return potential_vacuum - potential_bulk + rt + + +def _predict_hmix( + entry: DataEntry, + averages_mix: dict[str, torch.Tensor], + averages_0: dict[str, torch.Tensor], + averages_1: dict[str, torch.Tensor], + system_mix: smee.TensorSystem, + system_0: smee.TensorSystem, + system_1: smee.TensorSystem, +) -> torch.Tensor: + assert entry["units"] == "kcal/mol" + + x_0 = system_mix.n_copies[0] / sum(system_mix.n_copies) + x_1 = 1.0 - x_0 + + enthalpy_mix = averages_mix["enthalpy"] / sum(system_mix.n_copies) + + enthalpy_0 = averages_0["enthalpy"] / sum(system_0.n_copies) + enthalpy_1 = averages_1["enthalpy"] / sum(system_1.n_copies) + + return enthalpy_mix - x_0 * enthalpy_0 - x_1 * enthalpy_1 + + +def _predict( + entry: DataEntry, + keys: dict[str, SimulationKey], + averages: dict[Phase, dict[SimulationKey, dict[str, torch.Tensor]]], + systems: dict[Phase, dict[SimulationKey, smee.TensorSystem]], +): + if entry["type"] == "density": + value = _predict_density(entry, averages["bulk"][keys["bulk"]]) + elif entry["type"] == "hvap": + value = _predict_hvap( + entry, + averages["bulk"][keys["bulk"]], + averages["vacuum"][keys["vacuum"]], + systems["bulk"][keys["bulk"]], + ) + elif entry["type"] == "hmix": + value = _predict_hmix( + entry, + averages["bulk"][keys["bulk"]], + averages["bulk"][keys["bulk_0"]], + averages["bulk"][keys["bulk_1"]], + systems["bulk"][keys["bulk"]], + systems["bulk"][keys["bulk_0"]], + systems["bulk"][keys["bulk_1"]], + ) + else: + raise NotImplementedError(entry["type"]) + + return value + + +def predict( + dataset: pyarrow.Table, + force_field: smee.TensorForceField, + topologies: dict[str, smee.TensorTopology], + output_dir: pathlib.Path, + cached_dir: pathlib.Path | None = None, + per_type_scales: dict[DataType, float] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Predict the properties in a dataset using molecular simulation, or by reweighting + previous simulation data. + + Args: + dataset: The dataset to predict the properties of. + force_field: The force field to use. + topologies: The topologies of the molecules present in the dataset, with keys + of mapped SMILES patterns. + output_dir: The directory to write the simulation trajectories to. + cached_dir: The (optional) directory to read cached simulation trajectories + from. + per_type_scales: The scale factor to apply to each data type. A default of 1.0 + will be used for any data type not specified. + """ + + entries: list[DataEntry] = dataset.to_pylist() + + required_simulations, entry_to_simulation = _plan_simulations(entries, topologies) + averages = { + phase: { + key: _compute_averages( + phase, key, system, force_field, output_dir, cached_dir + ) + for key, system in systems.items() + } + for phase, systems in required_simulations.items() + } + + predicted = [] + reference = [] + + per_type_scales = per_type_scales if per_type_scales is not None else {} + + for entry, keys in zip(entries, entry_to_simulation): + value = _predict(entry, keys, averages, required_simulations) + + predicted.append(value * per_type_scales.get(entry["type"], 1.0)) + reference.append( + torch.tensor(entry["value"]) * per_type_scales.get(entry["type"], 1.0) + ) + + return torch.stack(reference), torch.stack(predicted) diff --git a/descent/tests/targets/test_thermo.py b/descent/tests/targets/test_thermo.py new file mode 100644 index 0000000..d2b5c42 --- /dev/null +++ b/descent/tests/targets/test_thermo.py @@ -0,0 +1,504 @@ +import numpy +import openmm.unit +import pytest +import smee.mm +import torch + +from descent.targets.thermo import ( + DataEntry, + SimulationKey, + _compute_averages, + _convert_entry_to_system, + _plan_simulations, + _predict, + _simulate, + create_dataset, + default_config, + extract_smiles, + predict, +) + + +@pytest.fixture +def mock_density_pure() -> DataEntry: + return { + "type": "density", + "smiles_a": "CO", + "x_a": 1.0, + "smiles_b": None, + "x_b": None, + "temperature": 298.15, + "pressure": 1.0, + "value": 0.785, + "std": 0.001, + "units": "g/mL", + "source": None, + } + + +@pytest.fixture +def mock_density_binary() -> DataEntry: + return { + "type": "density", + "smiles_a": "CCO", + "x_a": 0.5, + "smiles_b": "CO", + "x_b": 0.5, + "temperature": 298.15, + "pressure": 1.0, + "value": 0.9, + "std": 0.002, + "units": "g/mL", + "source": None, + } + + +@pytest.fixture +def mock_hvap() -> DataEntry: + return { + "type": "hvap", + "smiles_a": "CCCC", + "x_a": 1.0, + "smiles_b": None, + "x_b": None, + "temperature": 298.15, + "pressure": 1.0, + "value": 1.234, + "std": 0.004, + "units": "kcal/mol", + "source": None, + } + + +@pytest.fixture +def mock_hmix() -> DataEntry: + return { + "type": "hmix", + "smiles_a": "CCO", + "x_a": 0.5, + "smiles_b": "CO", + "x_b": 0.5, + "temperature": 298.15, + "pressure": 1.0, + "value": 0.4321, + "std": 0.0025, + "units": "kcal/mol", + "source": None, + } + + +def test_create_dataset(mock_density_pure, mock_density_binary): + expected_data_entries = [mock_density_pure, mock_density_binary] + + dataset = create_dataset(*expected_data_entries) + assert len(dataset) == 2 + + data_entries = dataset.to_pylist() + + assert data_entries == pytest.approx(expected_data_entries) + + +def test_extract_smiles(mock_density_pure, mock_density_binary): + dataset = create_dataset(mock_density_pure, mock_density_binary) + smiles = extract_smiles(dataset) + + expected_smiles = [ + "[H:1][C:4]([H:3])([O:5][H:2])[C:9]([H:6])([H:7])[H:8]", + "[H:1][O:3][C:6]([H:2])([H:4])[H:5]", + ] + assert smiles == expected_smiles + + +def test_convert_entry_to_system_pure(mock_density_pure, mocker): + topology = mocker.Mock() + topologies = {"CO": topology} + + n_mols = 123 + + key, system = _convert_entry_to_system(mock_density_pure, topologies, n_mols) + + assert key == ( + ("CO",), + (123,), + mock_density_pure["temperature"], + mock_density_pure["pressure"], + ) + + assert system.topologies == [topology] + assert system.n_copies == [n_mols] + assert system.is_periodic is True + + +def test_convert_entry_to_system_binary(mock_density_binary, mocker): + topology_a = mocker.Mock() + topology_b = mocker.Mock() + topologies = {"CO": topology_a, "CCO": topology_b} + + n_mols = 128 + + key, system = _convert_entry_to_system(mock_density_binary, topologies, n_mols) + + assert key == ( + ("CCO", "CO"), + (n_mols // 2, n_mols // 2), + mock_density_binary["temperature"], + mock_density_binary["pressure"], + ) + + assert system.topologies == [topology_b, topology_a] + assert system.n_copies == [n_mols // 2, n_mols // 2] + assert system.is_periodic is True + + +@pytest.mark.parametrize( + "phase, pressure, expected_n_mols", [("bulk", 1.23, 256), ("vacuum", None, 1)] +) +def test_default_config(phase, pressure, expected_n_mols): + expected_temperature = 298.15 + + config = default_config(phase, expected_temperature, pressure) + + assert config.max_mols == expected_n_mols + + assert ( + config.production.temperature.value_in_unit(openmm.unit.kelvin) + == expected_temperature + ) + + if pressure is None: + assert config.production.pressure is None + else: + assert ( + config.production.pressure.value_in_unit(openmm.unit.atmosphere) == pressure + ) + + +def test_plan_simulations( + mock_density_pure, mock_density_binary, mock_hvap, mock_hmix, mocker +): + topology_co = mocker.Mock() + topology_cco = mocker.Mock() + topology_cccc = mocker.Mock() + + topologies = {"CO": topology_co, "CCO": topology_cco, "CCCC": topology_cccc} + + required_simulations, entry_to_simulation = _plan_simulations( + [mock_density_pure, mock_density_binary, mock_hvap, mock_hmix], topologies + ) + + assert sorted(required_simulations) == ["bulk", "vacuum"] + + expected_vacuum_key = SimulationKey(("CCCC",), (1,), mock_hvap["temperature"], None) + assert sorted(required_simulations["vacuum"]) == [expected_vacuum_key] + assert required_simulations["vacuum"][expected_vacuum_key].n_copies == [1] + assert required_simulations["vacuum"][expected_vacuum_key].topologies == [ + topology_cccc + ] + + expected_cccc_key = SimulationKey( + ("CCCC",), + (256,), + mock_hvap["temperature"], + mock_hvap["pressure"], + ) + expected_co_key = SimulationKey( + ("CO",), + (256,), + mock_density_pure["temperature"], + mock_density_pure["pressure"], + ) + expected_cco_key = SimulationKey( + ("CCO",), + (256,), + mock_density_binary["temperature"], + mock_density_binary["pressure"], + ) + expected_cco_co_key = SimulationKey( + ("CCO", "CO"), + (128, 128), + mock_density_binary["temperature"], + mock_density_binary["pressure"], + ) + + expected_bulk_keys = [ + expected_cccc_key, + expected_co_key, + expected_cco_key, + expected_cco_co_key, + ] + + assert sorted(required_simulations["bulk"]) == sorted(expected_bulk_keys) + + assert required_simulations["bulk"][expected_cccc_key].n_copies == [256] + assert required_simulations["bulk"][expected_cccc_key].topologies == [topology_cccc] + + assert required_simulations["bulk"][expected_cco_key].n_copies == [256] + assert required_simulations["bulk"][expected_cco_key].topologies == [topology_cco] + + assert required_simulations["bulk"][expected_co_key].n_copies == [256] + assert required_simulations["bulk"][expected_co_key].topologies == [topology_co] + + assert required_simulations["bulk"][expected_cco_co_key].n_copies == [128, 128] + assert required_simulations["bulk"][expected_cco_co_key].topologies == [ + topology_cco, + topology_co, + ] + + assert entry_to_simulation == [ + {"bulk": expected_co_key}, + {"bulk": expected_cco_co_key}, + {"bulk": expected_cccc_key, "vacuum": expected_vacuum_key}, + { + "bulk": expected_cco_co_key, + "bulk_0": expected_cco_key, + "bulk_1": expected_co_key, + }, + ] + + +def test_simulation(tmp_cwd, mocker): + coords = numpy.zeros((1, 3)) * openmm.unit.angstrom + box_vectors = numpy.eye(3) * openmm.unit.angstrom + + expected_temperature = 298.15 + config = default_config("bulk", expected_temperature, 1.0) + + mock_system = mocker.MagicMock() + mock_ff = mocker.MagicMock() + + mock_gen_coords = mocker.patch( + "smee.mm.generate_system_coords", + autospec=True, + return_value=(coords, box_vectors), + ) + mock_simulate = mocker.patch("smee.mm.simulate", autospec=True) + + spied_reporter = mocker.spy(smee.mm, "TensorReporter") + + expected_output = tmp_cwd / "frames.msgpack" + _simulate(mock_system, mock_ff, config, expected_output) + + mock_gen_coords.assert_called_once_with(mock_system, config.gen_coords) + + mock_simulate.assert_called_once_with( + mock_system, + mock_ff, + coords, + box_vectors, + config.equilibrate, + config.production, + [mocker.ANY], + ) + assert expected_output.exists() + + expected_beta = 1.0 / ( + openmm.unit.MOLAR_GAS_CONSTANT_R * expected_temperature * openmm.unit.kelvin + ) + + spied_reporter.assert_called_once_with( + mocker.ANY, + config.production_frequency, + expected_beta, + config.production.pressure, + ) + + +def test_compute_averages_reweighted(tmp_cwd, mocker): + mock_result = mocker.Mock() + mock_reweight = mocker.patch( + "smee.mm.reweight_ensemble_averages", autospec=True, return_value=mock_result + ) + + expected_hash = "1234567890abcdef" + + mock_hash = mocker.MagicMock() + mock_hash.hexdigest.return_value = expected_hash + + mocker.patch("hashlib.sha256", autospec=True, return_value=mock_hash) + + phase = "vacuum" + key = SimulationKey(("CCCC",), (1,), 298.15, None) + + mock_system = mocker.Mock() + mock_ff = mocker.Mock() + + cached_dir = tmp_cwd / "cached" + cached_dir.mkdir() + + expected_path = cached_dir / f"{phase}-{expected_hash}-frames.msgpack" + expected_path.touch() + + result = _compute_averages(phase, key, mock_system, mock_ff, tmp_cwd, cached_dir) + assert result == mock_result + + mock_reweight.assert_called_once_with( + mock_system, mock_ff, expected_path, 298.15 * openmm.unit.kelvin, None + ) + + +def test_compute_averages_simulated(tmp_cwd, mocker): + mock_result = mocker.Mock() + mocker.patch( + "smee.mm.reweight_ensemble_averages", + autospec=True, + side_effect=smee.mm._ops.NotEnoughSamplesError(), + ) + mock_simulate = mocker.patch("descent.targets.thermo._simulate", autospec=True) + mock_compute = mocker.patch( + "smee.mm.compute_ensemble_averages", autospec=True, return_value=mock_result + ) + + expected_hash = "1234567890abcdef" + + mock_hash = mocker.MagicMock() + mock_hash.hexdigest.return_value = expected_hash + + mocker.patch("hashlib.sha256", autospec=True, return_value=mock_hash) + + phase = "vacuum" + key = SimulationKey(("CCCC",), (1,), 298.15, None) + + mock_system = mocker.Mock() + mock_ff = mocker.Mock() + + cached_dir = tmp_cwd / "cached" + cached_dir.mkdir() + (cached_dir / f"{phase}-{expected_hash}-frames.msgpack").touch() + + expected_path = tmp_cwd / f"{phase}-{expected_hash}-frames.msgpack" + expected_path.touch() + + result = _compute_averages(phase, key, mock_system, mock_ff, tmp_cwd, cached_dir) + assert result == mock_result + + mock_simulate.assert_called_once_with( + mock_system, mock_ff, mocker.ANY, expected_path + ) + mock_compute.assert_called_once_with( + mock_system, mock_ff, expected_path, 298.15 * openmm.unit.kelvin, None + ) + + +def test_predict_density(mock_density_pure, mocker): + topologies = {"CO": mocker.Mock()} + key, system = _convert_entry_to_system(mock_density_pure, topologies, 123) + + expected_result = mocker.Mock() + + averages = {"bulk": {key: {"density": expected_result}}} + systems = {"bulk": {key: system}} + + result = _predict(mock_density_pure, {"bulk": key}, averages, systems) + assert result == expected_result + + +def test_predict_hvap(mock_hvap, mocker): + topologies = {"CCCC": mocker.Mock()} + + n_mols = 123 + + key_bulk, system_bulk = _convert_entry_to_system(mock_hvap, topologies, n_mols) + key_vaccum = SimulationKey(("CCCC",), (1,), mock_hvap["temperature"], None) + + system_vacuum = smee.TensorSystem([topologies["CCCC"]], [1], False) + + potential_bulk = torch.tensor([7.0]) + potential_vacuum = torch.tensor([3.0]) + + averages = { + "bulk": {key_bulk: {"potential_energy": potential_bulk}}, + "vacuum": {key_vaccum: {"potential_energy": potential_vacuum}}, + } + systems = {"bulk": {key_bulk: system_bulk}, "vacuum": {key_vaccum: system_vacuum}} + keys = {"bulk": key_bulk, "vacuum": key_vaccum} + + rt = ( + mock_hvap["temperature"] * openmm.unit.kelvin * openmm.unit.MOLAR_GAS_CONSTANT_R + ).value_in_unit(openmm.unit.kilocalorie_per_mole) + + expected = potential_vacuum - potential_bulk / n_mols + rt + + result = _predict(mock_hvap, keys, averages, systems) + assert result == pytest.approx(expected) + + +def test_predict_hmix(mock_hmix, mocker): + topologies = {"CO": mocker.Mock(), "CCO": mocker.Mock()} + + n_mols = 100 + + key_bulk, system_bulk = _convert_entry_to_system(mock_hmix, topologies, n_mols) + key_0 = SimulationKey( + ("CCO",), (n_mols,), mock_hmix["temperature"], mock_hmix["pressure"] + ) + key_1 = SimulationKey( + ("CO",), (n_mols,), mock_hmix["temperature"], mock_hmix["pressure"] + ) + + system_0 = smee.TensorSystem([topologies["CCO"]], [n_mols], False) + system_1 = smee.TensorSystem([topologies["CO"]], [n_mols], False) + + enthalpy_bulk = torch.tensor([16.0]) + enthalpy_0 = torch.tensor([4.0]) + enthalpy_1 = torch.tensor([3.0]) + + averages = { + "bulk": { + key_bulk: {"enthalpy": enthalpy_bulk}, + key_0: {"enthalpy": enthalpy_0}, + key_1: {"enthalpy": enthalpy_1}, + }, + } + systems = {"bulk": {key_bulk: system_bulk, key_0: system_0, key_1: system_1}} + keys = {"bulk": key_bulk, "bulk_0": key_0, "bulk_1": key_1} + + expected = ( + enthalpy_bulk / n_mols - 0.5 * enthalpy_0 / n_mols - 0.5 * enthalpy_1 / n_mols + ) + + result = _predict(mock_hmix, keys, averages, systems) + assert result == pytest.approx(expected) + + +def test_predict(tmp_cwd, mock_density_pure, mocker): + dataset = create_dataset(mock_density_pure) + + mock_topologies = {"CO": mocker.Mock()} + mock_ff = mocker.Mock() + + mock_density = torch.tensor(123.0) + + mock_compute = mocker.patch( + "descent.targets.thermo._compute_averages", + autospec=True, + return_value={"density": mock_density}, + ) + + mock_scale = 3.0 + + y_ref, y_pred = predict( + dataset, mock_ff, mock_topologies, tmp_cwd, None, {"density": mock_scale} + ) + + mock_compute.assert_called_once_with( + "bulk", + SimulationKey( + ("CO",), + (256,), + mock_density_pure["temperature"], + mock_density_pure["pressure"], + ), + mocker.ANY, + mock_ff, + tmp_cwd, + None, + ) + + expected_y_ref = torch.tensor([mock_density_pure["value"] * mock_scale]) + expected_y_pred = torch.tensor([mock_density * mock_scale]) + + assert y_ref.shape == expected_y_ref.shape + assert torch.allclose(y_ref, expected_y_ref) + + assert y_pred.shape == expected_y_pred.shape + assert torch.allclose(y_pred, expected_y_pred)