diff --git a/descent/tests/utils/test_reporting.py b/descent/tests/utils/test_reporting.py index 0fa140a..065790f 100644 --- a/descent/tests/utils/test_reporting.py +++ b/descent/tests/utils/test_reporting.py @@ -1,7 +1,15 @@ +import openff.interchange +import openff.toolkit +import smee.converters from matplotlib import pyplot from rdkit import Chem -from descent.utils.reporting import _mol_from_smiles, figure_to_img, mols_to_img +from descent.utils.reporting import ( + _mol_from_smiles, + figure_to_img, + mols_to_img, + print_force_field_summary, +) def test_mol_from_smiles(): @@ -22,3 +30,27 @@ def test_figure_to_img(): assert img.startswith('') + + +def test_print_force_field_summary(capsys): + interchange = openff.interchange.Interchange.from_smirnoff( + openff.toolkit.ForceField("tip4p_fb.offxml"), + openff.toolkit.Molecule.from_smiles("O").to_topology(), + ) + + force_field, _ = smee.converters.convert_interchange(interchange) + print_force_field_summary(force_field) + + captured = capsys.readouterr().out + + assert "ID distance [Å] inPlaneAngle [rad] outOfPlaneAngle [rad]" in captured + assert ( + "[#1:2]-[#8X2H2+0:1]-[#1:3] EP -0.1053 " + "3.1416 0.0000" in captured + ) + + assert "fn=4*epsilon*((sigma/r)**12-(sigma/r)**6)" in captured + assert "scale_12 scale_13 scale_14 scale_15 cutoff [Å] switch_width [Å]" in captured + + assert "ID epsilon [kcal/mol] sigma [Å]" in captured + assert "[#1:2]-[#8X2H2+0:1]-[#1:3] EP 0.0000 1.0000" in captured diff --git a/descent/utils/reporting.py b/descent/utils/reporting.py index f128d56..e27e294 100644 --- a/descent/utils/reporting.py +++ b/descent/utils/reporting.py @@ -4,6 +4,9 @@ import itertools import typing +import openff.units +import smee + if typing.TYPE_CHECKING: from matplotlib import pyplot from rdkit import Chem @@ -71,3 +74,96 @@ def figure_to_img(figure: "pyplot.Figure") -> str: data = base64.b64encode(stream.getvalue()).decode() return f'' + + +def _format_unit(unit: openff.units.Unit | None) -> str: + """Format a unit for display in a table.""" + + if unit is None or unit == openff.units.unit.dimensionless: + return "" + + return f" [{unit: ~P}]" + + +def _format_parameter_id(id_: typing.Any) -> str: + """Format a parameter ID for display in a table.""" + + id_str = id_ if "EP" not in id_ else id_[: id_.index("EP") + 2] + return id_str[:60] + (id_str[60:] and "...") + + +def print_potential_summary(potential: smee.TensorPotential): + """Print a summary of the potential parameters to the terminal. + + Args: + potential: The potential. + """ + import pandas + + parameter_rows = [] + + for key, value in zip(potential.parameter_keys, potential.parameters.detach()): + row = {"ID": _format_parameter_id(key.id)} + row.update( + { + f"{col}{_format_unit(potential.parameter_units[idx])}": f"{value[idx].item():.4f}" + for idx, col in enumerate(potential.parameter_cols) + } + ) + parameter_rows.append(row) + + print(f" {potential.type} ".center(88, "="), flush=True) + print(f"fn={potential.fn}", flush=True) + + if potential.attributes is not None: + attribute_rows = [ + { + f"{col}{_format_unit(potential.attribute_units[idx])}": f"{potential.attributes[idx].item():.4f} " + for idx, col in enumerate(potential.attribute_cols) + } + ] + print("") + print("attributes=", flush=True) + print("") + print(pandas.DataFrame(attribute_rows).to_string(index=False), flush=True) + + print("") + print("parameters=", flush=True) + print("") + print(pandas.DataFrame(parameter_rows).to_string(index=False), flush=True) + + +def print_v_site_summary(v_sites: smee.TensorVSites): + import pandas + + parameter_rows = [] + + for key, value in zip(v_sites.keys, v_sites.parameters.detach()): + row = {"ID": _format_parameter_id(key.id)} + row.update( + { + f"{col}{_format_unit(unit)}": f"{value[idx].item():.4f}" + for idx, (col, unit) in enumerate(v_sites.parameter_units.items()) + } + ) + parameter_rows.append(row) + + print(" v-sites ".center(88, "="), flush=True) + print("parameters:", flush=True) + print(pandas.DataFrame(parameter_rows).to_string(index=False), flush=True) + + +def print_force_field_summary(force_field: smee.TensorForceField): + """Print a summary of the force field parameters to the terminal. + + Args: + force_field: The force field. + """ + + if force_field.v_sites is not None: + print_v_site_summary(force_field.v_sites) + print("") + + for potential in force_field.potentials: + print_potential_summary(potential) + print("")