Skip to content

Commit

Permalink
Add utils to print SMEE force field parameters (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 12, 2023
1 parent f4231cd commit 032865c
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 1 deletion.
34 changes: 33 additions & 1 deletion descent/tests/utils/test_reporting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import openff.interchange
import openff.toolkit
import smee.converters
from matplotlib import pyplot
from rdkit import Chem

from descent.utils.reporting import _mol_from_smiles, figure_to_img, mols_to_img
from descent.utils.reporting import (
_mol_from_smiles,
figure_to_img,
mols_to_img,
print_force_field_summary,
)


def test_mol_from_smiles():
Expand All @@ -22,3 +30,27 @@ def test_figure_to_img():

assert img.startswith('<img src="data:image/svg+xml;base64,')
assert img.endswith('"></img>')


def test_print_force_field_summary(capsys):
interchange = openff.interchange.Interchange.from_smirnoff(
openff.toolkit.ForceField("tip4p_fb.offxml"),
openff.toolkit.Molecule.from_smiles("O").to_topology(),
)

force_field, _ = smee.converters.convert_interchange(interchange)
print_force_field_summary(force_field)

captured = capsys.readouterr().out

assert "ID distance [Å] inPlaneAngle [rad] outOfPlaneAngle [rad]" in captured
assert (
"[#1:2]-[#8X2H2+0:1]-[#1:3] EP -0.1053 "
"3.1416 0.0000" in captured
)

assert "fn=4*epsilon*((sigma/r)**12-(sigma/r)**6)" in captured
assert "scale_12 scale_13 scale_14 scale_15 cutoff [Å] switch_width [Å]" in captured

assert "ID epsilon [kcal/mol] sigma [Å]" in captured
assert "[#1:2]-[#8X2H2+0:1]-[#1:3] EP 0.0000 1.0000" in captured
96 changes: 96 additions & 0 deletions descent/utils/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import itertools
import typing

import openff.units
import smee

if typing.TYPE_CHECKING:
from matplotlib import pyplot
from rdkit import Chem
Expand Down Expand Up @@ -71,3 +74,96 @@ def figure_to_img(figure: "pyplot.Figure") -> str:
data = base64.b64encode(stream.getvalue()).decode()

return f'<img src="data:image/svg+xml;base64,{data}"></img>'


def _format_unit(unit: openff.units.Unit | None) -> str:
"""Format a unit for display in a table."""

if unit is None or unit == openff.units.unit.dimensionless:
return ""

return f" [{unit: ~P}]"


def _format_parameter_id(id_: typing.Any) -> str:
"""Format a parameter ID for display in a table."""

id_str = id_ if "EP" not in id_ else id_[: id_.index("EP") + 2]
return id_str[:60] + (id_str[60:] and "...")


def print_potential_summary(potential: smee.TensorPotential):
"""Print a summary of the potential parameters to the terminal.
Args:
potential: The potential.
"""
import pandas

parameter_rows = []

for key, value in zip(potential.parameter_keys, potential.parameters.detach()):
row = {"ID": _format_parameter_id(key.id)}
row.update(
{
f"{col}{_format_unit(potential.parameter_units[idx])}": f"{value[idx].item():.4f}"
for idx, col in enumerate(potential.parameter_cols)
}
)
parameter_rows.append(row)

print(f" {potential.type} ".center(88, "="), flush=True)
print(f"fn={potential.fn}", flush=True)

if potential.attributes is not None:
attribute_rows = [
{
f"{col}{_format_unit(potential.attribute_units[idx])}": f"{potential.attributes[idx].item():.4f} "
for idx, col in enumerate(potential.attribute_cols)
}
]
print("")
print("attributes=", flush=True)
print("")
print(pandas.DataFrame(attribute_rows).to_string(index=False), flush=True)

print("")
print("parameters=", flush=True)
print("")
print(pandas.DataFrame(parameter_rows).to_string(index=False), flush=True)


def print_v_site_summary(v_sites: smee.TensorVSites):
import pandas

parameter_rows = []

for key, value in zip(v_sites.keys, v_sites.parameters.detach()):
row = {"ID": _format_parameter_id(key.id)}
row.update(
{
f"{col}{_format_unit(unit)}": f"{value[idx].item():.4f}"
for idx, (col, unit) in enumerate(v_sites.parameter_units.items())
}
)
parameter_rows.append(row)

print(" v-sites ".center(88, "="), flush=True)
print("parameters:", flush=True)
print(pandas.DataFrame(parameter_rows).to_string(index=False), flush=True)


def print_force_field_summary(force_field: smee.TensorForceField):
"""Print a summary of the force field parameters to the terminal.
Args:
force_field: The force field.
"""

if force_field.v_sites is not None:
print_v_site_summary(force_field.v_sites)
print("")

for potential in force_field.potentials:
print_potential_summary(potential)
print("")

0 comments on commit 032865c

Please sign in to comment.