From 86d3d1d561d5b5a995d9d11262467e33e9a598ff Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Fri, 12 Apr 2024 17:32:51 -0400 Subject: [PATCH] Offload assembly to workers --- mofa/assembly/assemble.py | 37 ++++++++++++++++++ mofa/db.py | 2 +- run_parallel_workflow.py | 80 ++++++++++++++++++++++----------------- tests/test_assemble.py | 6 ++- 4 files changed, 89 insertions(+), 36 deletions(-) diff --git a/mofa/assembly/assemble.py b/mofa/assembly/assemble.py index 61d19cb8..50180f55 100644 --- a/mofa/assembly/assemble.py +++ b/mofa/assembly/assemble.py @@ -2,6 +2,7 @@ from tempfile import TemporaryDirectory from typing import Sequence from pathlib import Path +from random import choice import itertools import os import io @@ -16,6 +17,42 @@ _bond_length_path = Path(__file__).parent / "OChemDB_bond_threshold.csv" +def assemble_many(ligand_options: dict[str, list[LigandDescription]], nodes: list[NodeDescription], to_make: int, attempts: int) -> list[MOFRecord]: + """Make many MOFs + + Args: + ligand_options: Many choices for each type of ligand + nodes: List of nodes used for assembly + to_make: Target number to make + attempts: Number of times to attempt per MOF before giving up + Returns: + Up to the target number of MOFs + """ + + output = [] + attempts_remaining = to_make * attempts + while len(output) < to_make and attempts_remaining > 0: + attempts_remaining -= 1 + + # Get a sample of ligands + ligand_choices = {} + requirements = {'COO': 2, 'cyano': 1} # TODO (wardlt): Do not hard code this + for anchor_type, count in requirements.items(): + ligand_choices[anchor_type] = [choice(ligand_options[anchor_type])] * count + + # Attempt assembly + try: + new_mof = assemble_mof( + nodes=nodes, + ligands=ligand_choices, + topology='pcu' + ) + except (ValueError, KeyError, IndexError): + continue + output.append(new_mof) + return output + + def readPillaredPaddleWheelXYZ(fpath, dummyElementCOO="At", dummyElementPillar="Fr"): """Read xyz file of a paddlewheel node diff --git a/mofa/db.py b/mofa/db.py index 5eeca5ec..923c5d1b 100644 --- a/mofa/db.py +++ b/mofa/db.py @@ -20,7 +20,7 @@ def initialize_database(client: MongoClient) -> Collection: collection = client.get_database('mofa').get_collection('mofs') collection.create_index([ - ("name", ASCENDING) + ("name", ASCENDING), ]) return collection diff --git a/run_parallel_workflow.py b/run_parallel_workflow.py index 55ba76d1..d9559228 100644 --- a/run_parallel_workflow.py +++ b/run_parallel_workflow.py @@ -35,7 +35,7 @@ from colmena.queue.redis import RedisQueues from colmena.thinker import BaseThinker, result_processor, task_submitter, ResourceCounter, event_responder, agent -from mofa.assembly.assemble import assemble_mof +from mofa.assembly.assemble import assemble_mof, assemble_many from mofa.assembly.validate import process_ligands from mofa.generator import run_generator, train_generator from mofa.model import MOFRecord, NodeDescription, LigandTemplate, LigandDescription @@ -106,7 +106,8 @@ def __init__(self, node_template: NodeDescription): if hpc_config.num_workers < 2: raise ValueError(f'There must be at least two workers. Supplied: {hpc_config}') - super().__init__(queues, ResourceCounter(hpc_config.num_workers, task_types=['generation', 'lammps', 'cp2k'])) + self.assemble_workers = max(1, hpc_config.num_lammps_workers // 1000) # We may need + super().__init__(queues, ResourceCounter(hpc_config.num_workers + self.assemble_workers, task_types=['generation', 'lammps', 'cp2k', 'assembly'])) self.generator_config = generator_config self.trainer_config = trainer_config self.node_template = node_template @@ -128,12 +129,13 @@ def __init__(self, self.post_md_queue: Queue[Result] = Queue() # Holds MD results ready to be stored # Database of completed MOFs - self.database: dict[str, MOFRecord] = {} + self.database: set[str] = set() # Set aside one GPU for generation self.rec.reallocate(None, 'generation', self.hpc_config.number_inf_workers) self.rec.reallocate(None, 'lammps', self.hpc_config.num_lammps_workers) self.rec.reallocate(None, 'cp2k', self.hpc_config.num_cp2k_workers) + self.rec.reallocate(None, 'assembly', self.assemble_workers) # Settings associated with MOF assembly self.mofs_per_call = hpc_config.num_lammps_workers + 4 @@ -158,7 +160,7 @@ def __init__(self, # Output files self._output_files: dict[str, Path | TextIO] = {} self.generate_write_lock: Lock = Lock() # Two threads write to the same generation output - for name in ['generation-results', 'simulation-results', 'training-results']: + for name in ['generation-results', 'simulation-results', 'training-results', 'assemble-results']: self._output_files[name] = run_dir / f'{name}.json' def __enter__(self): @@ -263,51 +265,60 @@ def process_ligands(self): with self.generate_write_lock: print(result.json(exclude={'inputs', 'value'}), file=self._output_files['generation-results'], flush=True) - @event_responder(event_name='make_mofs') - def assemble_new_mofs(self): - """Pull from the list of ligands and create MOFs. Runs when new MOFs are available""" + @task_submitter(task_type='assembly') + def submit_assembly(self): + """Pull from the list of ligands and create MOFs""" # Check that we have enough ligands to start assembly - for anchor_type in self.generator_config.anchor_types: - have = len(self.ligand_assembly_queue[anchor_type]) - if have < self.generator_config.min_ligand_candidates: - self.logger.info(f'Too few candidate for anchor_type={anchor_type}. have={have}, need={self.generator_config.min_ligand_candidates}') - return + while True: + self.make_mofs.wait() + for anchor_type in self.generator_config.anchor_types: + have = len(self.ligand_assembly_queue[anchor_type]) + if have < self.generator_config.min_ligand_candidates: + self.logger.info(f'Too few candidate for anchor_type={anchor_type}. have={have}, need={self.generator_config.min_ligand_candidates}') + break + else: + break - # Make a certain number of attempts - num_added = 0 - attempts_remaining = self.mofs_per_call * 4 - while num_added < self.mofs_per_call and attempts_remaining > 0: - attempts_remaining -= 1 + # Submit the assembly task + self.queues.send_inputs( + self.ligand_process_queue, + [self.node_template], + self.mofs_per_call, + 4, + method='assemble_many', + topic='assembly', + ) - # Get a sample of ligands - ligand_choices = {} - requirements = {'COO': 2, 'cyano': 1} # TODO (wardlt): Do not hard code this - for anchor_type, count in requirements.items(): - ligand_choices[anchor_type] = [choice(self.ligand_assembly_queue[anchor_type])] * count + @result_processor(topic='assembly') + def store_assembly(self, result: Result): + """Store the MOFs in the ready for LAMMPS queue""" - # Attempt assembly - try: - new_mof = assemble_mof( - nodes=[self.node_template], - ligands=ligand_choices, - topology='pcu' - ) - except (ValueError, KeyError, IndexError): - continue + # Trigger a new one to run + self.rec.release('assembly') + + # Skip if it failed + if not result.success: + self.logger.warning(f'Assembly task failed: {result.failure_info.exception}\nStack: {result.failure_info.traceback}') + # Add them to the database + num_added = 0 + for new_mof in result.value: # Check if a duplicate if new_mof.name in self.database: continue # Add it to the database and work queue num_added += 1 - self.database[new_mof.name] = new_mof + self.database.add(new_mof.name) self.mof_queue.append(new_mof) self.mofs_available.set() self.logger.info(f'Created {num_added} new MOFs. Current queue depth: {len(self.mof_queue)}') + # Save the result + print(result.json(exclude={'inputs', 'value'}), file=self._output_files['assembly-results'], flush=True) + @task_submitter(task_type='lammps') def submit_lammps(self): """Submit an MD simulation""" @@ -610,7 +621,7 @@ def store_cp2k(self, result: Result): # Configure to a use Redis queue, which allows streaming results form other nodes queues = RedisQueues( hostname=args.redis_host, - topics=['generation', 'lammps', 'cp2k', 'training'], + topics=['generation', 'lammps', 'cp2k', 'training', 'assembly'], proxystore_name='redis', proxystore_threshold=10000 ) @@ -723,7 +734,8 @@ def store_cp2k(self, result: Result): (cp2k_fun, {'executors': hpc_config.cp2k_executors}), (compute_partial_charges, {'executors': hpc_config.helper_executors}), (process_ligands, {'executors': hpc_config.helper_executors}), - (raspa_fun, {'executors': hpc_config.helper_executors}) + (raspa_fun, {'executors': hpc_config.helper_executors}), + (assemble_many, {'executors': hpc_config.helper_executors}) ], queues=queues, config=config diff --git a/tests/test_assemble.py b/tests/test_assemble.py index bd95ad83..11dc9900 100644 --- a/tests/test_assemble.py +++ b/tests/test_assemble.py @@ -4,7 +4,7 @@ from pytest import mark from ase.io import read -from mofa.assembly.assemble import assemble_pillaredPaddleWheel_pcuMOF, assemble_mof +from mofa.assembly.assemble import assemble_pillaredPaddleWheel_pcuMOF, assemble_mof, assemble_many from mofa.model import NodeDescription, LigandDescription _files_dir = Path(__file__).parent / 'files' / 'assemble' @@ -58,3 +58,7 @@ def test_assemble(node_name, topology, ligand_counts, file_path): for ligand in mof_record.ligands: assert ligand.xyz is not None assert mof_record.name is not None + + # Test making many assemblies + records = assemble_many(ligands, [node], 4, 1) + assert len(records) == 4