Skip to content

Commit

Permalink
Offload assembly to workers
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Apr 12, 2024
1 parent 9652594 commit 86d3d1d
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 36 deletions.
37 changes: 37 additions & 0 deletions mofa/assembly/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tempfile import TemporaryDirectory
from typing import Sequence
from pathlib import Path
from random import choice
import itertools
import os
import io
Expand All @@ -16,6 +17,42 @@
_bond_length_path = Path(__file__).parent / "OChemDB_bond_threshold.csv"


def assemble_many(ligand_options: dict[str, list[LigandDescription]], nodes: list[NodeDescription], to_make: int, attempts: int) -> list[MOFRecord]:
"""Make many MOFs
Args:
ligand_options: Many choices for each type of ligand
nodes: List of nodes used for assembly
to_make: Target number to make
attempts: Number of times to attempt per MOF before giving up
Returns:
Up to the target number of MOFs
"""

output = []
attempts_remaining = to_make * attempts
while len(output) < to_make and attempts_remaining > 0:
attempts_remaining -= 1

# Get a sample of ligands
ligand_choices = {}
requirements = {'COO': 2, 'cyano': 1} # TODO (wardlt): Do not hard code this
for anchor_type, count in requirements.items():
ligand_choices[anchor_type] = [choice(ligand_options[anchor_type])] * count

# Attempt assembly
try:
new_mof = assemble_mof(
nodes=nodes,
ligands=ligand_choices,
topology='pcu'
)
except (ValueError, KeyError, IndexError):
continue
output.append(new_mof)
return output


