diff --git a/descent/tests/utilities/test_smirnoff.py b/descent/tests/utilities/test_smirnoff.py index ca94cf6..5a3e1a3 100644 --- a/descent/tests/utilities/test_smirnoff.py +++ b/descent/tests/utilities/test_smirnoff.py @@ -1,10 +1,12 @@ +import pytest import torch from openff.interchange.models import PotentialKey from openff.toolkit.typing.engines.smirnoff import ForceField from simtk import unit as simtk_unit +from descent.data import DatasetEntry from descent.tests import is_close -from descent.utilities.smirnoff import perturb_force_field +from descent.utilities.smirnoff import exercised_parameters, perturb_force_field def test_perturb_force_field(): @@ -40,3 +42,193 @@ def test_perturb_force_field(): assert is_close( perturbed_force_field["ProperTorsions"].parameters[smirks].idivf2, 1.2 ) + + +@pytest.mark.parametrize( + "handlers_to_include," + "handlers_to_exclude," + "ids_to_include," + "ids_to_exclude," + "attributes_to_include," + "attributes_to_exclude," + "n_expected," + "expected_handlers," + "expected_potential_keys," + "expected_attributes", + [ + ( + None, + None, + None, + None, + None, + None, + 36, + {"Bonds", "Angles"}, + { + PotentialKey(id=smirks, mult=mult, associated_handler=handler) + for smirks in ("a", "b", "c") + for mult in (None, 0, 1) + for handler in ("Bonds", "Angles") + }, + {"k", "length", "angle"}, + ), + ( + ["Bonds"], + None, + None, + None, + None, + None, + 18, + {"Bonds"}, + { + PotentialKey(id=smirks, mult=mult, associated_handler="Bonds") + for smirks in ("a", "b", "c") + for mult in (None, 0, 1) + }, + {"k", "length"}, + ), + ( + None, + ["Bonds"], + None, + None, + None, + None, + 18, + {"Angles"}, + { + PotentialKey(id=smirks, mult=mult, associated_handler="Angles") + for smirks in ("a", "b", "c") + for mult in (None, 0, 1) + }, + {"k", "angle"}, + ), + ( + None, + None, + [PotentialKey(id="b", mult=0, associated_handler="Bonds")], + None, + None, + None, + 2, + {"Bonds"}, + {PotentialKey(id="b", mult=0, associated_handler="Bonds")}, + {"k", "length"}, + ), + ( + None, + None, + None, + [ + PotentialKey(id="b", mult=0, associated_handler="Bonds"), + ], + None, + None, + 34, + {"Bonds", "Angles"}, + { + PotentialKey(id=smirks, mult=mult, associated_handler=handler) + for handler in ("Bonds", "Angles") + for smirks in ("a", "b", "c") + for mult in (None, 0, 1) + if (smirks != "b" or mult != 0 or handler != "Bonds") + }, + {"k", "length", "angle"}, + ), + ( + None, + None, + None, + None, + ["length"], + None, + 9, + {"Bonds"}, + { + PotentialKey(id=smirks, mult=mult, associated_handler="Bonds") + for smirks in ("a", "b", "c") + for mult in (None, 0, 1) + }, + {"length"}, + ), + ( + None, + None, + None, + None, + None, + ["length"], + 27, + {"Bonds", "Angles"}, + { + PotentialKey(id=smirks, mult=mult, associated_handler=handler) + for handler in ("Bonds", "Angles") + for smirks in ("a", "b", "c") + for mult in (None, 0, 1) + }, + {"k", "angle"}, + ), + ], +) +def test_exercised_parameters( + handlers_to_include, + handlers_to_exclude, + ids_to_include, + ids_to_exclude, + attributes_to_include, + attributes_to_exclude, + n_expected, + expected_handlers, + expected_potential_keys, + expected_attributes, +): + class MockEntry(DatasetEntry): + def evaluate_loss(self, model, **kwargs): + pass + + def mock_entry(handler, patterns, mult): + + attributes = {"Bonds": ["k", "length"], "Angles": ["k", "angle"]}[handler] + + entry = MockEntry.__new__(MockEntry) + entry._model_input = { + (handler, ""): ( + None, + None, + [ + ( + PotentialKey(id=smirks, mult=mult, associated_handler=handler), + attributes, + ) + for smirks in patterns + ], + ) + } + return entry + + entries = [ + mock_entry(handler, patterns, mult) + for handler in ["Bonds", "Angles"] + for patterns in [("a", "b"), ("b", "c")] + for mult in [None, 0, 1] + ] + + parameter_keys = exercised_parameters( + entries, + handlers_to_include, + handlers_to_exclude, + ids_to_include, + ids_to_exclude, + attributes_to_include, + attributes_to_exclude, + ) + + assert len(parameter_keys) == n_expected + + actual_handlers, actual_keys, actual_attributes = zip(*parameter_keys) + + assert {*actual_handlers} == expected_handlers + assert {*actual_keys} == expected_potential_keys + assert {*actual_attributes} == expected_attributes diff --git a/descent/tests/utilities/test_utilities.py b/descent/tests/utilities/test_utilities.py new file mode 100644 index 0000000..2c19229 --- /dev/null +++ b/descent/tests/utilities/test_utilities.py @@ -0,0 +1,13 @@ +import pytest + +from descent.utilities import value_or_list_to_list + + +@pytest.mark.parametrize( + "function_input, expected_output", + [(None, None), (2, [2]), ("a", ["a"]), ([1, 2], [1, 2]), (["a", "b"], ["a", "b"])], +) +def test_value_or_list_to_list(function_input, expected_output): + + actual_output = value_or_list_to_list(function_input) + assert expected_output == actual_output diff --git a/descent/utilities/__init__.py b/descent/utilities/__init__.py index e69de29..0db1acf 100644 --- a/descent/utilities/__init__.py +++ b/descent/utilities/__init__.py @@ -0,0 +1,3 @@ +from descent.utilities.utilities import value_or_list_to_list + +__all__ = [value_or_list_to_list] diff --git a/descent/utilities/smirnoff.py b/descent/utilities/smirnoff.py index a334827..bfe0b34 100644 --- a/descent/utilities/smirnoff.py +++ b/descent/utilities/smirnoff.py @@ -1,11 +1,17 @@ import copy -from typing import List, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union import torch +from openff.interchange.components.interchange import Interchange from openff.interchange.models import PotentialKey from openff.toolkit.typing.engines.smirnoff import ForceField from openff.toolkit.utils import string_to_unit -from smirnoffee.smirnoff import _DEFAULT_UNITS +from smirnoffee.smirnoff import _DEFAULT_UNITS, vectorize_system + +from descent.utilities import value_or_list_to_list + +if TYPE_CHECKING: + from descent.data import Dataset, DatasetEntry def perturb_force_field( @@ -52,3 +58,74 @@ def perturb_force_field( setattr(parameter, attribute, original_value + delta) return force_field + + +def exercised_parameters( + dataset: Union["Dataset", Iterable["DatasetEntry"], Iterable[Interchange]], + handlers_to_include: Optional[Union[str, List[str]]] = None, + handlers_to_exclude: Optional[Union[str, List[str]]] = None, + ids_to_include: Optional[Union[PotentialKey, List[PotentialKey]]] = None, + ids_to_exclude: Optional[Union[PotentialKey, List[PotentialKey]]] = None, + attributes_to_include: Optional[Union[str, List[str]]] = None, + attributes_to_exclude: Optional[Union[str, List[str]]] = None, +) -> List[Tuple[str, PotentialKey, str]]: + """Returns the identifiers of each parameter that has been assigned to each molecule + in a dataset. + + Notes: + This function assumes that the dataset was created using an OpenFF interchange + object as the main input. + + Args: + dataset: The dataset, list of dataset entries, or list of interchange objects + That track a set of SMIRNOFF parameters assigned to a set of molecules. + handlers_to_include: An optional list of the parameter handlers that the returned + parameters should be associated with. + handlers_to_exclude: An optional list of the parameter handlers that the returned + parameters should **not** be associated with. + ids_to_include: An optional list of the potential keys that the parameters should + match with to be returned. + ids_to_exclude: An optional list of the potential keys that the parameters should + **not** match with to be returned. + attributes_to_include: An optional list of the attributes that the parameters + should match with to be returned. + attributes_to_exclude: An optional list of the attributes that the parameters + should **not** match with to be returned. + + Returns: + A list of tuples of the form ``(handler_type, potential_key, attribute_name)``. + """ + + def should_skip(value, to_include, to_exclude) -> bool: + + to_include = value_or_list_to_list(to_include) + to_exclude = value_or_list_to_list(to_exclude) + + return (to_include is not None and value not in to_include) or ( + to_exclude is not None and value in to_exclude + ) + + vectorized_systems = [ + entry.model_input + if not isinstance(entry, Interchange) + else vectorize_system(entry) + for entry in dataset + ] + + return_value = { + (handler_type, potential_key, attribute) + for vectorized_system in vectorized_systems + for (handler_type, _), (*_, potential_keys) in vectorized_system.items() + if not should_skip(handler_type, handlers_to_include, handlers_to_exclude) + for (potential_key, attributes) in potential_keys + if not should_skip(potential_key, ids_to_include, ids_to_exclude) + for attribute in attributes + if not should_skip(attribute, attributes_to_include, attributes_to_exclude) + } + + return_value = sorted( + return_value, + key=lambda x: (x[0], x[1].id, x[1].mult if x[1].mult is not None else -1, x[2]), + ) + + return return_value diff --git a/descent/utilities/utilities.py b/descent/utilities/utilities.py new file mode 100644 index 0000000..7ff7431 --- /dev/null +++ b/descent/utilities/utilities.py @@ -0,0 +1,21 @@ +from typing import List, TypeVar, Union, overload + +T = TypeVar("T") + + +@overload +def value_or_list_to_list(value: Union[T, List[T]]) -> List[T]: + ... + + +@overload +def value_or_list_to_list(value: None) -> None: + ... + + +def value_or_list_to_list(value): + + if value is None: + return value + + return value if isinstance(value, list) else [value] diff --git a/examples/energy-and-gradient.ipynb b/examples/energy-and-gradient.ipynb index 0c50c68..f9ace63 100644 --- a/examples/energy-and-gradient.ipynb +++ b/examples/energy-and-gradient.ipynb @@ -178,10 +178,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 199.60it/s]\n", - "Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 2074.33it/s]\n", + "Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 183.13it/s]\n", + "Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 3483.64it/s]\n", "Building entries.: 0%| | 0/1 [00:00