Skip to content

Commit

Permalink
Rename evaluate to evaluate loss and remove dataset version
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Sep 12, 2021
1 parent 7599c8f commit c23dcc0
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 107 deletions.
4 changes: 2 additions & 2 deletions descent/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, model_input: Union[Interchange]):
)

@abc.abstractmethod
def evaluate(self, model: ParameterizationModel, **kwargs) -> torch.Tensor:
def evaluate_loss(self, model: ParameterizationModel, **kwargs) -> torch.Tensor:
"""Evaluates the contribution to the total loss function of the data stored
in this entry using a specified model.
Expand All @@ -59,7 +59,7 @@ def __call__(self, model: ParameterizationModel, **kwargs) -> torch.Tensor:
Returns:
The loss contribution of this entry.
"""
return self.evaluate(model, **kwargs)
return self.evaluate_loss(model, **kwargs)


class Dataset(torch.utils.data.IterableDataset[T_co], Generic[T_co]):
Expand Down
59 changes: 2 additions & 57 deletions descent/data/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def _evaluate_loss_contribution(

return data_metric(transformed_computed_tensor, transformed_reference_tensor)

def evaluate(
def evaluate_loss(
self,
model: ParameterizationModel,
energy_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None,
Expand Down Expand Up @@ -782,63 +782,8 @@ def from_optimization_results(
),
total=len(result_tensors),
disable=not verbose,
desc="Building energy contribution objects.",
desc="Building entries.",
)
)

return cls(entries)

def evaluate(
self,
model: ParameterizationModel,
energy_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None,
energy_metric: Optional[LossMetric] = None,
gradient_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None,
gradient_metric: Optional[LossMetric] = None,
hessian_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None,
hessian_metric: Optional[LossMetric] = None,
) -> torch.Tensor:
"""Evaluates the contribution to the total loss function of the data stored
in this set using a specified model.
Args:
model: The model that will return vectorized view of a parameterised
molecule.
energy_transforms: Transforms to apply to the QM and MM energies
before computing the loss metric. By default
``descent.transforms.relative(index=0)`` is used if no value is provided.
energy_metric: The loss metric (e.g. MSE) to compute from the QM and MM
energies. By default ``descent.metrics.mse()`` is used if no value is
provided.
gradient_transforms: Transforms to apply to the QM and MM gradients
before computing the loss metric. By default
``descent.transforms.identity()`` is used if no value is provided.
gradient_metric: The loss metric (e.g. MSE) to compute from the QM and MM
gradients. By default ``descent.metrics.mse()`` is used if no value is
provided.
hessian_transforms: Transforms to apply to the QM and MM hessians
before computing the loss metric. By default
``descent.transforms.identity()`` is used if no value is provided.
hessian_metric: The loss metric (e.g. MSE) to compute from the QM and MM
hessians. By default ``descent.metrics.mse()`` is used if no value is
provided.
Returns:
The loss contribution of this dataset.
"""

loss = torch.zeros(1)

for entry in self._entries:

loss += entry(
model,
energy_transforms=energy_transforms,
energy_metric=energy_metric,
gradient_transforms=gradient_transforms,
gradient_metric=gradient_metric,
hessian_transforms=hessian_transforms,
hessian_metric=hessian_metric,
)

return loss
4 changes: 2 additions & 2 deletions descent/tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class DummyEntry(DatasetEntry):
def evaluate(self, model, **kwargs):
def evaluate_loss(self, model, **kwargs):
pass


Expand All @@ -20,7 +20,7 @@ def test_call(monkeypatch):
evaluate_kwargs = {}

class LocalEntry(DatasetEntry):
def evaluate(self, model, **kwargs):
def evaluate_loss(self, model, **kwargs):
nonlocal evaluate_called
evaluate_called = True
evaluate_kwargs.update(kwargs)
Expand Down
6 changes: 3 additions & 3 deletions descent/tests/data/test_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_evaluate_energies(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_val
reference_energies=expected_energies + torch.ones_like(expected_energies),
)

loss = entry.evaluate(
loss = entry.evaluate_loss(
SMIRNOFFModel([], None),
energy_transforms=lambda x: expected_scale * x,
energy_metric=metrics.mse(),
Expand All @@ -230,7 +230,7 @@ def test_evaluate_gradients(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_va
reference_gradients=expected_gradients + torch.ones_like(expected_gradients),
)

loss = entry.evaluate(
loss = entry.evaluate_loss(
SMIRNOFFModel([], None),
gradient_transforms=lambda x: expected_scale * x,
gradient_metric=metrics.mse(()),
Expand All @@ -255,7 +255,7 @@ def test_evaluate_hessians(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_val
reference_hessians=expected_hessians + torch.ones_like(expected_hessians),
)

loss = entry.evaluate(
loss = entry.evaluate_loss(
SMIRNOFFModel([], None),
hessian_transforms=lambda x: expected_scale * x,
hessian_metric=metrics.mse(()),
Expand Down
91 changes: 48 additions & 43 deletions examples/energy-and-gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 118.99it/s]\n",
"Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 3734.91it/s]\n",
"Building energy contribution objects.: 0%| | 0/1 [00:00<?, ?it/s]Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.\n",
"Building energy contribution objects.: 100%|██████████| 1/1 [00:04<00:00, 4.26s/it]\n"
"Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 182.09it/s]\n",
"Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 2606.78it/s]\n",
"Building entries.: 0%| | 0/1 [00:00<?, ?it/s]Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.\n",
"Building entries.: 100%|██████████| 1/1 [00:03<00:00, 3.35s/it]\n"
]
}
],
Expand Down Expand Up @@ -264,7 +264,7 @@
"outputs": [
{
"data": {
"text/plain": "[('Bonds',\n PotentialKey(id='[#7:1]-[#1:2]', mult=None, associated_handler='Bonds'),\n 'k'),\n ('Bonds',\n PotentialKey(id='[#6X3:1]=[#6X3:2]', mult=None, associated_handler='Bonds'),\n 'k')]"
"text/plain": "[('Bonds',\n PotentialKey(id='[#6X3:1]=[#7X2,#7X3+1:2]', mult=None, associated_handler='Bonds'),\n 'k'),\n ('Bonds',\n PotentialKey(id='[#7:1]-[#1:2]', mult=None, associated_handler='Bonds'),\n 'k')]"
},
"execution_count": 7,
"metadata": {},
Expand Down Expand Up @@ -361,14 +361,14 @@
"text": [
"Epoch 0: loss=633.4320678710938\n",
"Epoch 20: loss=359.64111328125\n",
"Epoch 40: loss=340.6795959472656\n",
"Epoch 40: loss=340.67962646484375\n",
"Epoch 60: loss=334.7128601074219\n",
"Epoch 80: loss=332.68359375\n",
"Epoch 100: loss=332.10784912109375\n",
"Epoch 120: loss=331.9746398925781\n",
"Epoch 140: loss=331.9458923339844\n",
"Epoch 80: loss=332.68365478515625\n",
"Epoch 100: loss=332.1078796386719\n",
"Epoch 120: loss=331.9747009277344\n",
"Epoch 140: loss=331.9459533691406\n",
"Epoch 160: loss=331.9331359863281\n",
"Epoch 180: loss=331.9215087890625\n"
"Epoch 180: loss=331.92156982421875\n"
]
}
],
Expand All @@ -384,20 +384,25 @@
"\n",
"for epoch in range(n_epochs):\n",
"\n",
" loss = training_dataset.evaluate(\n",
" model,\n",
" # Specify that we want use energies relative to the first conformer\n",
" # when evaluating the loss function\n",
" energy_transforms=transforms.relative(index=0),\n",
" # Use the built-in MSE metric when comparing the MM and QM relative\n",
" # energies.\n",
" energy_metric=metrics.mse(),\n",
" # For this example with will use the QM and MM gradients directly when\n",
" # computing the loss function.\n",
" gradient_transforms=transforms.identity(),\n",
" # Use the built-in MSE metric when comparing the MM and QM gradients\n",
" gradient_metric=metrics.mse(),\n",
" )\n",
" loss = torch.zeros(1)\n",
"\n",
" for entry in training_dataset:\n",
"\n",
" loss += entry.evaluate_loss(\n",
" model,\n",
" # Specify that we want use energies relative to the first conformer\n",
" # when evaluating the loss function\n",
" energy_transforms=transforms.relative(index=0),\n",
" # Use the built-in MSE metric when comparing the MM and QM relative\n",
" # energies.\n",
" energy_metric=metrics.mse(),\n",
" # For this example with will use the QM and MM gradients directly when\n",
" # computing the loss function.\n",
" gradient_transforms=transforms.identity(),\n",
" # Use the built-in MSE metric when comparing the MM and QM gradients\n",
" gradient_metric=metrics.mse(),\n",
" )\n",
"\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
Expand All @@ -416,9 +421,9 @@
{
"cell_type": "markdown",
"source": [
"where the only code of note is the ``training_dataset.evaluate`` function that will evaluate the loss function for us.\n",
"This function accepts a number of arguments, but most notable are those that control exactly how the data is transformed\n",
"(i.e. computed relative energies) and what form the loss function should take.\n",
"where the only code of note is the ``evaluate_loss`` function that will compute the loss function for us. This function\n",
"accepts a number of arguments, but most notable are those that control exactly how the data is transformed (i.e.\n",
"compute relative energies) and what form the loss function should take.\n",
"\n",
"We can save our trained parameters back to a SMIRNOFF `.offxml` file for future use:"
],
Expand Down Expand Up @@ -464,26 +469,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Bonds SMIRKS=[#6X3:1]=[#7X2,#7X3+1:2] ATTR=k INITIAL=882.4191878243 kcal/(A**2 mol) FINAL=881.926830095258 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#7:1]-[#1:2] ATTR=k INITIAL=997.7547006218 kcal/(A**2 mol) FINAL=997.2764430951689 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]=[#6X3:2] ATTR=k INITIAL=857.1115548611 kcal/(A**2 mol) FINAL=856.6721658595291 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#7X2:2] ATTR=k INITIAL=837.2647972972 kcal/(A**2 mol) FINAL=837.6406442714525 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#1:2] ATTR=k INITIAL=758.0931772913 kcal/(A**2 mol) FINAL=757.6152659958559 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6:1]=[#8X1+0,#8X2+1:2] ATTR=k INITIAL=1135.595318618 kcal/(A**2 mol) FINAL=1135.1173284290276 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]=[#6X3:2] ATTR=k INITIAL=857.1115548611 kcal/(A**2 mol) FINAL=856.6721659734958 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#6X3:2]=[#8X1+0] ATTR=k INITIAL=612.0537081219 kcal/(A**2 mol) FINAL=612.5316377660014 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#6X3:2] ATTR=k INITIAL=612.5097961064 kcal/(A**2 mol) FINAL=612.0317954609721 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6:1]=[#8X1+0,#8X2+1:2] ATTR=k INITIAL=1135.595318618 kcal/(A**2 mol) FINAL=1135.1173284290276 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#7X3:2] ATTR=k INITIAL=719.219372554 kcal/(A**2 mol) FINAL=718.7531744829888 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#1:2] ATTR=k INITIAL=808.1394472833 kcal/(A**2 mol) FINAL=807.661493535217 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#7X2:2] ATTR=k INITIAL=837.2647972972 kcal/(A**2 mol) FINAL=837.640644214469 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6:1]-[#7:2] ATTR=k INITIAL=719.6326854584 kcal/(A**2 mol) FINAL=719.1547107119311 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#7X3:2]-[#6X3]=[#8X1+0] ATTR=k INITIAL=764.7120801727 kcal/(A**2 mol) FINAL=765.1886159741499 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]-[#1:2] ATTR=k INITIAL=808.1394472833 kcal/(A**2 mol) FINAL=807.6614935637086 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1]=[#7X2,#7X3+1:2] ATTR=k INITIAL=882.4191878243 kcal/(A**2 mol) FINAL=881.9268301522415 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X3:1](=[#8X1+0])-[#7X3:2] ATTR=k INITIAL=1053.970761594 kcal/(A**2 mol) FINAL=1053.4928622650716 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6:1]-[#7:2] ATTR=k INITIAL=719.6326854584 kcal/(A**2 mol) FINAL=719.1547107689146 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#6X3:2]=[#8X1+0] ATTR=k INITIAL=612.0537081219 kcal/(A**2 mol) FINAL=612.531637709018 kcal/(A**2 mol)\n",
"Angles SMIRKS=[#1:1]-[#6X4:2]-[#1:3] ATTR=k INITIAL=74.28701527177 kcal/(mol rad**2) FINAL=3.156618157298837 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#7X3$(*~[#6X3,#6X2,#7X2+0]):2]~[*:3] ATTR=k INITIAL=112.545110149 kcal/(mol rad**2) FINAL=94.1024970424084 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#7X2+0:2]~[*:3] ATTR=k INITIAL=226.9001499199 kcal/(mol rad**2) FINAL=31.499088021470783 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#6X3:2]~[*:3] ATTR=k INITIAL=153.5899485526 kcal/(mol rad**2) FINAL=40.37185554971532 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#6X4:2]-[*:3] ATTR=k INITIAL=101.7373362367 kcal/(mol rad**2) FINAL=26.406389308094433 kcal/(mol rad**2)\n",
"Bonds SMIRKS=[#6X3:1](=[#8X1+0])-[#7X3:2] ATTR=k INITIAL=1053.970761594 kcal/(A**2 mol) FINAL=1053.4928622935631 kcal/(A**2 mol)\n",
"Bonds SMIRKS=[#6X4:1]-[#1:2] ATTR=k INITIAL=758.0931772913 kcal/(A**2 mol) FINAL=757.6152659958559 kcal/(A**2 mol)\n",
"Angles SMIRKS=[#1:1]-[#7X3$(*~[#6X3,#6X2,#7X2+0]):2]-[*:3] ATTR=k INITIAL=77.52610202633 kcal/(mol rad**2) FINAL=191.91446506590137 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~;!@[*;X3;r5:2]~;@[*;r5:3] ATTR=k INITIAL=70.66802196994 kcal/(mol rad**2) FINAL=7.836212593216928 kcal/(mol rad**2)\n"
"Angles SMIRKS=[#1:1]-[#6X4:2]-[#1:3] ATTR=k INITIAL=74.28701527177 kcal/(mol rad**2) FINAL=3.1566064657054795 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#7X2+0:2]~[*:3] ATTR=k INITIAL=226.9001499199 kcal/(mol rad**2) FINAL=31.499088021470783 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#6X4:2]-[*:3] ATTR=k INITIAL=101.7373362367 kcal/(mol rad**2) FINAL=26.406365924907732 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#7X3$(*~[#6X3,#6X2,#7X2+0]):2]~[*:3] ATTR=k INITIAL=112.545110149 kcal/(mol rad**2) FINAL=94.10249558095923 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~;!@[*;X3;r5:2]~;@[*;r5:3] ATTR=k INITIAL=70.66802196994 kcal/(mol rad**2) FINAL=7.836206747420249 kcal/(mol rad**2)\n",
"Angles SMIRKS=[*:1]~[#6X3:2]~[*:3] ATTR=k INITIAL=153.5899485526 kcal/(mol rad**2) FINAL=40.37185554971532 kcal/(mol rad**2)\n"
]
}
],
Expand Down

0 comments on commit c23dcc0

Please sign in to comment.