Skip to content

Commit

Permalink
Add utility to find SMIRNOFF parameters exercised by dataset (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Sep 13, 2021
1 parent 665a67e commit fcefc51
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 41 deletions.
194 changes: 193 additions & 1 deletion descent/tests/utilities/test_smirnoff.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
import torch
from openff.interchange.models import PotentialKey
from openff.toolkit.typing.engines.smirnoff import ForceField
from simtk import unit as simtk_unit

from descent.data import DatasetEntry
from descent.tests import is_close
from descent.utilities.smirnoff import perturb_force_field
from descent.utilities.smirnoff import exercised_parameters, perturb_force_field


def test_perturb_force_field():
Expand Down Expand Up @@ -40,3 +42,193 @@ def test_perturb_force_field():
assert is_close(
perturbed_force_field["ProperTorsions"].parameters[smirks].idivf2, 1.2
)


@pytest.mark.parametrize(
"handlers_to_include,"
"handlers_to_exclude,"
"ids_to_include,"
"ids_to_exclude,"
"attributes_to_include,"
"attributes_to_exclude,"
"n_expected,"
"expected_handlers,"
"expected_potential_keys,"
"expected_attributes",
[
(
None,
None,
None,
None,
None,
None,
36,
{"Bonds", "Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler=handler)
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
for handler in ("Bonds", "Angles")
},
{"k", "length", "angle"},
),
(
["Bonds"],
None,
None,
None,
None,
None,
18,
{"Bonds"},
{
PotentialKey(id=smirks, mult=mult, associated_handler="Bonds")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"k", "length"},
),
(
None,
["Bonds"],
None,
None,
None,
None,
18,
{"Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler="Angles")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"k", "angle"},
),
(
None,
None,
[PotentialKey(id="b", mult=0, associated_handler="Bonds")],
None,
None,
None,
2,
{"Bonds"},
{PotentialKey(id="b", mult=0, associated_handler="Bonds")},
{"k", "length"},
),
(
None,
None,
None,
[
PotentialKey(id="b", mult=0, associated_handler="Bonds"),
],
None,
None,
34,
{"Bonds", "Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler=handler)
for handler in ("Bonds", "Angles")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
if (smirks != "b" or mult != 0 or handler != "Bonds")
},
{"k", "length", "angle"},
),
(
None,
None,
None,
None,
["length"],
None,
9,
{"Bonds"},
{
PotentialKey(id=smirks, mult=mult, associated_handler="Bonds")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"length"},
),
(
None,
None,
None,
None,
None,
["length"],
27,
{"Bonds", "Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler=handler)
for handler in ("Bonds", "Angles")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"k", "angle"},
),
],
)
def test_exercised_parameters(
handlers_to_include,
handlers_to_exclude,
ids_to_include,
ids_to_exclude,
attributes_to_include,
attributes_to_exclude,
n_expected,
expected_handlers,
expected_potential_keys,
expected_attributes,
):
class MockEntry(DatasetEntry):
def evaluate_loss(self, model, **kwargs):
pass

def mock_entry(handler, patterns, mult):

attributes = {"Bonds": ["k", "length"], "Angles": ["k", "angle"]}[handler]

entry = MockEntry.__new__(MockEntry)
entry._model_input = {
(handler, ""): (
None,
None,
[
(
PotentialKey(id=smirks, mult=mult, associated_handler=handler),
attributes,
)
for smirks in patterns
],
)
}
return entry

entries = [
mock_entry(handler, patterns, mult)
for handler in ["Bonds", "Angles"]
for patterns in [("a", "b"), ("b", "c")]
for mult in [None, 0, 1]
]

parameter_keys = exercised_parameters(
entries,
handlers_to_include,
handlers_to_exclude,
ids_to_include,
ids_to_exclude,
attributes_to_include,
attributes_to_exclude,
)

assert len(parameter_keys) == n_expected

actual_handlers, actual_keys, actual_attributes = zip(*parameter_keys)

assert {*actual_handlers} == expected_handlers
assert {*actual_keys} == expected_potential_keys
assert {*actual_attributes} == expected_attributes
13 changes: 13 additions & 0 deletions descent/tests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from descent.utilities import value_or_list_to_list


@pytest.mark.parametrize(
"function_input, expected_output",
[(None, None), (2, [2]), ("a", ["a"]), ([1, 2], [1, 2]), (["a", "b"], ["a", "b"])],
)
def test_value_or_list_to_list(function_input, expected_output):

actual_output = value_or_list_to_list(function_input)
assert expected_output == actual_output
3 changes: 3 additions & 0 deletions descent/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from descent.utilities.utilities import value_or_list_to_list

__all__ = [value_or_list_to_list]
81 changes: 79 additions & 2 deletions descent/utilities/smirnoff.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import copy
from typing import List, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union

import torch
from openff.interchange.components.interchange import Interchange
from openff.interchange.models import PotentialKey
from openff.toolkit.typing.engines.smirnoff import ForceField
from openff.toolkit.utils import string_to_unit
from smirnoffee.smirnoff import _DEFAULT_UNITS
from smirnoffee.smirnoff import _DEFAULT_UNITS, vectorize_system

from descent.utilities import value_or_list_to_list

if TYPE_CHECKING:
from descent.data import Dataset, DatasetEntry


def perturb_force_field(
Expand Down Expand Up @@ -52,3 +58,74 @@ def perturb_force_field(
setattr(parameter, attribute, original_value + delta)

return force_field


def exercised_parameters(
dataset: Union["Dataset", Iterable["DatasetEntry"], Iterable[Interchange]],
handlers_to_include: Optional[Union[str, List[str]]] = None,
handlers_to_exclude: Optional[Union[str, List[str]]] = None,
ids_to_include: Optional[Union[PotentialKey, List[PotentialKey]]] = None,
ids_to_exclude: Optional[Union[PotentialKey, List[PotentialKey]]] = None,
attributes_to_include: Optional[Union[str, List[str]]] = None,
attributes_to_exclude: Optional[Union[str, List[str]]] = None,
) -> List[Tuple[str, PotentialKey, str]]:
"""Returns the identifiers of each parameter that has been assigned to each molecule
in a dataset.
Notes:
This function assumes that the dataset was created using an OpenFF interchange
object as the main input.
Args:
dataset: The dataset, list of dataset entries, or list of interchange objects
That track a set of SMIRNOFF parameters assigned to a set of molecules.
handlers_to_include: An optional list of the parameter handlers that the returned
parameters should be associated with.
handlers_to_exclude: An optional list of the parameter handlers that the returned
parameters should **not** be associated with.
ids_to_include: An optional list of the potential keys that the parameters should
match with to be returned.
ids_to_exclude: An optional list of the potential keys that the parameters should
**not** match with to be returned.
attributes_to_include: An optional list of the attributes that the parameters
should match with to be returned.
attributes_to_exclude: An optional list of the attributes that the parameters
should **not** match with to be returned.
Returns:
A list of tuples of the form ``(handler_type, potential_key, attribute_name)``.
"""

def should_skip(value, to_include, to_exclude) -> bool:

to_include = value_or_list_to_list(to_include)
to_exclude = value_or_list_to_list(to_exclude)

return (to_include is not None and value not in to_include) or (
to_exclude is not None and value in to_exclude
)

vectorized_systems = [
entry.model_input
if not isinstance(entry, Interchange)
else vectorize_system(entry)
for entry in dataset
]

return_value = {
(handler_type, potential_key, attribute)
for vectorized_system in vectorized_systems
for (handler_type, _), (*_, potential_keys) in vectorized_system.items()
if not should_skip(handler_type, handlers_to_include, handlers_to_exclude)
for (potential_key, attributes) in potential_keys
if not should_skip(potential_key, ids_to_include, ids_to_exclude)
for attribute in attributes
if not should_skip(attribute, attributes_to_include, attributes_to_exclude)
}

return_value = sorted(
return_value,
key=lambda x: (x[0], x[1].id, x[1].mult if x[1].mult is not None else -1, x[2]),
)

return return_value
21 changes: 21 additions & 0 deletions descent/utilities/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List, TypeVar, Union, overload

T = TypeVar("T")


@overload
def value_or_list_to_list(value: Union[T, List[T]]) -> List[T]:
...


@overload
def value_or_list_to_list(value: None) -> None:
...


def value_or_list_to_list(value):

if value is None:
return value

return value if isinstance(value, list) else [value]
Loading

0 comments on commit fcefc51

Please sign in to comment.