def readPillaredPaddleWheelXYZ(fpath, dummyElementCOO="At", dummyElementPillar="Fr"):
"""Read xyz file of a paddlewheel node
Expand Down
2 changes: 1 addition & 1 deletion mofa/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def initialize_database(client: MongoClient) -> Collection:

collection = client.get_database('mofa').get_collection('mofs')
collection.create_index([
("name", ASCENDING)
("name", ASCENDING),
])
return collection

Expand Down
80 changes: 46 additions & 34 deletions run_parallel_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from colmena.queue.redis import RedisQueues
from colmena.thinker import BaseThinker, result_processor, task_submitter, ResourceCounter, event_responder, agent

from mofa.assembly.assemble import assemble_mof
from mofa.assembly.assemble import assemble_mof, assemble_many
from mofa.assembly.validate import process_ligands
from mofa.generator import run_generator, train_generator
from mofa.model import MOFRecord, NodeDescription, LigandTemplate, LigandDescription
Expand Down Expand Up @@ -106,7 +106,8 @@ def __init__(self,
node_template: NodeDescription):
if hpc_config.num_workers < 2:
raise ValueError(f'There must be at least two workers. Supplied: {hpc_config}')
super().__init__(queues, ResourceCounter(hpc_config.num_workers, task_types=['generation', 'lammps', 'cp2k']))
self.assemble_workers = max(1, hpc_config.num_lammps_workers // 1000) # We may need
super().__init__(queues, ResourceCounter(hpc_config.num_workers + self.assemble_workers, task_types=['generation', 'lammps', 'cp2k', 'assembly']))
self.generator_config = generator_config
self.trainer_config = trainer_config
self.node_template = node_template
Expand All @@ -128,12 +129,13 @@ def __init__(self,
self.post_md_queue: Queue[Result] = Queue() # Holds MD results ready to be stored

# Database of completed MOFs
self.database: dict[str, MOFRecord] = {}
self.database: set[str] = set()

# Set aside one GPU for generation
self.rec.reallocate(None, 'generation', self.hpc_config.number_inf_workers)
self.rec.reallocate(None, 'lammps', self.hpc_config.num_lammps_workers)
self.rec.reallocate(None, 'cp2k', self.hpc_config.num_cp2k_workers)
self.rec.reallocate(None, 'assembly', self.assemble_workers)

# Settings associated with MOF assembly
self.mofs_per_call = hpc_config.num_lammps_workers + 4
Expand All @@ -158,7 +160,7 @@ def __init__(self,
# Output files
self._output_files: dict[str, Path | TextIO] = {}
self.generate_write_lock: Lock = Lock() # Two threads write to the same generation output
for name in ['generation-results', 'simulation-results', 'training-results']:
for name in ['generation-results', 'simulation-results', 'training-results', 'assemble-results']:
self._output_files[name] = run_dir / f'{name}.json'

def __enter__(self):
Expand Down Expand Up @@ -263,51 +265,60 @@ def process_ligands(self):
with self.generate_write_lock:
print(result.json(exclude={'inputs', 'value'}), file=self._output_files['generation-results'], flush=True)

@event_responder(event_name='make_mofs')
def assemble_new_mofs(self):
"""Pull from the list of ligands and create MOFs. Runs when new MOFs are available"""
@task_submitter(task_type='assembly')
def submit_assembly(self):
"""Pull from the list of ligands and create MOFs"""

# Check that we have enough ligands to start assembly
for anchor_type in self.generator_config.anchor_types:
have = len(self.ligand_assembly_queue[anchor_type])
if have < self.generator_config.min_ligand_candidates:
self.logger.info(f'Too few candidate for anchor_type={anchor_type}. have={have}, need={self.generator_config.min_ligand_candidates}')
return
while True:
self.make_mofs.wait()
for anchor_type in self.generator_config.anchor_types:
have = len(self.ligand_assembly_queue[anchor_type])
if have < self.generator_config.min_ligand_candidates:
self.logger.info(f'Too few candidate for anchor_type={anchor_type}. have={have}, need={self.generator_config.min_ligand_candidates}')
break
else:
break

# Make a certain number of attempts
num_added = 0
attempts_remaining = self.mofs_per_call * 4
while num_added < self.mofs_per_call and attempts_remaining > 0:
attempts_remaining -= 1
# Submit the assembly task
self.queues.send_inputs(
self.ligand_process_queue,
[self.node_template],
self.mofs_per_call,
4,
method='assemble_many',
topic='assembly',
)

# Get a sample of ligands
ligand_choices = {}
requirements = {'COO': 2, 'cyano': 1} # TODO (wardlt): Do not hard code this
for anchor_type, count in requirements.items():
ligand_choices[anchor_type] = [choice(self.ligand_assembly_queue[anchor_type])] * count
@result_processor(topic='assembly')
def store_assembly(self, result: Result):
"""Store the MOFs in the ready for LAMMPS queue"""

# Attempt assembly
try:
new_mof = assemble_mof(
nodes=[self.node_template],
ligands=ligand_choices,
topology='pcu'
)
except (ValueError, KeyError, IndexError):
continue
# Trigger a new one to run
self.rec.release('assembly')

# Skip if it failed
if not result.success:
self.logger.warning(f'Assembly task failed: {result.failure_info.exception}\nStack: {result.failure_info.traceback}')

# Add them to the database
num_added = 0
for new_mof in result.value:
# Check if a duplicate
if new_mof.name in self.database:
continue

# Add it to the database and work queue
num_added += 1
self.database[new_mof.name] = new_mof
self.database.add(new_mof.name)
self.mof_queue.append(new_mof)
self.mofs_available.set()

self.logger.info(f'Created {num_added} new MOFs. Current queue depth: {len(self.mof_queue)}')

# Save the result
print(result.json(exclude={'inputs', 'value'}), file=self._output_files['assembly-results'], flush=True)

@task_submitter(task_type='lammps')
def submit_lammps(self):
"""Submit an MD simulation"""
Expand Down Expand Up @@ -610,7 +621,7 @@ def store_cp2k(self, result: Result):
# Configure to a use Redis queue, which allows streaming results form other nodes
queues = RedisQueues(
hostname=args.redis_host,
topics=['generation', 'lammps', 'cp2k', 'training'],
topics=['generation', 'lammps', 'cp2k', 'training', 'assembly'],
proxystore_name='redis',
proxystore_threshold=10000
)
Expand Down Expand Up @@ -723,7 +734,8 @@ def store_cp2k(self, result: Result):
(cp2k_fun, {'executors': hpc_config.cp2k_executors}),
(compute_partial_charges, {'executors': hpc_config.helper_executors}),
(process_ligands, {'executors': hpc_config.helper_executors}),
(raspa_fun, {'executors': hpc_config.helper_executors})
(raspa_fun, {'executors': hpc_config.helper_executors}),
(assemble_many, {'executors': hpc_config.helper_executors})
],
queues=queues,
config=config
Expand Down
6 changes: 5 additions & 1 deletion tests/test_assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pytest import mark
from ase.io import read

from mofa.assembly.assemble import assemble_pillaredPaddleWheel_pcuMOF, assemble_mof
from mofa.assembly.assemble import assemble_pillaredPaddleWheel_pcuMOF, assemble_mof, assemble_many
from mofa.model import NodeDescription, LigandDescription

_files_dir = Path(__file__).parent / 'files' / 'assemble'
Expand Down Expand Up @@ -58,3 +58,7 @@ def test_assemble(node_name, topology, ligand_counts, file_path):
for ligand in mof_record.ligands:
assert ligand.xyz is not None
assert mof_record.name is not None

# Test making many assemblies
records = assemble_many(ligands, [node], 4, 1)
assert len(records) == 4

0 comments on commit 86d3d1d

Please sign in to comment.