Skip to content

Commit

Permalink
Add function to summarize SMIRNOFF model (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Sep 13, 2021
1 parent a74a086 commit 665a67e
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 43 deletions.
4 changes: 4 additions & 0 deletions descent/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ class ParameterizationModel(Protocol):

def forward(self, graph: Any) -> VectorizedSystem:
"""Outputs a vectorised view of a parameterized molecule."""

def summarise(self):
"""Print a summary of the status of this model, such as the differences between
the initial and current state during training."""
182 changes: 182 additions & 0 deletions descent/models/smirnoff.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import io
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import torch.nn
from openff.interchange.models import PotentialKey
from openff.toolkit.typing.engines.smirnoff import ForceField
from smirnoffee.potentials import add_parameter_delta
from typing_extensions import Literal

from descent.models.models import VectorizedSystem
from descent.utilities.smirnoff import perturb_force_field
Expand Down Expand Up @@ -166,3 +168,183 @@ def to_force_field(self) -> ForceField:
for (smirks, attribute) in handler_ids
],
)

def summarise(
self,
parameter_id_type: Literal["smirks", "id"] = "smirks",
print_to_terminal: bool = True,
) -> str:
"""
Args:
parameter_id_type: The type of ID to show for each parameter. Currently
this can either be the unique ``'id'`` associated with the parameter or
the ``'smirks'`` pattern that encodes the chemical environment the
parameter is applied to.
print_to_terminal: Whether to print the summary to the terminal
Returns:
A string containing the summary.
"""

from openff.units.simtk import from_simtk

final_force_field = self.to_force_field()

# Reshape the data into dictionaries to make tabulation easier
table_data = defaultdict(lambda: defaultdict(dict))
attribute_units = {}

for handler_type, potential_key, attribute in [
(handler_type, potential_key, attribute)
for handler_type, parameter_ids in self._parameter_delta_ids.items()
for (potential_key, attribute) in parameter_ids
]:

smirks = potential_key.id

attribute = (
attribute
if potential_key.mult is None
else f"{attribute}{potential_key.mult}"
)

initial_value = from_simtk(
getattr(
self._initial_force_field[handler_type].parameters[smirks],
attribute,
)
)
final_value = from_simtk(
getattr(final_force_field[handler_type].parameters[smirks], attribute)
)

if (handler_type, attribute) not in attribute_units:
attribute_units[(handler_type, attribute)] = initial_value.units

unit = attribute_units[(handler_type, attribute)]

attribute = f"{attribute} ({unit:P~})"

if parameter_id_type == "id":

smirks = self._initial_force_field[handler_type].parameters[smirks].id
smirks = smirks if smirks is not None else "NO ID"

table_data[handler_type][attribute][smirks] = (
initial_value.to(unit).m,
final_value.to(unit).m,
)

# Construct the final return value:
return_value = io.StringIO()

for handler_type, attribute_data in table_data.items():

print(f"\n{handler_type.center(80, '=')}\n", file=return_value)

attribute_headers = sorted(attribute_data)

attribute_widths = {
attribute: max(
[
len(f"{value:.4f}")
for value_tuple in attribute_data[attribute].values()
for value in value_tuple
]
)
* 2
+ 1
for attribute in attribute_headers
}
attribute_widths = {
# Make sure the width of the column - 1 is divisible by 2
attribute: max(int((column_width - 1) / 2.0 + 0.5) * 2 + 1, 15)
for attribute, column_width in attribute_widths.items()
}

smirks_width = max(
len(smirks)
for smirks_data in attribute_data.values()
for smirks in smirks_data
)

