diff --git a/descent/data/__init__.py b/descent/data/__init__.py new file mode 100644 index 0000000..342c428 --- /dev/null +++ b/descent/data/__init__.py @@ -0,0 +1,3 @@ +from descent.data.data import Dataset, DatasetEntry + +__all__ = [DatasetEntry, Dataset] diff --git a/descent/data/data.py b/descent/data/data.py new file mode 100644 index 0000000..e0c5903 --- /dev/null +++ b/descent/data/data.py @@ -0,0 +1,78 @@ +import abc +from typing import Generic, Iterator, Sequence, TypeVar, Union + +import torch.utils.data +from openff.interchange.components.interchange import Interchange +from smirnoffee.smirnoff import vectorize_system + +from descent.models import ParameterizationModel +from descent.models.models import VectorizedSystem + +T_co = TypeVar("T_co", covariant=True) + + +class DatasetEntry(abc.ABC): + """The base class for storing labels associated with an input datum, such as + an OpenFF interchange object or an Espaloma graph model.""" + + @property + def model_input(self) -> VectorizedSystem: + return self._model_input + + def __init__(self, model_input: Union[Interchange]): + """ + + Args: + model_input: The input that will be passed to the model being trained in + order to yield a vectorized view of a parameterised molecule. If the + input is an interchange object it will be vectorised prior to being + used as a model input. + """ + + self._model_input = ( + model_input + if not isinstance(model_input, Interchange) + else vectorize_system(model_input) + ) + + @abc.abstractmethod + def evaluate_loss(self, model: ParameterizationModel, **kwargs) -> torch.Tensor: + """Evaluates the contribution to the total loss function of the data stored + in this entry using a specified model. + + Args: + model: The model that will return vectorized view of a parameterised + molecule. + + Returns: + The loss contribution of this entry. + """ + raise NotImplementedError() + + def __call__(self, model: ParameterizationModel, **kwargs) -> torch.Tensor: + """Evaluate the objective using a specified model. + + Args: + model: The model that will return vectorized view of a parameterised + molecule. + + Returns: + The loss contribution of this entry. + """ + return self.evaluate_loss(model, **kwargs) + + +class Dataset(torch.utils.data.IterableDataset[T_co], Generic[T_co]): + r"""An class representing a :class:`Dataset`.""" + + def __init__(self, entries: Sequence): + self._entries = entries + + def __getitem__(self, index: int) -> T_co: + return self._entries[index] + + def __iter__(self) -> Iterator[T_co]: + return self._entries.__iter__() + + def __len__(self) -> int: + return len(self._entries) diff --git a/descent/objectives/energy.py b/descent/data/energy.py similarity index 72% rename from descent/objectives/energy.py rename to descent/data/energy.py index d0efe0f..223d994 100644 --- a/descent/objectives/energy.py +++ b/descent/data/energy.py @@ -3,10 +3,8 @@ from multiprocessing import Pool from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -import dill import torch from openff.interchange.components.interchange import Interchange -from openff.interchange.models import PotentialKey from openff.toolkit.topology import Molecule, Topology from openff.toolkit.typing.engines.smirnoff import ForceField from openff.units import unit @@ -15,16 +13,15 @@ detect_internal_coordinates, ) from smirnoffee.potentials.potentials import evaluate_vectorized_system_energy -from smirnoffee.smirnoff import vectorize_system from torch._vmap_internals import vmap from torch.autograd import grad from tqdm import tqdm from typing_extensions import Literal from descent import metrics, transforms +from descent.data import Dataset, DatasetEntry from descent.metrics import LossMetric from descent.models import ParameterizationModel -from descent.objectives import ObjectiveContribution from descent.transforms import LossTransform if TYPE_CHECKING: @@ -39,49 +36,19 @@ ) _INVERSE_BOHR_TO_ANGSTROM = (1.0 * unit.bohr ** -1).to(unit.angstrom ** -1).magnitude -_LAMBDA_FIELDS = [ - "energy_transforms", - "energy_metric", - "gradient_transforms", - "gradient_metric", - "hessian_transforms", - "hessian_metric", -] - - -class EnergyObjective(ObjectiveContribution): - """An objective term which measures the deviations of a set of MM - energies, gradients, and hessians from a set of reference (usually QM) values. - """ - - @property - def parameter_ids(self) -> List[Tuple[str, PotentialKey, str]]: - - return sorted( - { - (handler_type, potential_key, attribute) - for (handler_type, _), (_, _, parameters) in self._system.items() - for potential_key, attributes in parameters - for attribute in attributes - }, - key=lambda x: x[0], - reverse=True, - ) + +class EnergyEntry(DatasetEntry): + """A object that stores reference energy, gradient and hessian labels for a molecule + in multiple conforms.""" def __init__( self, - system: Interchange, + model_input: Union[Interchange], conformers: torch.Tensor, reference_energies: Optional[torch.Tensor] = None, - energy_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - energy_metric: Optional[LossMetric] = None, reference_gradients: Optional[torch.Tensor] = None, - gradient_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - gradient_metric: Optional[LossMetric] = None, gradient_coordinate_system: Literal["cartesian", "ric"] = "cartesian", reference_hessians: Optional[torch.Tensor] = None, - hessian_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - hessian_metric: Optional[LossMetric] = None, hessian_coordinate_system: Literal["cartesian", "ric"] = "cartesian", ): """ @@ -89,39 +56,27 @@ def __init__( Args: reference_energies: The reference energies with shape=(n_conformers, 1) and units of [kJ / mol]. - energy_transforms: Transforms to apply to the QM and MM energies - before computing the loss metric. - energy_metric: The loss metric (e.g. MSE) to compute from the QM and MM - energies. reference_gradients: The reference gradients with shape=(n_conformers, n_atoms, 3) and units of [kJ / mol / A]. - gradient_transforms: Transforms to apply to the QM and MM gradients - before computing the loss metric. - gradient_metric: The loss metric (e.g. MSE) to compute from the QM and MM - gradients. gradient_coordinate_system: The coordinate system to project the QM and MM gradients to before computing the loss metric. reference_hessians: The reference gradients with shape=(n_conformers, n_atoms * 3, n_atoms * 3) and units of [kJ / mol / A^2]. - hessian_transforms: Transforms to apply to the QM and MM hessians - before computing the loss metric. - hessian_metric: The loss metric (e.g. MSE) to compute from the QM and MM - hessians hessian_coordinate_system: The coordinate system to project the QM and MM hessians to before computing the loss metric. """ + super(EnergyEntry, self).__init__(model_input) + self._validate_inputs( conformers, reference_energies, reference_gradients, reference_hessians, - system, + model_input, ) - self._system = vectorize_system(system) - self._conformers = conformers internal_coordinate_systems = { @@ -129,65 +84,31 @@ def __init__( } self._inverse_b_matrices = { coordinate_system.lower(): self._initialize_internal_coordinates( - coordinate_system, system.topology, reference_hessians is not None + coordinate_system, model_input.topology, reference_hessians is not None ) for coordinate_system in internal_coordinate_systems if coordinate_system is not None and coordinate_system.lower() != "cartesian" } - if reference_energies is not None: - - energy_transforms = ( - transforms.relative() - if energy_transforms is None - else energy_transforms - ) - energy_metric = metrics.mse() if energy_metric is None else energy_metric - - reference_energies = transforms.transform_tensor( - reference_energies, energy_transforms - ) - self._reference_energies = reference_energies - self._energy_transforms = energy_transforms - self._energy_metric = energy_metric if reference_hessians is not None: - ( - hessian_metric, - hessian_transforms, - reference_hessians, - ) = self._initialize_reference_hessians( - reference_hessians, - hessian_transforms, - hessian_metric, - hessian_coordinate_system, - reference_gradients, + reference_hessians = self._project_hessians( + reference_hessians, reference_gradients, hessian_coordinate_system ) self._reference_hessians = reference_hessians - self._hessian_transforms = hessian_transforms - self._hessian_metric = hessian_metric self._hessian_coordinate_system = hessian_coordinate_system if reference_gradients is not None: - ( - gradient_metric, - gradient_transforms, - reference_gradients, - ) = self._initialize_reference_gradients( - reference_gradients, - gradient_transforms, - gradient_metric, - gradient_coordinate_system, + reference_gradients = self._project_gradients( + reference_gradients, gradient_coordinate_system ) self._reference_gradients = reference_gradients - self._gradient_transforms = gradient_transforms - self._gradient_metric = gradient_metric self._gradient_coordinate_system = gradient_coordinate_system @classmethod @@ -418,37 +339,6 @@ def get_vjp(v): torch.stack(b_matrix_gradients), ) - def _initialize_reference_gradients( - self, - reference_gradients: torch.Tensor, - gradient_transforms: Optional[Union[LossTransform, List[LossTransform]]], - gradient_metric: Optional[LossMetric], - gradient_coordinate_system: Literal["cartesian", "ric"], - ) -> Tuple[List[LossTransform], LossMetric, torch.Tensor]: - """Applies the relevant transforms and projects to the reference gradients and - populates missing transforms (identity) and metrics (MSE). - - Returns: - The the gradient transforms, metric and transformed reference values. - """ - - gradient_transforms = ( - [transforms.identity()] - if gradient_transforms is None - else gradient_transforms - ) - gradient_metric = ( - metrics.mse(dim=()) if gradient_metric is None else gradient_metric - ) - - reference_gradients = transforms.transform_tensor( - self._project_gradients(reference_gradients, gradient_coordinate_system), - gradient_transforms, - ) - - # noinspection PyTypeChecker - return gradient_metric, gradient_transforms, reference_gradients - def _project_gradients( self, gradients: torch.Tensor, coordinate_system: Literal["cartesian", "ric"] ) -> torch.Tensor: @@ -469,41 +359,6 @@ def _project_gradients( return gradients - def _initialize_reference_hessians( - self, - reference_hessians: torch.Tensor, - hessian_transforms: Optional[Union[LossTransform, List[LossTransform]]], - hessian_metric: Optional[LossMetric], - hessian_coordinate_system: Literal["cartesian", "ric"], - reference_gradients: torch.Tensor, - ) -> Tuple[List[LossTransform], LossMetric, torch.Tensor]: - """Applies the relevant transforms and projects to the reference hessians and - populates missing transforms (identity) and metrics (MSE). - - Returns: - The the hessians transforms, metric and transformed reference values. - """ - - hessian_transforms = ( - [transforms.identity()] - if hessian_transforms is None - else hessian_transforms - ) - - hessian_metric = ( - metrics.mse(dim=()) if hessian_metric is None else hessian_metric - ) - - reference_hessians = transforms.transform_tensor( - self._project_hessians( - reference_hessians, reference_gradients, hessian_coordinate_system - ), - hessian_transforms, - ) - - # noinspection PyTypeChecker - return hessian_metric, hessian_transforms, reference_hessians - def _project_hessians( self, hessians: torch.Tensor, @@ -556,14 +411,14 @@ def _evaluate_mm_energies( compute_hessians: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Evaluate the perturbed MM energies, gradients and hessians of the system - associated with this term. + associated with this entry. Args: model: The model that will return vectorized view of a parameterised molecule. """ - vectorized_system = model.forward(self._system) + vectorized_system = model.forward(self._model_input) conformers = self._conformers.detach().clone().requires_grad_() mm_energies, mm_gradients, mm_hessians = [], [], [] @@ -605,8 +460,68 @@ def _evaluate_mm_energies( None if not compute_hessians else torch.stack(mm_hessians), ) - def evaluate(self, model: ParameterizationModel) -> torch.Tensor: + @staticmethod + def _evaluate_loss_contribution( + reference_tensor: torch.Tensor, + computed_tensor: torch.Tensor, + data_transforms: Union[LossTransform, List[LossTransform]], + data_metric: LossMetric, + ) -> torch.Tensor: + """Computes the loss contribution for a set of computed and reference labels. + + Args: + reference_tensor: The reference tensor. + computed_tensor: The computed tensor. + data_transforms: Transforms to apply to the reference and computed tensors. + data_metric: The loss metric (e.g. MSE) to compute. + Returns: + The loss contribution. + """ + + transformed_reference_tensor = transforms.transform_tensor( + reference_tensor, data_transforms + ) + transformed_computed_tensor = transforms.transform_tensor( + computed_tensor, data_transforms + ) + + return data_metric(transformed_computed_tensor, transformed_reference_tensor) + + def evaluate_loss( + self, + model: ParameterizationModel, + energy_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, + energy_metric: Optional[LossMetric] = None, + gradient_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, + gradient_metric: Optional[LossMetric] = None, + hessian_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, + hessian_metric: Optional[LossMetric] = None, + ) -> torch.Tensor: + """ + + Args: + model: The model that will return vectorized view of a parameterised + molecule. + energy_transforms: Transforms to apply to the QM and MM energies + before computing the loss metric. By default + ``descent.transforms.relative(index=0)`` is used if no value is provided. + energy_metric: The loss metric (e.g. MSE) to compute from the QM and MM + energies. By default ``descent.metrics.mse()`` is used if no value is + provided. + gradient_transforms: Transforms to apply to the QM and MM gradients + before computing the loss metric. By default + ``descent.transforms.identity()`` is used if no value is provided. + gradient_metric: The loss metric (e.g. MSE) to compute from the QM and MM + gradients. By default ``descent.metrics.mse()`` is used if no value is + provided. + hessian_transforms: Transforms to apply to the QM and MM hessians + before computing the loss metric. By default + ``descent.transforms.identity()`` is used if no value is provided. + hessian_metric: The loss metric (e.g. MSE) to compute from the QM and MM + hessians. By default ``descent.metrics.mse()`` is used if no value is + provided. + """ mm_energies, mm_gradients, mm_hessians = self._evaluate_mm_energies( model, compute_gradients=( @@ -620,37 +535,45 @@ def evaluate(self, model: ParameterizationModel) -> torch.Tensor: if self._reference_energies is not None: - transformed_mm_energies = transforms.transform_tensor( - mm_energies, self._energy_transforms - ) - loss += self._energy_metric( - transformed_mm_energies, self._reference_energies + loss += self._evaluate_loss_contribution( + self._reference_energies, + mm_energies, + energy_transforms + if energy_transforms is not None + else transforms.relative(index=0), + energy_metric if energy_metric is not None else metrics.mse(), ) if self._reference_gradients is not None: - transformed_mm_gradients = transforms.transform_tensor( + loss += self._evaluate_loss_contribution( + self._reference_gradients, self._project_gradients(mm_gradients, self._gradient_coordinate_system), - self._gradient_transforms, - ) - loss += self._gradient_metric( - transformed_mm_gradients, self._reference_gradients + gradient_transforms + if gradient_transforms is not None + else transforms.relative(index=0), + gradient_metric if gradient_metric is not None else metrics.mse(), ) if self._reference_hessians is not None: - transformed_mm_hessians = transforms.transform_tensor( + loss += self._evaluate_loss_contribution( + self._reference_hessians, self._project_hessians( mm_hessians, mm_gradients, self._hessian_coordinate_system ), - self._hessian_transforms, - ) - loss += self._hessian_metric( - transformed_mm_hessians, self._reference_hessians + hessian_transforms + if hessian_transforms is not None + else transforms.identity(), + hessian_metric if hessian_metric is not None else metrics.mse(), ) return loss + +class EnergyDataset(Dataset[EnergyEntry]): + """A data set that stores reference energy, gradient and hessian labels.""" + @classmethod def _retrieve_gradient_and_hessians( cls, @@ -731,19 +654,14 @@ def _from_grouped_results( ], force_field: ForceField, **kwargs, - ) -> "EnergyObjective": + ) -> "EnergyEntry": cmiles, conformers, qc_energies, qc_gradients, qc_hessians = grouped_data molecule = Molecule.from_mapped_smiles(cmiles, allow_undefined_stereo=True) system = Interchange.from_smirnoff(force_field, molecule.to_topology()) - # We need to un-dill any potential lambda functions as the default - # multiprocessing pickler cannot handle these by default. - for field_name in _LAMBDA_FIELDS: - kwargs[field_name] = dill.loads(kwargs[field_name], None) - - return EnergyObjective( + return EnergyEntry( system, conformers, reference_energies=qc_energies, @@ -758,49 +676,31 @@ def from_optimization_results( optimization_results: "OptimizationResultCollection", initial_force_field: ForceField, include_energies: bool = True, - energy_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - energy_metric: Optional[LossMetric] = None, include_gradients: bool = False, - gradient_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - gradient_metric: Optional[LossMetric] = None, gradient_coordinate_system: Literal["cartesian", "ric"] = "cartesian", include_hessians: bool = False, - hessian_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - hessian_metric: Optional[LossMetric] = None, hessian_coordinate_system: Literal["cartesian", "ric"] = "cartesian", n_processes: int = 1, verbose: bool = True, - ) -> List["EnergyObjective"]: - """Creates a list of energy objective contribution terms (one per unique - molecule) from the **final** structures a set of QC optimization results. + ) -> "EnergyDataset": + """Creates a dataset of energy entries (one per unique molecule) from the + **final** structures a set of QC optimization results. Args: optimization_results: The collection of result records. initial_force_field: The force field that will be trained. include_energies: Whether to include energies. - energy_transforms: Transforms to apply to the QM and MM energies - before computing the loss metric. - energy_metric: The loss metric (e.g. MSE) to compute from the QM and MM - energies. include_gradients: Whether to include gradients. - gradient_transforms: Transforms to apply to the QM and MM gradients - before computing the loss metric. - gradient_metric: The loss metric (e.g. MSE) to compute from the QM and MM - gradients. gradient_coordinate_system: The coordinate system to project the QM and MM gradients to before computing the loss metric. include_hessians: Whether to include hessians. - hessian_transforms: Transforms to apply to the QM and MM hessians - before computing the loss metric. - hessian_metric: The loss metric (e.g. MSE) to compute from the QM and MM - hessians hessian_coordinate_system: The coordinate system to project the QM and MM hessians to before computing the loss metric. n_processes: The number of processes to parallelize this function across. verbose: Whether to log progress to the terminal. Returns: - A list of the energy objective terms. + A dataset of the energy entries. """ from simtk import unit as simtk_unit @@ -865,35 +765,15 @@ def from_optimization_results( with Pool(n_processes) as pool: - # We need to dill any potential lambda functions as the default - # multiprocessing pickler cannot handle these by default. - contributions = list( + entries = list( tqdm( pool.imap( functools.partial( cls._from_grouped_results, force_field=initial_force_field, - energy_transforms=dill.dumps( - energy_transforms if include_energies else None - ), - energy_metric=dill.dumps( - energy_metric if include_energies else None - ), - gradient_transforms=dill.dumps( - gradient_transforms if include_gradients else None - ), - gradient_metric=dill.dumps( - gradient_metric if include_gradients else None - ), gradient_coordinate_system=gradient_coordinate_system if include_gradients else None, - hessian_transforms=dill.dumps( - hessian_transforms if include_hessians else None - ), - hessian_metric=dill.dumps( - hessian_metric if include_hessians else None - ), hessian_coordinate_system=hessian_coordinate_system if include_hessians else None, @@ -902,30 +782,8 @@ def from_optimization_results( ), total=len(result_tensors), disable=not verbose, - desc="Building energy contribution objects.", + desc="Building entries.", ) ) - return contributions - - def __getstate__(self): - """A custom pickle function that ensures any lambda functions are correctly - serialized.""" - - return_value = {**self.__dict__} - - for field_name in _LAMBDA_FIELDS: - return_value[f"_{field_name}"] = dill.dumps(return_value[f"_{field_name}"]) - - return return_value - - def __setstate__(self, state): - """A custom pickle function that ensures any lambda functions are correctly - serialized.""" - - state = {**state} - - for field_name in _LAMBDA_FIELDS: - state[f"_{field_name}"] = dill.loads(state[f"_{field_name}"]) - - self.__dict__.update(state) + return cls(entries) diff --git a/descent/objectives/__init__.py b/descent/objectives/__init__.py deleted file mode 100644 index f29332d..0000000 --- a/descent/objectives/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from descent.objectives.objectives import ObjectiveContribution - -__all__ = [ObjectiveContribution] diff --git a/descent/objectives/objectives.py b/descent/objectives/objectives.py deleted file mode 100644 index 5040dce..0000000 --- a/descent/objectives/objectives.py +++ /dev/null @@ -1,44 +0,0 @@ -import abc -from typing import List, Tuple - -import torch -from openff.interchange.models import PotentialKey - -from descent.models import ParameterizationModel - - -class ObjectiveContribution(abc.ABC): - """The base class for contributions to a total objective ( / loss) function.""" - - @property - @abc.abstractmethod - def parameter_ids(self) -> List[Tuple[str, PotentialKey, str]]: - """The ids of the parameters that are exercised by this contribution to the - total objective function. - """ - raise NotImplementedError() - - @abc.abstractmethod - def evaluate(self, model: ParameterizationModel) -> torch.Tensor: - """Evaluate the objective using a specified model. - - Args: - model: The model that will return vectorized view of a parameterised - molecule. - - Returns: - The loss contribution of this term. - """ - raise NotImplementedError() - - def __call__(self, model: ParameterizationModel) -> torch.Tensor: - """Evaluate the objective using a specified model. - - Args: - model: The model that will return vectorized view of a parameterised - molecule. - - Returns: - The loss contribution of this term. - """ - return self.evaluate(model) diff --git a/descent/tests/objectives/__init__.py b/descent/tests/data/__init__.py similarity index 100% rename from descent/tests/objectives/__init__.py rename to descent/tests/data/__init__.py diff --git a/descent/tests/data/test_data.py b/descent/tests/data/test_data.py new file mode 100644 index 0000000..f226ae9 --- /dev/null +++ b/descent/tests/data/test_data.py @@ -0,0 +1,48 @@ +import pytest + +from descent.data import Dataset, DatasetEntry +from descent.models.smirnoff import SMIRNOFFModel +from descent.tests.mocking.systems import generate_mock_hcl_system + + +class DummyEntry(DatasetEntry): + def evaluate_loss(self, model, **kwargs): + pass + + +class DummyDataset(Dataset[DummyEntry]): + pass + + +def test_call(monkeypatch): + + evaluate_called = False + evaluate_kwargs = {} + + class LocalEntry(DatasetEntry): + def evaluate_loss(self, model, **kwargs): + nonlocal evaluate_called + evaluate_called = True + evaluate_kwargs.update(kwargs) + + LocalEntry(generate_mock_hcl_system())(SMIRNOFFModel([], None), a="a", b=2) + + assert evaluate_called + assert evaluate_kwargs == {"a": "a", "b": 2} + + +def test_dataset(): + + model_input = generate_mock_hcl_system() + + dataset = DummyDataset(entries=[DummyEntry(model_input), DummyEntry(model_input)]) + + assert dataset[0] is not None + assert dataset[1] is not None + + with pytest.raises(IndexError): + assert dataset[2] + + assert len(dataset) == 2 + + assert all(isinstance(entry.model_input, dict) for entry in dataset) diff --git a/descent/tests/objectives/test_energy.py b/descent/tests/data/test_energy.py similarity index 65% rename from descent/tests/objectives/test_energy.py rename to descent/tests/data/test_energy.py index 4e32970..3074ae9 100644 --- a/descent/tests/objectives/test_energy.py +++ b/descent/tests/data/test_energy.py @@ -1,18 +1,16 @@ import copy from typing import Tuple -import dill import numpy import pytest import torch -from openff.interchange.models import PotentialKey from openff.toolkit.topology import Molecule from openff.toolkit.typing.engines.smirnoff import ForceField from smirnoffee.geometry.internal import detect_internal_coordinates from descent import metrics, transforms +from descent.data.energy import EnergyDataset, EnergyEntry from descent.models.smirnoff import SMIRNOFFModel -from descent.objectives.energy import EnergyObjective from descent.tests.geometric import geometric_project_derivatives from descent.tests.mocking.qcdata import mock_optimization_result_collection from descent.tests.mocking.systems import generate_mock_hcl_system @@ -67,24 +65,6 @@ def mock_hcl_mm_values() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -def test_parameter_ids(mock_hcl_conformers, mock_hcl_system): - - term = EnergyObjective( - mock_hcl_system, - mock_hcl_conformers, - reference_energies=torch.zeros((len(mock_hcl_conformers), 1)), - ) - - assert {*term.parameter_ids} == { - ("Bonds", PotentialKey(id="[#1:1]-[#17:2]", associated_handler="Bonds"), "k"), - ( - "Bonds", - PotentialKey(id="[#1:1]-[#17:2]", associated_handler="Bonds"), - "length", - ), - } - - def test_initialize_internal_coordinates(): """Test that the internal coordinate matrices can be correctly constructed and padding when different conformers of a molecule have different numbers of internal @@ -102,10 +82,10 @@ def test_initialize_internal_coordinates(): requires_grad=True, ) - objective = EnergyObjective.__new__(EnergyObjective) - objective._conformers = conformers + entry = EnergyEntry.__new__(EnergyEntry) + entry._conformers = conformers - b_matrix, g_inverse, b_matrix_gradient = objective._initialize_internal_coordinates( + b_matrix, g_inverse, b_matrix_gradient = entry._initialize_internal_coordinates( "ric", topology, True ) @@ -161,7 +141,7 @@ def test_gradient_hessian_projection(ethanol, ethanol_conformer, ethanol_system) reference_hessians, ) - ethanol_objective = EnergyObjective( + entry = EnergyEntry( ethanol_system, ethanol_conformer.reshape(1, len(ethanol_conformer), 3), reference_gradients=reference_gradients, @@ -170,8 +150,8 @@ def test_gradient_hessian_projection(ethanol, ethanol_conformer, ethanol_system) hessian_coordinate_system="ric", ) - actual_gradiant = ethanol_objective._reference_gradients.numpy() - actual_hessian = ethanol_objective._reference_hessians.numpy() + actual_gradiant = entry._reference_gradients.numpy() + actual_hessian = entry._reference_hessians.numpy() assert numpy.allclose( actual_gradiant.reshape(expected_gradiant.shape), expected_gradiant, atol=1.0e-3 @@ -191,11 +171,9 @@ def test_evaluate_mm_energies( mock_hcl_mm_values, ): - energy_objective = EnergyObjective( - mock_hcl_system, mock_hcl_conformers, torch.zeros((2, 1)) - ) + entry = EnergyEntry(mock_hcl_system, mock_hcl_conformers, torch.zeros((2, 1))) - mm_energies, mm_gradients, mm_hessians = energy_objective._evaluate_mm_energies( + mm_energies, mm_gradients, mm_hessians = entry._evaluate_mm_energies( SMIRNOFFModel([], None), compute_gradients, compute_hessians ) @@ -222,16 +200,18 @@ def test_evaluate_energies(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_val expected_energies, *_ = mock_hcl_mm_values expected_scale = torch.rand(1) - energy_objective = EnergyObjective( + entry = EnergyEntry( mock_hcl_system, mock_hcl_conformers, reference_energies=expected_energies + torch.ones_like(expected_energies), + ) + + loss = entry.evaluate_loss( + SMIRNOFFModel([], None), energy_transforms=lambda x: expected_scale * x, energy_metric=metrics.mse(), ) - loss = energy_objective.evaluate(SMIRNOFFModel([], None)) - assert loss.shape == (1,) assert torch.isclose(loss, expected_scale.square()) @@ -241,19 +221,21 @@ def test_evaluate_gradients(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_va expected_energies, expected_gradients, _ = mock_hcl_mm_values expected_scale = torch.rand(1) - energy_objective = EnergyObjective( + entry = EnergyEntry( mock_hcl_system, mock_hcl_conformers, # Set a reference energy to make sure gradient contributions don't # bleed between loss functions reference_energies=expected_energies, reference_gradients=expected_gradients + torch.ones_like(expected_gradients), + ) + + loss = entry.evaluate_loss( + SMIRNOFFModel([], None), gradient_transforms=lambda x: expected_scale * x, gradient_metric=metrics.mse(()), ) - loss = energy_objective.evaluate(SMIRNOFFModel([], None)) - assert loss.shape == (1,) assert torch.isclose(loss, expected_scale.square()) @@ -263,7 +245,7 @@ def test_evaluate_hessians(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_val expected_energies, expected_gradients, expected_hessians = mock_hcl_mm_values expected_scale = torch.rand(1) - energy_objective = EnergyObjective( + entry = EnergyEntry( mock_hcl_system, mock_hcl_conformers, # Set a reference energy to make sure gradient contributions don't @@ -271,23 +253,35 @@ def test_evaluate_hessians(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_val reference_energies=expected_energies, reference_gradients=expected_gradients, reference_hessians=expected_hessians + torch.ones_like(expected_hessians), + ) + + loss = entry.evaluate_loss( + SMIRNOFFModel([], None), hessian_transforms=lambda x: expected_scale * x, hessian_metric=metrics.mse(()), ) - loss = energy_objective.evaluate(SMIRNOFFModel([], None)) - assert loss.shape == (1,) assert torch.isclose(loss, expected_scale.square()) +def test_evaluate_loss_contribution(): + + reference_tensor = torch.tensor([[1.0], [2.0]]) + computed_tensor = torch.tensor([[4.0], [8.0]]) + + loss = EnergyEntry._evaluate_loss_contribution( + reference_tensor, computed_tensor, transforms.relative(), metrics.mse() + ) + + assert torch.isclose(loss, torch.tensor((4.0 - 1.0) ** 2 * 0.5)) + + def test_from_grouped_results(mock_hcl_conformers, mock_hcl_mm_values): - def energy_transforms(x): - return x * 2.0 mock_energies, mock_gradients, mock_hessians = mock_hcl_mm_values - created_term = EnergyObjective._from_grouped_results( + created_term = EnergyDataset._from_grouped_results( ( "[Cl:1][Cl:2]", mock_hcl_conformers, @@ -296,20 +290,12 @@ def energy_transforms(x): mock_hessians, ), ForceField("openff_unconstrained-1.0.0.offxml"), - energy_transforms=dill.dumps(energy_transforms), - energy_metric=dill.dumps(None), - gradient_transforms=dill.dumps(None), - gradient_metric=dill.dumps(None), - hessian_transforms=dill.dumps(None), - hessian_metric=dill.dumps(None), ) - assert created_term._system is not None + assert created_term._model_input is not None assert torch.allclose(created_term._conformers, mock_hcl_conformers) - assert torch.allclose( - created_term._reference_energies, energy_transforms(mock_energies) - ) + assert torch.allclose(created_term._reference_energies, mock_energies) assert torch.allclose(created_term._reference_gradients, mock_gradients) assert torch.allclose(created_term._reference_hessians, mock_hessians) @@ -344,145 +330,56 @@ def test_from_optimization_results( molecules, monkeypatch ) - energy_terms = EnergyObjective.from_optimization_results( + energy_dataset = EnergyDataset.from_optimization_results( optimization_collection, initial_force_field=ForceField(), include_energies=include_energies, - energy_transforms=transforms.relative(index=0), include_gradients=include_gradients, gradient_coordinate_system="cartesian", include_hessians=include_hessians, hessian_coordinate_system="cartesian", ) - assert len(energy_terms) == 2 + assert len(energy_dataset) == 2 - for energy_term, n_atoms in zip(energy_terms, [5, 8]): + for energy_entry, n_atoms in zip(energy_dataset, [5, 8]): if not include_energies: - assert energy_term._reference_energies is None + assert energy_entry._reference_energies is None else: - assert energy_term._reference_energies is not None - assert energy_term._reference_energies.shape == (2, 1) + assert energy_entry._reference_energies is not None + assert energy_entry._reference_energies.shape == (2, 1) assert not torch.allclose( - energy_term._reference_energies, - torch.zeros_like(energy_term._reference_energies), + energy_entry._reference_energies, + torch.zeros_like(energy_entry._reference_energies), ) if not include_gradients: - assert energy_term._reference_gradients is None + assert energy_entry._reference_gradients is None else: - assert energy_term._reference_gradients is not None - assert energy_term._reference_gradients.shape == (2, n_atoms, 3) + assert energy_entry._reference_gradients is not None + assert energy_entry._reference_gradients.shape == (2, n_atoms, 3) assert not torch.allclose( - energy_term._reference_gradients, - torch.zeros_like(energy_term._reference_gradients), + energy_entry._reference_gradients, + torch.zeros_like(energy_entry._reference_gradients), ) if not include_hessians: - assert energy_term._reference_hessians is None + assert energy_entry._reference_hessians is None else: - assert energy_term._reference_hessians is not None - assert energy_term._reference_hessians.shape == ( + assert energy_entry._reference_hessians is not None + assert energy_entry._reference_hessians.shape == ( 2, n_atoms * 3, n_atoms * 3, ) assert not torch.allclose( - energy_term._reference_hessians, - torch.zeros_like(energy_term._reference_hessians), + energy_entry._reference_hessians, + torch.zeros_like(energy_entry._reference_hessians), ) - - -def test_get_state(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_values): - - mock_mm_energies, mock_mm_gradients, mock_mm_hessians = mock_hcl_mm_values - - term = EnergyObjective( - mock_hcl_system, - mock_hcl_conformers, - mock_mm_energies, - None, - None, - mock_mm_gradients, - None, - None, - "cartesian", - mock_mm_hessians, - None, - None, - "cartesian", - ) - - state = term.__getstate__() - - assert "_system" in state - - assert "_conformers" in state - assert torch.allclose(state["_conformers"], mock_hcl_conformers) - - assert "_reference_energies" in state - assert torch.allclose(state["_reference_energies"], mock_mm_energies) - - assert callable(dill.loads(state["_energy_transforms"])) - assert callable(dill.loads(state["_energy_metric"])) - - assert "_reference_gradients" in state - assert torch.allclose(state["_reference_gradients"], mock_mm_gradients) - - assert callable(dill.loads(state["_gradient_transforms"])[0]) - assert callable(dill.loads(state["_gradient_metric"])) - - assert "_reference_hessians" in state - assert torch.allclose(state["_reference_hessians"], mock_mm_hessians) - - assert callable(dill.loads(state["_hessian_transforms"])[0]) - assert callable(dill.loads(state["_hessian_metric"])) - - -def test_set_state(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_values): - - mock_mm_energies, mock_mm_gradients, mock_mm_hessians = mock_hcl_mm_values - - term = EnergyObjective( - mock_hcl_system, - mock_hcl_conformers, - mock_mm_energies, - None, - None, - mock_mm_gradients, - None, - None, - "cartesian", - mock_mm_hessians, - None, - None, - "cartesian", - ) - - term_new = EnergyObjective.__new__(EnergyObjective) - term_new.__setstate__(term.__getstate__()) - - assert term_new._system == term._system - assert torch.allclose(term._conformers, term_new._conformers) - - assert torch.allclose(term._reference_energies, term_new._reference_energies) - - assert callable(term_new._energy_transforms) - assert callable(term_new._energy_metric) - - assert torch.allclose(term._reference_gradients, term_new._reference_gradients) - - assert callable(term_new._gradient_transforms[0]) - assert callable(term_new._gradient_metric) - - assert torch.allclose(term._reference_hessians, term_new._reference_hessians) - - assert callable(term_new._hessian_transforms[0]) - assert callable(term_new._hessian_metric) diff --git a/descent/tests/objectives/test_objectives.py b/descent/tests/objectives/test_objectives.py deleted file mode 100644 index 13b5644..0000000 --- a/descent/tests/objectives/test_objectives.py +++ /dev/null @@ -1,22 +0,0 @@ -from openff.toolkit.typing.engines.smirnoff import ForceField - -from descent.models.smirnoff import SMIRNOFFModel -from descent.objectives import ObjectiveContribution - - -def test_call(monkeypatch): - - evaluate_called = False - - class DummyObjective(ObjectiveContribution): - @property - def parameter_ids(self): - return [] - - def evaluate(self, model): - - nonlocal evaluate_called - evaluate_called = True - - DummyObjective()(SMIRNOFFModel([], ForceField)) - assert evaluate_called diff --git a/devtools/conda-envs/meta.yaml b/devtools/conda-envs/meta.yaml index ea92c53..25e1d85 100644 --- a/devtools/conda-envs/meta.yaml +++ b/devtools/conda-envs/meta.yaml @@ -10,7 +10,6 @@ dependencies: # Core dependencies - python - pip - - dill - tqdm - openff-toolkit-base >=0.9.2 diff --git a/examples/energy-and-gradient.ipynb b/examples/energy-and-gradient.ipynb index c88e0ae..3047ccd 100644 --- a/examples/energy-and-gradient.ipynb +++ b/examples/energy-and-gradient.ipynb @@ -39,7 +39,7 @@ { "cell_type": "markdown", "source": [ - "### Retrieving the QM training set\n", + "### Curating a QC training set\n", "\n", "For this example we will be training against QM energies which have been computed by and stored within the\n", "[QCArchive](https://qcarchive.molssi.org/), which are easily retrieved using the [OpenFF QCSubmit](https://github.com/openforcefield/openff-qcsubmit)\n", @@ -133,16 +133,9 @@ "You should see that our filtered collection contains the 6 results, which corresponds to 6 minimized conformers (and\n", "their associated energy computed using the OpenFF default B3LYP-D3BJ spec) for the molecule we filtered for above.\n", "\n", - "### Defining the objective ( / loss) function\n", - "\n", - "For this example we will be training our force field parameters against:\n", - "\n", - "* the relative energies between each conformer with the first conformer of the molecule\n", - "* the deviations between the QM and MM gradients projected along the redundant internal coordinates (RIC) of\n", - " the molecule.\n", - "\n", - "The construction of such a loss function is made trivial using the built-in ``EnergyObjective`` class which can be\n", - "created directly from the collection of optimization results we retrieved above:\n", + "In order to be able to train our parameter against this data we need to wrap it in a PyTorch dataset object. This\n", + "is made trivial thanks to the built-in ``EnergyDataset`` object that ships with the framework. The energy dataset\n", + "will extract and store any energy, gradient, and hessian data in a format ready for evaluating a loss function.\n", "\n", "We first load in the initial force field parameters ($\\theta$) using the [OpenFF Toolkit](https://github.com/openforcefield/openff-toolkit):" ], @@ -168,7 +161,7 @@ { "cell_type": "markdown", "source": [ - "which we can then use to construct our contribution objects:" + "which we can then use to construct our dataset:" ], "metadata": { "collapsed": false, @@ -180,29 +173,28 @@ { "cell_type": "code", "execution_count": 5, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 182.09it/s]\n", + "Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 2606.78it/s]\n", + "Building entries.: 0%| | 0/1 [00:00