Skip to content

Commit

Permalink
Use separate lists for duplicates, in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Apr 12, 2024
1 parent 86d3d1d commit 80ca85c
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions run_parallel_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections import deque
from queue import Queue, Empty
from platform import node
from random import shuffle, choice
from random import shuffle
from pathlib import Path
from threading import Event, Lock
import logging
Expand All @@ -35,7 +35,7 @@
from colmena.queue.redis import RedisQueues
from colmena.thinker import BaseThinker, result_processor, task_submitter, ResourceCounter, event_responder, agent

from mofa.assembly.assemble import assemble_mof, assemble_many
from mofa.assembly.assemble import assemble_many
from mofa.assembly.validate import process_ligands
from mofa.generator import run_generator, train_generator
from mofa.model import MOFRecord, NodeDescription, LigandTemplate, LigandDescription
Expand Down Expand Up @@ -128,8 +128,9 @@ def __init__(self,

self.post_md_queue: Queue[Result] = Queue() # Holds MD results ready to be stored

# Database of completed MOFs
self.database: set[str] = set()
# Lists used to avoid duplicates
self.in_progress: dict[str, MOFRecord] = {}
self.seen: set[str] = set()

# Set aside one GPU for generation
self.rec.reallocate(None, 'generation', self.hpc_config.number_inf_workers)
Expand Down Expand Up @@ -304,13 +305,13 @@ def store_assembly(self, result: Result):
# Add them to the database
num_added = 0
for new_mof in result.value:
# Check if a duplicate
if new_mof.name in self.database:
# Avoid duplicates
if new_mof.name in self.seen:
continue

# Add it to the database and work queue
num_added += 1
self.database.add(new_mof.name)
self.seen.add(new_mof.name)
self.mof_queue.append(new_mof)
self.mofs_available.set()

Expand Down Expand Up @@ -340,6 +341,7 @@ def submit_lammps(self):
topic='lammps',
task_info={'name': to_run.name}
)
self.in_progress[to_run.name] = to_run # Store the MOF record for later use
self.logger.info(f'Started MD simulation for mof={to_run.name}. '
f'Simulation queue depth: {len(self.mof_queue)}.')

Expand All @@ -357,6 +359,8 @@ def store_lammps(self, result: Result):
# Retrieve the results
if not result.success:
self.logger.warning(f'MD task failed: {result.failure_info.exception}\nStack: {result.failure_info.traceback}')
name = result.task_info['name']
self.in_progress.pop(name)
else:
self.post_md_queue.put(result)

Expand All @@ -378,7 +382,7 @@ def process_md_results(self):
# Store the trajectory
traj = result.value
name = result.task_info['name']
record = self.database[name]
record = self.in_progress.pop(name)
self.logger.info(f'Received a trajectory of {len(traj)} frames for mof={name}. Backlog: {self.post_md_queue.qsize()}')

# Compute the lattice strain
Expand All @@ -387,6 +391,7 @@ def process_md_results(self):
record.md_trajectory['uff'] = traj_vasp
strain = scorer.score_mof(record)
record.structure_stability['uff'] = strain
record.times['md-done'] = datetime.now()
self.logger.info(f'Lattice change after MD simulation for mof={name}: {strain * 100:.1f}%')

# Store the result in MongoDB
Expand Down Expand Up @@ -436,8 +441,8 @@ def retrain(self):
record['md_trajectory'] = {}
examples.append(MOFRecord(**record))
if (len(examples) == 0 or len(examples) == last_train_size) and len(examples) < self.trainer_config.maximum_train_size:
self.logger.info(f'The number of training examples for {sort_field} with strain below {self.trainer_config.maximum_strain:.2f} ({len(examples)} is the same '
f'as the last time we trained DiffLinker ({last_train_size}). Waiting for more data')
self.logger.info(f'The number of training examples for {sort_field} with strain below {self.trainer_config.maximum_strain:.2f}'
f' ({len(examples)} is the same as the last time we trained DiffLinker ({last_train_size}). Waiting for more data')
self.start_train.clear()
self.start_train.wait()
continue
Expand Down Expand Up @@ -544,6 +549,7 @@ def store_cp2k(self, result: Result):
storage_mean, storage_std = result.value
record = mofadb.get_records(self.collection, [mof_name])[0]
record.gas_storage['CO2'] = (1e4, storage_mean)
record.times['raspa-done'] = datetime.now()
mofadb.update_records(self.collection, [record])

# Update and trigger training, in case it's blocked
Expand Down

0 comments on commit 80ca85c

Please sign in to comment.