From 6bb05aac6b2ac7cbf12e1c86b264fa646e6df699 Mon Sep 17 00:00:00 2001 From: Josh Horton Date: Tue, 8 Oct 2024 09:59:46 +0100 Subject: [PATCH] Enable Energy and Force fitting on a GPU (#72) * make sure energy and force predictions are on the same device to allow for GPU training * fix tests dtype --- descent/targets/energy.py | 29 +++++++++++++++++++--------- descent/tests/targets/test_energy.py | 8 ++++---- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/descent/targets/energy.py b/descent/targets/energy.py index 0bbe3ac..ff610d7 100644 --- a/descent/targets/energy.py +++ b/descent/targets/energy.py @@ -6,6 +6,7 @@ import datasets.table import pyarrow import smee +import smee.utils import torch DATA_SCHEMA = pyarrow.schema( @@ -110,12 +111,13 @@ def predict( energy_ref = entry["energy"] forces_ref = entry["forces"].reshape(len(energy_ref), -1, 3) - coords = ( - entry["coords"] - .reshape(len(energy_ref), -1, 3) - .detach() - .requires_grad_(True) + coords_flat = smee.utils.tensor_like( + entry["coords"], force_field.potentials[0].parameters ) + + coords = ( + coords_flat.reshape(len(energy_ref), -1, 3) + ).detach().requires_grad_(True) topology = topologies[smiles] energy_pred = smee.compute_energy(topology, force_field, coords) @@ -150,9 +152,18 @@ def predict( energy_pred_all.append(scale_energy * (energy_pred - energy_pred_0)) forces_pred_all.append(scale_forces * forces_pred.reshape(-1, 3)) + energy_pred_all = torch.cat(energy_pred_all) + forces_pred_all = torch.cat(forces_pred_all) + + energy_ref_all = torch.cat(energy_ref_all) + energy_ref_all = smee.utils.tensor_like(energy_ref_all, energy_pred_all) + + forces_ref_all = torch.cat(forces_ref_all) + forces_ref_all = smee.utils.tensor_like(forces_ref_all, forces_pred_all) + return ( - torch.cat(energy_ref_all), - torch.cat(energy_pred_all), - torch.cat(forces_ref_all), - torch.cat(forces_pred_all), + energy_ref_all, + energy_pred_all, + forces_ref_all, + forces_pred_all, ) diff --git a/descent/tests/targets/test_energy.py b/descent/tests/targets/test_energy.py index 4ca8393..cbd75c0 100644 --- a/descent/tests/targets/test_energy.py +++ b/descent/tests/targets/test_energy.py @@ -78,7 +78,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0], - ] + ], dtype=torch.float64 ) / math.sqrt(6.0 * 3.0), torch.tensor([7.899425506591797, -7.89942741394043]) / math.sqrt(2.0), @@ -90,7 +90,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [0.0, -137.45770263671875, 0.0], [102.62999725341797, 68.72884368896484, 0.0], [-102.62999725341797, 68.72884368896484, 0.0], - ] + ], dtype=torch.float64 ) / math.sqrt(6.0 * 3.0), ), @@ -106,7 +106,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [9.0, 10.0, 11.0], [12.0, 13.0, 14.0], [15.0, 16.0, 17.0], - ] + ], dtype=torch.float64 ), torch.tensor([0.0, -15.798852920532227]), -torch.tensor( @@ -117,7 +117,7 @@ def test_extract_smiles(mock_meoh_entry, mock_hoh_entry): [0.0, -137.45770263671875, 0.0], [102.62999725341797, 68.72884368896484, 0.0], [-102.62999725341797, 68.72884368896484, 0.0], - ] + ], dtype=torch.float64 ), ), ],