Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow different topologies for each force field #64

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,17 @@ def _plot_energies(energies: dict[str, torch.Tensor]) -> str:
def report(
dataset: datasets.Dataset,
force_fields: dict[str, smee.TensorForceField],
topologies: dict[str, smee.TensorTopology],
topologies: dict[str, dict[str, smee.TensorTopology]],
output_path: pathlib.Path,
):
"""Generate a report comparing the predicted and reference energies of each dimer.

Args:
dataset: The dataset to generate the report for.
force_fields: The force fields to use to predict the energies.
topologies: The topologies of each monomer. Each key should be a fully
mapped SMILES string.
topologies: The topologies of each monomer for the given force field. Each key
should be a fully mapped SMILES string. The name of the force field must
also be present in force_fields
output_path: The path to write the report to.
"""
import pandas
Expand All @@ -361,7 +362,10 @@ def report(
for dimer in descent.utils.dataset.iter_dataset(dataset):
energies = {"ref": dimer["energy"]}
energies.update(
(force_field_name, _predict(dimer, force_field, topologies)[1])
(
force_field_name,
_predict(dimer, force_field, topologies[force_field_name])[1]
)
for force_field_name, force_field in force_fields.items()
)

Expand Down
2 changes: 1 addition & 1 deletion descent/tests/targets/test_dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_report(tmp_cwd, mock_dimer, mocker):
mock_tops = mocker.MagicMock()

expected_path = tmp_cwd / "report.html"
report(dataset, {"A": mock_ff}, mock_tops, expected_path)
report(dataset, {"A": mock_ff}, {"A": mock_tops}, expected_path)

assert expected_path.exists()
assert expected_path.read_text().startswith("<!DOCTYPE html>")
Expand Down
Loading