first_header = (
" " * (smirks_width)
+ " "
+ " ".join(
[
attribute.center(attribute_widths[attribute], " ")
for attribute in attribute_headers
]
)
)
second_header = (
" " * (smirks_width)
+ " "
+ " ".join(
[
"INITIAL".center((column_width - 1) // 2, " ")
+ " "
+ "FINAL".center((column_width - 1) // 2, " ")
for attribute, column_width in attribute_widths.items()
]
)
)
border = (
"-" * smirks_width
+ " "
+ " ".join(
[
"-" * attribute_widths[attribute]
for attribute in attribute_headers
]
)
)

smirks_data = defaultdict(dict)

for attribute in attribute_data:
for smirks, value_tuple in attribute_data[attribute].items():
smirks_data[smirks][attribute] = value_tuple

print(border, file=return_value)
print(first_header, file=return_value)
print(second_header, file=return_value)
print(border, file=return_value)

for smirks in sorted(smirks_data):

def format_column(attr, value_tuple):

if value_tuple is None:
return " " * attribute_widths[attr]

value_width = (attribute_widths[attr] - 1) // 2
return (
f"{value_tuple[0]:.4f}".ljust(value_width, " ")
+ " "
+ f"{value_tuple[1]:.4f}".ljust(value_width, " ")
)

row = (
f"{smirks.ljust(smirks_width)}"
+ " "
+ " ".join(
[
format_column(
attribute, smirks_data[smirks].get(attribute, None)
)
for attribute in attribute_headers
]
)
)

print(row, file=return_value)

return_value = return_value.getvalue()

if print_to_terminal:
print(return_value)

return return_value
35 changes: 35 additions & 0 deletions descent/tests/models/test_smirnoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,38 @@ def test_model_to_force_field(mock_force_field, covariance_tensor):
(4.0 + 1.0 * (1.0 if covariance_tensor is None else float(covariance_tensor)))
* simtk_unit.angstrom,
)


@pytest.mark.parametrize("parameter_id_type", ["id", "smirks"])
def test_model_summarise(mock_force_field, parameter_id_type):

mock_force_field["Bonds"].parameters["[#1:1]-[#17:2]"].id = "b1"

model = SMIRNOFFModel(
[
("Bonds", "[#1:1]-[#17:2]", "length"),
("Bonds", "[#1:1]-[#9:2]", "length"),
("Bonds", "[#1:1]-[#17:2]", "k"),
("Angles", "[#1:1]-[#8:2]-[#1:3]", "angle"),
("Angles", "[#1:1]-[#8:2]-[#1:3]", "k"),
],
mock_force_field,
)
model.parameter_delta = torch.nn.Parameter(torch.tensor([1.00001] * 5))
return_value = model.summarise(parameter_id_type=parameter_id_type)

assert return_value is not None
assert len(return_value) > 0

assert "Bonds" in return_value
assert "Angles" in return_value

assert " k (kJ/mol/Ų) length (Å) " in return_value
assert " angle (deg) k (kJ/deg²/mol)" in return_value

if parameter_id_type == "smirks":
assert "[#1:1]-[#17:2]" in return_value
assert "b1" not in return_value
else:
assert "[#1:1]-[#17:2]" not in return_value
assert "b1" in return_value
103 changes: 60 additions & 43 deletions examples/energy-and-gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 182.09it/s]\n",
"Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 2606.78it/s]\n",
"Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 199.60it/s]\n",
"Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 2074.33it/s]\n",
"Building entries.: 0%| | 0/1 [00:00<?, ?it/s]Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.\n",
"Building entries.: 100%|██████████| 1/1 [00:03<00:00, 3.35s/it]\n"
"Building entries.: 100%|██████████| 1/1 [00:03<00:00, 3.52s/it]\n"
]
}
],
Expand Down Expand Up @@ -264,7 +264,7 @@
"outputs": [
{
"data": {
"text/plain": "[('Bonds',\n PotentialKey(id='[#6X3:1]=[#7X2,#7X3+1:2]', mult=None, associated_handler='Bonds'),\n 'k'),\n ('Bonds',\n PotentialKey(id='[#7:1]-[#1:2]', mult=None, associated_handler='Bonds'),\n 'k')]"
"text/plain": "[('Bonds',\n PotentialKey(id='[#6X3:1]=[#6X3:2]', mult=None, associated_handler='Bonds'),\n 'k'),\n ('Bonds',\n PotentialKey(id='[#6X3:1]-[#1:2]', mult=None, associated_handler='Bonds'),\n 'k')]"
},
"execution_count": 7,
"metadata": {},
Expand Down Expand Up @@ -361,14 +361,14 @@
"text": [
"Epoch 0: loss=633.4320678710938\n",
"Epoch 20: loss=359.64111328125\n",
"Epoch 40: loss=340.67962646484375\n",
"Epoch 60: loss=334.7128601074219\n",
"Epoch 80: loss=332.68365478515625\n",
"Epoch 40: loss=340.6795349121094\n",
"Epoch 60: loss=334.71282958984375\n",
"Epoch 80: loss=332.6836242675781\n",
"Epoch 100: loss=332.1078796386719\n",
"Epoch 120: loss=331.9747009277344\n",
"Epoch 120: loss=331.97467041015625\n",
"Epoch 140: loss=331.9459533691406\n",
"Epoch 160: loss=331.9331359863281\n",
"Epoch 180: loss=331.92156982421875\n"
"Epoch 160: loss=331.9331970214844\n",
"Epoch 180: loss=331.9215393066406\n"
]
}
],
Expand Down Expand Up @@ -452,7 +452,7 @@
{
"cell_type": "markdown",
"source": [
"or print out the initial and final values."
"or print out a summary of the trained values"
],
"metadata": {
"collapsed": false,
Expand All @@ -469,49 +469,66 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Bonds SMIRKS=[#6X3:1]=[#7X2,#7X3+1:2] ATTR=k INITIAL=882.4191878243 kcal/(A**2 mol) FINAL=881.926830095258 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#7:1]-[#1:2] ATTR=k INITIAL=997.7547006218 kcal/(A**2 mol) FINAL=997.2764430951689 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]=[#6X3:2] ATTR=k INITIAL=857.1115548611 kcal/(A**2 mol) FINAL=856.6721659734958 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#6X3:2]=[#8X1+0] ATTR=k INITIAL=612.0537081219 kcal/(A**2 mol) FINAL=612.5316377660014 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#6X3:2] ATTR=k INITIAL=612.5097961064 kcal/(A**2 mol) FINAL=612.0317954609721 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6:1]=[#8X1+0,#8X2+1:2] ATTR=k INITIAL=1135.595318618 kcal/(A**2 mol) FINAL=1135.1173284290276 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#7X3:2] ATTR=k INITIAL=719.219372554 kcal/(A**2 mol) FINAL=718.7531744829888 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#1:2] ATTR=k INITIAL=808.1394472833 kcal/(A**2 mol) FINAL=807.661493535217 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#7X2:2] ATTR=k INITIAL=837.2647972972 kcal/(A**2 mol) FINAL=837.640644214469 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6:1]-[#7:2] ATTR=k INITIAL=719.6326854584 kcal/(A**2 mol) FINAL=719.1547107119311 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0] ATTR=k INITIAL=764.7120801727 kcal/(A**2 mol) FINAL=765.1886159741499 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1](=[#8X1+0])-[#7X3:2] ATTR=k INITIAL=1053.970761594 kcal/(A**2 mol) FINAL=1053.4928622935631 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#1:2] ATTR=k INITIAL=758.0931772913 kcal/(A**2 mol) FINAL=757.6152659958559 kcal/(A**2 mol)\n",
"Angles SMIRKS=[#1:1]-[#7X3$(*~[#6X3,#6X2,#7X2+0]):2]-[*:3] ATTR=k INITIAL=77.52610202633 kcal/(mol rad**2) FINAL=191.91446506590137 kcal/(mol rad**2)\n",
"Angles SMIRKS=[#1:1]-[#6X4:2]-[#1:3] ATTR=k INITIAL=74.28701527177 kcal/(mol rad**2) FINAL=3.1566064657054795 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#7X2+0:2]~[*:3] ATTR=k INITIAL=226.9001499199 kcal/(mol rad**2) FINAL=31.499088021470783 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#6X4:2]-[*:3] ATTR=k INITIAL=101.7373362367 kcal/(mol rad**2) FINAL=26.406365924907732 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#7X3$(*~[#6X3,#6X2,#7X2+0]):2]~[*:3] ATTR=k INITIAL=112.545110149 kcal/(mol rad**2) FINAL=94.10249558095923 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~;!@[*;X3;r5:2]~;@[*;r5:3] ATTR=k INITIAL=70.66802196994 kcal/(mol rad**2) FINAL=7.836206747420249 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#6X3:2]~[*:3] ATTR=k INITIAL=153.5899485526 kcal/(mol rad**2) FINAL=40.37185554971532 kcal/(mol rad**2)\n"
"\n",
"=====================================Bonds======================================\n",
"\n",
"--- -------------------\n",
" k (kcal/mol/Ų) \n",
" INITIAL FINAL \n",
"--- -------------------\n",
"b10 1053.9708 1053.4929\n",
"b11 837.2648 837.6406 \n",
"b13 882.4192 881.9268 \n",
"b2 612.5098 612.0318 \n",
"b20 1135.5953 1135.1173\n",
"b3 612.0537 612.5316 \n",
"b6 857.1116 856.6722 \n",
"b7 719.6327 719.1547 \n",
"b8 719.2194 718.7532 \n",
"b83 758.0932 757.6153 \n",
"b84 808.1394 807.6615 \n",
"b86 997.7547 997.2764 \n",
"b9 764.7121 765.1886 \n",
"\n",
"=====================================Angles=====================================\n",
"\n",
"--- -----------------\n",
" k (kcal/mol/rad²)\n",
" INITIAL FINAL \n",
"--- -----------------\n",
"a1 101.7373 26.4064 \n",
"a10 153.5899 40.3719 \n",
"a14 70.6680 7.8362 \n",
"a19 112.5451 94.1025 \n",
"a2 74.2870 3.1566 \n",
"a20 77.5261 191.9145\n",
"a22 226.9001 31.4991 \n",
"\n"
]
}
],
"source": [
"for parameter_handler, potential_key, attribute in parameter_delta_ids:\n",
"\n",
" initial_value = getattr(\n",
" initial_force_field[parameter_handler].parameters[potential_key.id], attribute\n",
" )\n",
" final_value = getattr(\n",
" final_force_field[parameter_handler].parameters[potential_key.id], attribute\n",
" )\n",
"\n",
" print(\n",
" f\"{parameter_handler} SMIRKS={potential_key.id} ATTR={attribute} INITIAL={initial_value} FINAL={final_value}\"\n",
" )"
"model.summarise(parameter_id_type=\"id\");"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"where here we have chosen to print the unique ID associated with each parameter as opposed to the SMIRKS pattern\n",
"(i.e `parameter_id_type=\"smirks\"`) for improved clarity."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
}
],
"metadata": {
Expand Down

0 comments on commit 665a67e

Please sign in to comment.