Skip to content

Commit

Permalink
Allow different topologies for each force field (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
jthorton authored Jun 3, 2024
1 parent 5be78b2 commit 9ff3a92
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,17 @@ def _plot_energies(energies: dict[str, torch.Tensor]) -> str:
def report(
dataset: datasets.Dataset,
force_fields: dict[str, smee.TensorForceField],
topologies: dict[str, smee.TensorTopology],
topologies: dict[str, dict[str, smee.TensorTopology]],
output_path: pathlib.Path,
):
"""Generate a report comparing the predicted and reference energies of each dimer.
Args:
dataset: The dataset to generate the report for.
force_fields: The force fields to use to predict the energies.
topologies: The topologies of each monomer. Each key should be a fully
mapped SMILES string.
topologies: The topologies of each monomer for the given force field. Each key
should be a fully mapped SMILES string. The name of the force field must
also be present in force_fields
output_path: The path to write the report to.
"""
import pandas
Expand All @@ -361,7 +362,10 @@ def report(
for dimer in descent.utils.dataset.iter_dataset(dataset):
energies = {"ref": dimer["energy"]}
energies.update(
(force_field_name, _predict(dimer, force_field, topologies)[1])
(
force_field_name,
_predict(dimer, force_field, topologies[force_field_name])[1]
)
for force_field_name, force_field in force_fields.items()
)

Expand Down
2 changes: 1 addition & 1 deletion descent/tests/targets/test_dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_report(tmp_cwd, mock_dimer, mocker):
mock_tops = mocker.MagicMock()

expected_path = tmp_cwd / "report.html"
report(dataset, {"A": mock_ff}, mock_tops, expected_path)
report(dataset, {"A": mock_ff}, {"A": mock_tops}, expected_path)

assert expected_path.exists()
assert expected_path.read_text().startswith("<!DOCTYPE html>")
Expand Down

0 comments on commit 9ff3a92

Please sign in to comment.