diff --git a/descent/targets/dimers.py b/descent/targets/dimers.py index b76da37..95c147a 100644 --- a/descent/targets/dimers.py +++ b/descent/targets/dimers.py @@ -309,6 +309,11 @@ def report( rows = [] + delta_sqr_total = { + force_field_name: torch.zeros(1) for force_field_name in force_fields + } + delta_sqr_count = 0 + for dimer in dataset.to_pylist(): energies = {"ref": torch.tensor(dimer["energy"])} energies.update( @@ -316,12 +321,69 @@ def report( for force_field_name, force_field in force_fields.items() ) - plot_img = _plot_energies(energies) - mol_img = descent.utils.reporting.mols_to_img( dimer["smiles_a"], dimer["smiles_b"] ) - rows.append({"Dimer": mol_img, "Energy [kcal/mol]": plot_img}) + data_row = {"Dimer": mol_img, "Energy [kcal/mol]": _plot_energies(energies)} + + for force_field_name in force_fields: + delta_sqr = ((energies["ref"] - energies[force_field_name]) ** 2).sum() + delta_sqr_total[force_field_name] += delta_sqr + + rmse = torch.sqrt(delta_sqr / len(energies["ref"])) + data_row[f"RMSE {force_field_name}"] = rmse.item() + + delta_sqr_count += len(energies["ref"]) + + rows.append(data_row) + + rmse_total_rows = [ + { + "Force Field": force_field_name, + "RMSE [kcal/mol]": torch.sqrt( + delta_sqr_total[force_field_name].sum() / delta_sqr_count + ).item(), + } + for force_field_name in force_fields + ] + + import bokeh.models.widgets.tables + import panel + + data_full = pandas.DataFrame(rows) + data_stats = pandas.DataFrame(rmse_total_rows) + + rmse_format = bokeh.models.widgets.tables.NumberFormatter(format="0.0000") + + formatters_stats = { + col: rmse_format for col in data_stats.columns if col.startswith("RMSE") + } + formatters_full = { + **{col: "html" for col in ["Dimer", "Energy [kcal/mol]"]}, + **{col: rmse_format for col in data_full.columns if col.startswith("RMSE")}, + } + + layout = panel.Column( + "## Statistics", + panel.widgets.Tabulator( + pandas.DataFrame(rmse_total_rows), + show_index=False, + selectable=False, + disabled=True, + formatters=formatters_stats, + configuration={"columnDefaults": {"headerSort": False}}, + ), + "## Energies", + panel.widgets.Tabulator( + data_full, + show_index=False, + selectable=False, + disabled=True, + formatters=formatters_full, + ), + sizing_mode="stretch_width", + scroll=True, + ) output_path.parent.mkdir(parents=True, exist_ok=True) - return pandas.DataFrame(rows).to_html(output_path, escape=False, index=False) + layout.save(output_path, title="Dimers", embed=True) diff --git a/descent/tests/targets/test_dimers.py b/descent/tests/targets/test_dimers.py index dd6f692..defb3fe 100644 --- a/descent/tests/targets/test_dimers.py +++ b/descent/tests/targets/test_dimers.py @@ -201,6 +201,6 @@ def test_report(tmp_cwd, mock_dimer, mocker): report(dataset, {"A": mock_ff}, mock_tops, expected_path) assert expected_path.exists() - assert expected_path.read_text().startswith("") mock_predict_fn.assert_called_once_with(mocker.ANY, mock_ff, mock_tops) diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml index 43f6946..156743b 100644 --- a/devtools/envs/base.yaml +++ b/devtools/envs/base.yaml @@ -21,6 +21,7 @@ dependencies: # Optional packages - rdkit - matplotlib-base + - panel # Examples - jupyter