Skip to content

Commit

Permalink
Report dimer RMSD using panel (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 9, 2023
1 parent 092e2fe commit c2a172a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
70 changes: 66 additions & 4 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,81 @@ def report(

rows = []

delta_sqr_total = {
force_field_name: torch.zeros(1) for force_field_name in force_fields
}
delta_sqr_count = 0

for dimer in dataset.to_pylist():
energies = {"ref": torch.tensor(dimer["energy"])}
energies.update(
(force_field_name, _predict(dimer, force_field, topologies)[1])
for force_field_name, force_field in force_fields.items()
)

plot_img = _plot_energies(energies)

mol_img = descent.utils.reporting.mols_to_img(
dimer["smiles_a"], dimer["smiles_b"]
)
rows.append({"Dimer": mol_img, "Energy [kcal/mol]": plot_img})
data_row = {"Dimer": mol_img, "Energy [kcal/mol]": _plot_energies(energies)}

for force_field_name in force_fields:
delta_sqr = ((energies["ref"] - energies[force_field_name]) ** 2).sum()
delta_sqr_total[force_field_name] += delta_sqr

rmse = torch.sqrt(delta_sqr / len(energies["ref"]))
data_row[f"RMSE {force_field_name}"] = rmse.item()

delta_sqr_count += len(energies["ref"])

rows.append(data_row)

rmse_total_rows = [
{
"Force Field": force_field_name,
"RMSE [kcal/mol]": torch.sqrt(
delta_sqr_total[force_field_name].sum() / delta_sqr_count
).item(),
}
for force_field_name in force_fields
]

import bokeh.models.widgets.tables
import panel

data_full = pandas.DataFrame(rows)
data_stats = pandas.DataFrame(rmse_total_rows)

rmse_format = bokeh.models.widgets.tables.NumberFormatter(format="0.0000")

formatters_stats = {
col: rmse_format for col in data_stats.columns if col.startswith("RMSE")
}
formatters_full = {
**{col: "html" for col in ["Dimer", "Energy [kcal/mol]"]},
**{col: rmse_format for col in data_full.columns if col.startswith("RMSE")},
}

layout = panel.Column(
"## Statistics",
panel.widgets.Tabulator(
pandas.DataFrame(rmse_total_rows),
show_index=False,
selectable=False,
disabled=True,
formatters=formatters_stats,
configuration={"columnDefaults": {"headerSort": False}},
),
"## Energies",
panel.widgets.Tabulator(
data_full,
show_index=False,
selectable=False,
disabled=True,
formatters=formatters_full,
),
sizing_mode="stretch_width",
scroll=True,
)

output_path.parent.mkdir(parents=True, exist_ok=True)
return pandas.DataFrame(rows).to_html(output_path, escape=False, index=False)
layout.save(output_path, title="Dimers", embed=True)
2 changes: 1 addition & 1 deletion descent/tests/targets/test_dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,6 @@ def test_report(tmp_cwd, mock_dimer, mocker):
report(dataset, {"A": mock_ff}, mock_tops, expected_path)

assert expected_path.exists()
assert expected_path.read_text().startswith("<table border")
assert expected_path.read_text().startswith("<!DOCTYPE html>")

mock_predict_fn.assert_called_once_with(mocker.ANY, mock_ff, mock_tops)
1 change: 1 addition & 0 deletions devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
# Optional packages
- rdkit
- matplotlib-base
- panel

# Examples
- jupyter
Expand Down

0 comments on commit c2a172a

Please sign in to comment.