Skip to content

Commit

Permalink
Add verbose option to energy objectives from QC results (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Sep 10, 2021
1 parent 77abd56 commit e875ae2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 34 deletions.
80 changes: 49 additions & 31 deletions descent/objectives/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from smirnoffee.smirnoff import vectorize_system
from torch._vmap_internals import vmap
from torch.autograd import grad
from tqdm import tqdm
from typing_extensions import Literal

from descent import metrics, transforms
Expand Down Expand Up @@ -656,6 +657,7 @@ def _retrieve_gradient_and_hessians(
optimization_results: "OptimizationResultCollection",
include_gradients: bool,
include_hessians: bool,
verbose: bool = True,
) -> Tuple[
Dict[Tuple[str, "ObjectId"], torch.Tensor],
Dict[Tuple[str, "ObjectId"], torch.Tensor],
Expand All @@ -668,6 +670,7 @@ def _retrieve_gradient_and_hessians(
gradients and hessians should be retrieved where available.
include_gradients: Whether to retrieve gradient values.
include_hessians: Whether to retrieve hessian values.
verbose: Whether to log progress to the terminal.
Returns:
The values of the gradients and hessians (if requested) stored in
Expand All @@ -689,7 +692,11 @@ def _retrieve_gradient_and_hessians(

qc_gradients, qc_hessians = {}, {}

for qc_record, _ in basic_result_collection.to_records():
for qc_record, _ in tqdm(
basic_result_collection.to_records(),
desc="Pulling gradient / hessian data",
disable=not verbose,
):

address = qc_record.client.address

Expand Down Expand Up @@ -762,6 +769,7 @@ def from_optimization_results(
hessian_metric: Optional[LossMetric] = None,
hessian_coordinate_system: Literal["cartesian", "ric"] = "cartesian",
n_processes: int = 1,
verbose: bool = True,
) -> List["EnergyObjective"]:
"""Creates a list of energy objective contribution terms (one per unique
molecule) from the **final** structures a set of QC optimization results.
Expand Down Expand Up @@ -789,6 +797,7 @@ def from_optimization_results(
hessian_coordinate_system: The coordinate system to project the QM and MM
hessians to before computing the loss metric.
n_processes: The number of processes to parallelize this function across.
verbose: Whether to log progress to the terminal.
Returns:
A list of the energy objective terms.
Expand All @@ -799,7 +808,11 @@ def from_optimization_results(
# Group the results by molecule ignoring stereochemistry
per_molecule_records = defaultdict(list)

for qc_record, molecule in optimization_results.to_records():
for qc_record, molecule in tqdm(
optimization_results.to_records(),
desc="Pulling main optimisation records",
disable=not verbose,
):

molecule: Molecule = molecule.canonical_order_atoms()
conformer = molecule.conformers[0].value_in_unit(simtk_unit.angstrom)
Expand All @@ -811,7 +824,7 @@ def from_optimization_results(
per_molecule_records[smiles].append((qc_record, conformer))

qc_gradients, qc_hessians = cls._retrieve_gradient_and_hessians(
optimization_results, include_gradients, include_hessians
optimization_results, include_gradients, include_hessians, verbose
)

result_tensors = []
Expand Down Expand Up @@ -855,36 +868,41 @@ def from_optimization_results(
# We need to dill any potential lambda functions as the default
# multiprocessing pickler cannot handle these by default.
contributions = list(
pool.imap(
functools.partial(
cls._from_grouped_results,
force_field=initial_force_field,
energy_transforms=dill.dumps(
energy_transforms if include_energies else None
),
energy_metric=dill.dumps(
energy_metric if include_energies else None
),
gradient_transforms=dill.dumps(
gradient_transforms if include_gradients else None
),
gradient_metric=dill.dumps(
gradient_metric if include_gradients else None
),
gradient_coordinate_system=gradient_coordinate_system
if include_gradients
else None,
hessian_transforms=dill.dumps(
hessian_transforms if include_hessians else None
),
hessian_metric=dill.dumps(
hessian_metric if include_hessians else None
tqdm(
pool.imap(
functools.partial(
cls._from_grouped_results,
force_field=initial_force_field,
energy_transforms=dill.dumps(
energy_transforms if include_energies else None
),
energy_metric=dill.dumps(
energy_metric if include_energies else None
),
gradient_transforms=dill.dumps(
gradient_transforms if include_gradients else None
),
gradient_metric=dill.dumps(
gradient_metric if include_gradients else None
),
gradient_coordinate_system=gradient_coordinate_system
if include_gradients
else None,
hessian_transforms=dill.dumps(
hessian_transforms if include_hessians else None
),
hessian_metric=dill.dumps(
hessian_metric if include_hessians else None
),
hessian_coordinate_system=hessian_coordinate_system
if include_hessians
else None,
),
hessian_coordinate_system=hessian_coordinate_system
if include_hessians
else None,
result_tensors,
),
result_tensors,
total=len(result_tensors),
disable=not verbose,
desc="Building energy contribution objects.",
)
)

Expand Down
3 changes: 1 addition & 2 deletions devtools/conda-envs/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ dependencies:
- python
- pip
- dill

- click
- tqdm

- openff-toolkit-base >=0.9.2
- openff-interchange ==0.1.0
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ force_grid_wrap=0
use_parentheses=True
line_length=88
known_third_party=
click
dill
geometric
openff
pydantic
smirnoffee
torch
tqdm

[versioneer]
# Automatic version numbering scheme
Expand Down

0 comments on commit e875ae2

Please sign in to comment.