Skip to content

Commit

Permalink
Switch to training on RASPA later in run (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT authored Apr 10, 2024
1 parent bf13db8 commit 86319dd
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions run_parallel_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(self,
# Settings related for training
self.start_train = Event()
self.initial_weights = self.generator_config.generator_path # Store the starting weights, which we'll always use as a starting point for training
self.num_completed = 0 # Number of MOFs which have finished training
self.num_lammps_completed = 0 # Number of MOFs which have finished stability
self.num_raspa_completed = 0 # Number for which we have gas storage
self.model_iteration = 0 # Which version of the model we used for generating a ligand

# Settings related to scheduling CP2K
Expand Down Expand Up @@ -382,8 +383,8 @@ def process_md_results(self):
self.cp2k_ready.set()

# Determine if we should retrain
self.num_completed += 1
if self.num_completed >= self.trainer_config.minimum_train_size:
self.num_lammps_completed += 1
if self.num_lammps_completed >= self.trainer_config.minimum_train_size:
self.start_train.set() # Either starts or indicates that we have new data

@event_responder(event_name='start_train')
Expand All @@ -393,16 +394,28 @@ def retrain(self):
self.logger.info('Started to retrain DiffLinker')
last_train_size = 0
while not self.done.is_set():
# Get the top MOFs
sort_field = 'structure_stability.uff'
to_include = min(int(self.collection.estimated_document_count() * self.trainer_config.best_fraction), self.trainer_config.maximum_train_size)
# Determine how to select the best MOFs
if self.num_raspa_completed < self.trainer_config.minimum_train_size:
sort_field = 'structure_stability.uff'
to_include = min(int(self.num_lammps_completed * self.trainer_config.best_fraction), self.trainer_config.maximum_train_size)
sort_order = pymongo.ASCENDING
else:
sort_field = 'gas_storage.CO2'
to_include = min(int(self.num_raspa_completed * self.trainer_config.best_fraction), self.trainer_config.maximum_train_size)
sort_order = pymongo.DESCENDING

# Build the query
self.collection.create_index(sort_field)
query = defaultdict(dict)
query[sort_field] = {'$exists': True}
query['structure_stability.uff'] = {'$lt': self.trainer_config.maximum_strain}

cursor = (
self.collection.find(
{sort_field: {'$exists': True, '$lt': self.trainer_config.maximum_strain}},
query,
{'md_trajectory': 0} # Filter out the trajectory to save I/O
)
.sort(sort_field, pymongo.ASCENDING)
.sort(sort_field, sort_order)
.limit(to_include)
)
examples = []
Expand All @@ -411,13 +424,13 @@ def retrain(self):
record['times'] = {}
record['md_trajectory'] = {}
examples.append(MOFRecord(**record))
if len(examples) == 0 or len(examples) == last_train_size:
self.logger.info(f'The number of training examples with strain below {self.trainer_config.maximum_strain:.2f} is the same '
if (len(examples) == 0 or len(examples) == last_train_size) and len(examples) < self.trainer_config.maximum_train_size:
self.logger.info(f'The number of training examples for {sort_field} with strain below {self.trainer_config.maximum_strain:.2f} is the same '
f'as the last time we trained DiffLinker ({last_train_size}). Waiting for more data')
self.start_train.clear()
self.start_train.wait()
continue
self.logger.info(f'Gathered the top {len(examples)} records with strain below {self.trainer_config.maximum_strain:.2f} based on stability')
self.logger.info(f'Gathered the top {len(examples)} with strain below {self.trainer_config.maximum_strain:.2f} records based on {sort_field}')
last_train_size = len(examples) # So we know what the training set size was for the next iteration

# Determine the run directory
Expand All @@ -432,7 +445,7 @@ def retrain(self):
input_kwargs={'examples': examples, 'run_directory': train_dir},
method='train_generator',
topic='training',
task_info={'train_size': len(examples)}
task_info={'train_size': len(examples), 'sort_field': sort_field}
)
self.logger.info('Submitted training. Waiting until complete')

Expand Down Expand Up @@ -516,11 +529,17 @@ def store_cp2k(self, result: Result):
)
self.logger.info(f'Partial charges are complete for {mof_name}. Submitted RASPA')
elif result.method == 'run_GCMC_single':
# Store result
storage_mean, storage_std = result.value
record = mofadb.get_records(self.collection, [mof_name])[0]
record.gas_storage['CO2'] = (1e4, storage_mean)
mofadb.update_records(self.collection, [record])
self.logger.info(f'Stored gas storage capacity for {mof_name}: {storage_mean:.3e} +/- {storage_std:.3e}')

# Update and trigger training, in case it's blocked
self.num_raspa_completed += 1
if self.num_raspa_completed > self.trainer_config.minimum_train_size:
self.start_train.set()
self.logger.info(f'Stored gas storage capacity for {mof_name}: {storage_mean:.3e} +/- {storage_std:.3e}. Completed {self.num_raspa_completed}')
else:
raise ValueError(f'Method not supported: {result.method}')
print(result.json(exclude={'inputs', 'value'}), file=self._output_files['simulation-results'], flush=True)
Expand Down

0 comments on commit 86319dd

Please sign in to comment.