Skip to content

Commit

Permalink
Report dimer RMSD using itables
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Nov 9, 2023
1 parent 092e2fe commit 431a00d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
65 changes: 61 additions & 4 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,28 @@ def _plot_energies(energies: dict[str, torch.Tensor]) -> str:
return img


def _pandas_html_style() -> str:
"""Return the default HTML style for Pandas DataFrames."""
import itables
import pandas

html = itables.to_html_datatable(pandas.DataFrame())
html_end = html.index("</style>") + len("</style>")

return html[:html_end]


def _pandas_to_html(data_frame: "pandas.DataFrame") -> str:
"""Convert a Pandas DataFrame to an HTML table."""

import itables

html = itables.to_html_datatable(data_frame)
html_start = html.index("</style>") + len("</style>")

return html[html_start:]


def report(
dataset: pyarrow.Table,
force_fields: dict[str, smee.TensorForceField],
Expand All @@ -309,19 +331,54 @@ def report(

rows = []

delta_sqr_total = {
force_field_name: torch.zeros(1) for force_field_name in force_fields
}
delta_sqr_count = 0

for dimer in dataset.to_pylist():
energies = {"ref": torch.tensor(dimer["energy"])}
energies.update(
(force_field_name, _predict(dimer, force_field, topologies)[1])
for force_field_name, force_field in force_fields.items()
)

plot_img = _plot_energies(energies)

mol_img = descent.utils.reporting.mols_to_img(
dimer["smiles_a"], dimer["smiles_b"]
)
rows.append({"Dimer": mol_img, "Energy [kcal/mol]": plot_img})
data_row = {"Dimer": mol_img, "Energy [kcal/mol]": _plot_energies(energies)}

for force_field_name in force_fields:
delta_sqr = ((energies["ref"] - energies[force_field_name]) ** 2).sum()
delta_sqr_total[force_field_name] += delta_sqr

rmse = torch.sqrt(delta_sqr / len(energies["ref"]))
data_row[f"RMSE {force_field_name} [kcal/mol]"] = rmse.item()

delta_sqr_count += len(energies["ref"])

rows.append(data_row)

rmse_total_rows = [
{
"Force Field": force_field_name,
"RMSE [kcal/mol]": torch.sqrt(
delta_sqr_total[force_field_name].sum() / delta_sqr_count
).item(),
}
for force_field_name in force_fields
]

html_style = _pandas_html_style()
html = "\n".join(
[
html_style,
"<h2>Statistics</h2>",
_pandas_to_html(pandas.DataFrame(rmse_total_rows)),
"<h2>Energy Plots</h2>",
_pandas_to_html(pandas.DataFrame(rows)),
]
)

output_path.parent.mkdir(parents=True, exist_ok=True)
return pandas.DataFrame(rows).to_html(output_path, escape=False, index=False)
output_path.write_text(html)
2 changes: 1 addition & 1 deletion descent/tests/targets/test_dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,6 @@ def test_report(tmp_cwd, mock_dimer, mocker):
report(dataset, {"A": mock_ff}, mock_tops, expected_path)

assert expected_path.exists()
assert expected_path.read_text().startswith("<table border")
assert expected_path.read_text().startswith("<style>.itable")

mock_predict_fn.assert_called_once_with(mocker.ANY, mock_ff, mock_tops)
1 change: 1 addition & 0 deletions devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
# Optional packages
- rdkit
- matplotlib-base
- itables

# Examples
- jupyter
Expand Down

0 comments on commit 431a00d

Please sign in to comment.