Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility to find SMIRNOFF parameters exercised by dataset #11

Merged
merged 2 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 193 additions & 1 deletion descent/tests/utilities/test_smirnoff.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
import torch
from openff.interchange.models import PotentialKey
from openff.toolkit.typing.engines.smirnoff import ForceField
from simtk import unit as simtk_unit

from descent.data import DatasetEntry
from descent.tests import is_close
from descent.utilities.smirnoff import perturb_force_field
from descent.utilities.smirnoff import exercised_parameters, perturb_force_field


def test_perturb_force_field():
Expand Down Expand Up @@ -40,3 +42,193 @@ def test_perturb_force_field():
assert is_close(
perturbed_force_field["ProperTorsions"].parameters[smirks].idivf2, 1.2
)


@pytest.mark.parametrize(
"handlers_to_include,"
"handlers_to_exclude,"
"ids_to_include,"
"ids_to_exclude,"
"attributes_to_include,"
"attributes_to_exclude,"
"n_expected,"
"expected_handlers,"
"expected_potential_keys,"
"expected_attributes",
[
(
None,
None,
None,
None,
None,
None,
36,
{"Bonds", "Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler=handler)
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
for handler in ("Bonds", "Angles")
},
{"k", "length", "angle"},
),
(
["Bonds"],
None,
None,
None,
None,
None,
18,
{"Bonds"},
{
PotentialKey(id=smirks, mult=mult, associated_handler="Bonds")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"k", "length"},
),
(
None,
["Bonds"],
None,
None,
None,
None,
18,
{"Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler="Angles")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"k", "angle"},
),
(
None,
None,
[PotentialKey(id="b", mult=0, associated_handler="Bonds")],
None,
None,
None,
2,
{"Bonds"},
{PotentialKey(id="b", mult=0, associated_handler="Bonds")},
{"k", "length"},
),
(
None,
None,
None,
[
PotentialKey(id="b", mult=0, associated_handler="Bonds"),
],
None,
None,
34,
{"Bonds", "Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler=handler)
for handler in ("Bonds", "Angles")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
if (smirks != "b" or mult != 0 or handler != "Bonds")
},
{"k", "length", "angle"},
),
(
None,
None,
None,
None,
["length"],
None,
9,
{"Bonds"},
{
PotentialKey(id=smirks, mult=mult, associated_handler="Bonds")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"length"},
),
(
None,
None,
None,
None,
None,
["length"],
27,
{"Bonds", "Angles"},
{
PotentialKey(id=smirks, mult=mult, associated_handler=handler)
for handler in ("Bonds", "Angles")
for smirks in ("a", "b", "c")
for mult in (None, 0, 1)
},
{"k", "angle"},
),
],
)
def test_exercised_parameters(
handlers_to_include,
handlers_to_exclude,
ids_to_include,
ids_to_exclude,
attributes_to_include,
attributes_to_exclude,
n_expected,
expected_handlers,
expected_potential_keys,
expected_attributes,
):
class MockEntry(DatasetEntry):
def evaluate_loss(self, model, **kwargs):
pass

def mock_entry(handler, patterns, mult):

attributes = {"Bonds": ["k", "length"], "Angles": ["k", "angle"]}[handler]

entry = MockEntry.__new__(MockEntry)
entry._model_input = {
(handler, ""): (
None,
None,
[
(
PotentialKey(id=smirks, mult=mult, associated_handler=handler),
attributes,
)
for smirks in patterns
],
)
}
return entry

entries = [
mock_entry(handler, patterns, mult)
for handler in ["Bonds", "Angles"]
for patterns in [("a", "b"), ("b", "c")]
for mult in [None, 0, 1]
]

parameter_keys = exercised_parameters(
entries,
handlers_to_include,
handlers_to_exclude,
ids_to_include,
ids_to_exclude,
attributes_to_include,
attributes_to_exclude,
)

assert len(parameter_keys) == n_expected

actual_handlers, actual_keys, actual_attributes = zip(*parameter_keys)

assert {*actual_handlers} == expected_handlers
assert {*actual_keys} == expected_potential_keys
assert {*actual_attributes} == expected_attributes
13 changes: 13 additions & 0 deletions descent/tests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from descent.utilities import value_or_list_to_list


@pytest.mark.parametrize(
"function_input, expected_output",
[(None, None), (2, [2]), ("a", ["a"]), ([1, 2], [1, 2]), (["a", "b"], ["a", "b"])],
)
def test_value_or_list_to_list(function_input, expected_output):

actual_output = value_or_list_to_list(function_input)
assert expected_output == actual_output
3 changes: 3 additions & 0 deletions descent/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from descent.utilities.utilities import value_or_list_to_list

__all__ = [value_or_list_to_list]
81 changes: 79 additions & 2 deletions descent/utilities/smirnoff.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import copy
from typing import List, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union

import torch
from openff.interchange.components.interchange import Interchange
from openff.interchange.models import PotentialKey
from openff.toolkit.typing.engines.smirnoff import ForceField
from openff.toolkit.utils import string_to_unit
from smirnoffee.smirnoff import _DEFAULT_UNITS
from smirnoffee.smirnoff import _DEFAULT_UNITS, vectorize_system

from descent.utilities import value_or_list_to_list

if TYPE_CHECKING:
from descent.data import Dataset, DatasetEntry


def perturb_force_field(
Expand Down Expand Up @@ -52,3 +58,74 @@ def perturb_force_field(
setattr(parameter, attribute, original_value + delta)

return force_field


def exercised_parameters(
dataset: Union["Dataset", Iterable["DatasetEntry"], Iterable[Interchange]],
handlers_to_include: Optional[Union[str, List[str]]] = None,
handlers_to_exclude: Optional[Union[str, List[str]]] = None,
ids_to_include: Optional[Union[PotentialKey, List[PotentialKey]]] = None,
ids_to_exclude: Optional[Union[PotentialKey, List[PotentialKey]]] = None,
attributes_to_include: Optional[Union[str, List[str]]] = None,
attributes_to_exclude: Optional[Union[str, List[str]]] = None,
) -> List[Tuple[str, PotentialKey, str]]:
"""Returns the identifiers of each parameter that has been assigned to each molecule
in a dataset.

Notes:
This function assumes that the dataset was created using an OpenFF interchange
object as the main input.

Args:
dataset: The dataset, list of dataset entries, or list of interchange objects
That track a set of SMIRNOFF parameters assigned to a set of molecules.
handlers_to_include: An optional list of the parameter handlers that the returned
parameters should be associated with.
handlers_to_exclude: An optional list of the parameter handlers that the returned
parameters should **not** be associated with.
ids_to_include: An optional list of the potential keys that the parameters should
match with to be returned.
ids_to_exclude: An optional list of the potential keys that the parameters should
**not** match with to be returned.
attributes_to_include: An optional list of the attributes that the parameters
should match with to be returned.
attributes_to_exclude: An optional list of the attributes that the parameters
should **not** match with to be returned.

Returns:
A list of tuples of the form ``(handler_type, potential_key, attribute_name)``.
"""

def should_skip(value, to_include, to_exclude) -> bool:

to_include = value_or_list_to_list(to_include)
to_exclude = value_or_list_to_list(to_exclude)

return (to_include is not None and value not in to_include) or (
to_exclude is not None and value in to_exclude
)

vectorized_systems = [
entry.model_input
if not isinstance(entry, Interchange)
else vectorize_system(entry)
for entry in dataset
]

return_value = {
(handler_type, potential_key, attribute)
for vectorized_system in vectorized_systems
for (handler_type, _), (*_, potential_keys) in vectorized_system.items()
if not should_skip(handler_type, handlers_to_include, handlers_to_exclude)
for (potential_key, attributes) in potential_keys
if not should_skip(potential_key, ids_to_include, ids_to_exclude)
for attribute in attributes
if not should_skip(attribute, attributes_to_include, attributes_to_exclude)
}

return_value = sorted(
return_value,
key=lambda x: (x[0], x[1].id, x[1].mult if x[1].mult is not None else -1, x[2]),
)

return return_value
21 changes: 21 additions & 0 deletions descent/utilities/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List, TypeVar, Union, overload

T = TypeVar("T")


@overload
def value_or_list_to_list(value: Union[T, List[T]]) -> List[T]:
...


@overload
def value_or_list_to_list(value: None) -> None:
...


def value_or_list_to_list(value):

if value is None:
return value

return value if isinstance(value, list) else [value]
Loading