diff --git a/run_parallel_workflow.py b/run_parallel_workflow.py index d9559228..9d6015e7 100644 --- a/run_parallel_workflow.py +++ b/run_parallel_workflow.py @@ -13,7 +13,7 @@ from collections import deque from queue import Queue, Empty from platform import node -from random import shuffle, choice +from random import shuffle from pathlib import Path from threading import Event, Lock import logging @@ -35,7 +35,7 @@ from colmena.queue.redis import RedisQueues from colmena.thinker import BaseThinker, result_processor, task_submitter, ResourceCounter, event_responder, agent -from mofa.assembly.assemble import assemble_mof, assemble_many +from mofa.assembly.assemble import assemble_many from mofa.assembly.validate import process_ligands from mofa.generator import run_generator, train_generator from mofa.model import MOFRecord, NodeDescription, LigandTemplate, LigandDescription @@ -128,8 +128,9 @@ def __init__(self, self.post_md_queue: Queue[Result] = Queue() # Holds MD results ready to be stored - # Database of completed MOFs - self.database: set[str] = set() + # Lists used to avoid duplicates + self.in_progress: dict[str, MOFRecord] = {} + self.seen: set[str] = set() # Set aside one GPU for generation self.rec.reallocate(None, 'generation', self.hpc_config.number_inf_workers) @@ -304,13 +305,13 @@ def store_assembly(self, result: Result): # Add them to the database num_added = 0 for new_mof in result.value: - # Check if a duplicate - if new_mof.name in self.database: + # Avoid duplicates + if new_mof.name in self.seen: continue # Add it to the database and work queue num_added += 1 - self.database.add(new_mof.name) + self.seen.add(new_mof.name) self.mof_queue.append(new_mof) self.mofs_available.set() @@ -340,6 +341,7 @@ def submit_lammps(self): topic='lammps', task_info={'name': to_run.name} ) + self.in_progress[to_run.name] = to_run # Store the MOF record for later use self.logger.info(f'Started MD simulation for mof={to_run.name}. ' f'Simulation queue depth: {len(self.mof_queue)}.') @@ -357,6 +359,8 @@ def store_lammps(self, result: Result): # Retrieve the results if not result.success: self.logger.warning(f'MD task failed: {result.failure_info.exception}\nStack: {result.failure_info.traceback}') + name = result.task_info['name'] + self.in_progress.pop(name) else: self.post_md_queue.put(result) @@ -378,7 +382,7 @@ def process_md_results(self): # Store the trajectory traj = result.value name = result.task_info['name'] - record = self.database[name] + record = self.in_progress.pop(name) self.logger.info(f'Received a trajectory of {len(traj)} frames for mof={name}. Backlog: {self.post_md_queue.qsize()}') # Compute the lattice strain @@ -387,6 +391,7 @@ def process_md_results(self): record.md_trajectory['uff'] = traj_vasp strain = scorer.score_mof(record) record.structure_stability['uff'] = strain + record.times['md-done'] = datetime.now() self.logger.info(f'Lattice change after MD simulation for mof={name}: {strain * 100:.1f}%') # Store the result in MongoDB @@ -436,8 +441,8 @@ def retrain(self): record['md_trajectory'] = {} examples.append(MOFRecord(**record)) if (len(examples) == 0 or len(examples) == last_train_size) and len(examples) < self.trainer_config.maximum_train_size: - self.logger.info(f'The number of training examples for {sort_field} with strain below {self.trainer_config.maximum_strain:.2f} ({len(examples)} is the same ' - f'as the last time we trained DiffLinker ({last_train_size}). Waiting for more data') + self.logger.info(f'The number of training examples for {sort_field} with strain below {self.trainer_config.maximum_strain:.2f}' + f' ({len(examples)} is the same as the last time we trained DiffLinker ({last_train_size}). Waiting for more data') self.start_train.clear() self.start_train.wait() continue @@ -544,6 +549,7 @@ def store_cp2k(self, result: Result): storage_mean, storage_std = result.value record = mofadb.get_records(self.collection, [mof_name])[0] record.gas_storage['CO2'] = (1e4, storage_mean) + record.times['raspa-done'] = datetime.now() mofadb.update_records(self.collection, [record]) # Update and trigger training, in case it's blocked