Skip to content

Commit

Permalink
Enable Energy and Force fitting on a GPU (#72)
Browse files Browse the repository at this point in the history
* make sure energy and force predictions are on the same device to allow for GPU training

* fix tests dtype
  • Loading branch information
jthorton authored Oct 8, 2024
1 parent 92a1396 commit 6bb05aa
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
29 changes: 20 additions & 9 deletions descent/targets/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datasets.table
import pyarrow
import smee
import smee.utils
import torch

DATA_SCHEMA = pyarrow.schema(
Expand Down Expand Up @@ -110,12 +111,13 @@ def predict(
energy_ref = entry["energy"]
forces_ref = entry["forces"].reshape(len(energy_ref), -1, 3)

coords = (
entry["coords"]
.reshape(len(energy_ref), -1, 3)
.detach()
.requires_grad_(True)
coords_flat = smee.utils.tensor_like(
entry["coords"], force_field.potentials[0].parameters
)

coords = (
coords_flat.reshape(len(energy_ref), -1, 3)
).detach().requires_grad_(True)
topology = topologies[smiles]

energy_pred = smee.compute_energy(topology, force_field, coords)
Expand Down Expand Up @@ -150,9 +152,18 @@ def predict(
energy_pred_all.append(scale_energy * (energy_pred - energy_pred_0))
forces_pred_all.append(scale_forces * forces_pred.reshape(-1, 3))

energy_pred_all = torch.cat(energy_pred_all)
forces_pred_all = torch.cat(forces_pred_all)

energy_ref_all = torch.cat(energy_ref_all)
energy_ref_all = smee.utils.tensor_like(energy_ref_all, energy_pred_all)

forces_ref_all = torch.cat(forces_ref_all)
forces_ref_all = smee.utils.tensor_like(forces_ref_all, forces_pred_all)

return (
torch.cat(energy_ref_all),
torch.cat(energy_pred_all),
torch.cat(forces_ref_all),
torch.cat(forces_pred_all),
energy_ref_all,
energy_pred_all,
forces_ref_all,
forces_pred_all,
)
8 changes: 4 additions & 4 deletions descent/tests/targets/test_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
]
], dtype=torch.float64
)
/ math.sqrt(6.0 * 3.0),
torch.tensor([7.899425506591797, -7.89942741394043]) / math.sqrt(2.0),
Expand All @@ -90,7 +90,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[0.0, -137.45770263671875, 0.0],
[102.62999725341797, 68.72884368896484, 0.0],
[-102.62999725341797, 68.72884368896484, 0.0],
]
], dtype=torch.float64
)
/ math.sqrt(6.0 * 3.0),
),
Expand All @@ -106,7 +106,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[9.0, 10.0, 11.0],
[12.0, 13.0, 14.0],
[15.0, 16.0, 17.0],
]
], dtype=torch.float64
),
torch.tensor([0.0, -15.798852920532227]),
-torch.tensor(
Expand All @@ -117,7 +117,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry):
[0.0, -137.45770263671875, 0.0],
[102.62999725341797, 68.72884368896484, 0.0],
[-102.62999725341797, 68.72884368896484, 0.0],
]
], dtype=torch.float64
),
),
],
Expand Down

0 comments on commit 6bb05aa

Please sign in to comment.