From 86319ddac9fb7ac5ac0e6cb6ede71e2838829416 Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Wed, 10 Apr 2024 13:46:24 -0400 Subject: [PATCH] Switch to training on RASPA later in run (#124) --- run_parallel_workflow.py | 45 ++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/run_parallel_workflow.py b/run_parallel_workflow.py index 045b16a3..33bce130 100644 --- a/run_parallel_workflow.py +++ b/run_parallel_workflow.py @@ -143,7 +143,8 @@ def __init__(self, # Settings related for training self.start_train = Event() self.initial_weights = self.generator_config.generator_path # Store the starting weights, which we'll always use as a starting point for training - self.num_completed = 0 # Number of MOFs which have finished training + self.num_lammps_completed = 0 # Number of MOFs which have finished stability + self.num_raspa_completed = 0 # Number for which we have gas storage self.model_iteration = 0 # Which version of the model we used for generating a ligand # Settings related to scheduling CP2K @@ -382,8 +383,8 @@ def process_md_results(self): self.cp2k_ready.set() # Determine if we should retrain - self.num_completed += 1 - if self.num_completed >= self.trainer_config.minimum_train_size: + self.num_lammps_completed += 1 + if self.num_lammps_completed >= self.trainer_config.minimum_train_size: self.start_train.set() # Either starts or indicates that we have new data @event_responder(event_name='start_train') @@ -393,16 +394,28 @@ def retrain(self): self.logger.info('Started to retrain DiffLinker') last_train_size = 0 while not self.done.is_set(): - # Get the top MOFs - sort_field = 'structure_stability.uff' - to_include = min(int(self.collection.estimated_document_count() * self.trainer_config.best_fraction), self.trainer_config.maximum_train_size) + # Determine how to select the best MOFs + if self.num_raspa_completed < self.trainer_config.minimum_train_size: + sort_field = 'structure_stability.uff' + to_include = min(int(self.num_lammps_completed * self.trainer_config.best_fraction), self.trainer_config.maximum_train_size) + sort_order = pymongo.ASCENDING + else: + sort_field = 'gas_storage.CO2' + to_include = min(int(self.num_raspa_completed * self.trainer_config.best_fraction), self.trainer_config.maximum_train_size) + sort_order = pymongo.DESCENDING + + # Build the query self.collection.create_index(sort_field) + query = defaultdict(dict) + query[sort_field] = {'$exists': True} + query['structure_stability.uff'] = {'$lt': self.trainer_config.maximum_strain} + cursor = ( self.collection.find( - {sort_field: {'$exists': True, '$lt': self.trainer_config.maximum_strain}}, + query, {'md_trajectory': 0} # Filter out the trajectory to save I/O ) - .sort(sort_field, pymongo.ASCENDING) + .sort(sort_field, sort_order) .limit(to_include) ) examples = [] @@ -411,13 +424,13 @@ def retrain(self): record['times'] = {} record['md_trajectory'] = {} examples.append(MOFRecord(**record)) - if len(examples) == 0 or len(examples) == last_train_size: - self.logger.info(f'The number of training examples with strain below {self.trainer_config.maximum_strain:.2f} is the same ' + if (len(examples) == 0 or len(examples) == last_train_size) and len(examples) < self.trainer_config.maximum_train_size: + self.logger.info(f'The number of training examples for {sort_field} with strain below {self.trainer_config.maximum_strain:.2f} is the same ' f'as the last time we trained DiffLinker ({last_train_size}). Waiting for more data') self.start_train.clear() self.start_train.wait() continue - self.logger.info(f'Gathered the top {len(examples)} records with strain below {self.trainer_config.maximum_strain:.2f} based on stability') + self.logger.info(f'Gathered the top {len(examples)} with strain below {self.trainer_config.maximum_strain:.2f} records based on {sort_field}') last_train_size = len(examples) # So we know what the training set size was for the next iteration # Determine the run directory @@ -432,7 +445,7 @@ def retrain(self): input_kwargs={'examples': examples, 'run_directory': train_dir}, method='train_generator', topic='training', - task_info={'train_size': len(examples)} + task_info={'train_size': len(examples), 'sort_field': sort_field} ) self.logger.info('Submitted training. Waiting until complete') @@ -516,11 +529,17 @@ def store_cp2k(self, result: Result): ) self.logger.info(f'Partial charges are complete for {mof_name}. Submitted RASPA') elif result.method == 'run_GCMC_single': + # Store result storage_mean, storage_std = result.value record = mofadb.get_records(self.collection, [mof_name])[0] record.gas_storage['CO2'] = (1e4, storage_mean) mofadb.update_records(self.collection, [record]) - self.logger.info(f'Stored gas storage capacity for {mof_name}: {storage_mean:.3e} +/- {storage_std:.3e}') + + # Update and trigger training, in case it's blocked + self.num_raspa_completed += 1 + if self.num_raspa_completed > self.trainer_config.minimum_train_size: + self.start_train.set() + self.logger.info(f'Stored gas storage capacity for {mof_name}: {storage_mean:.3e} +/- {storage_std:.3e}. Completed {self.num_raspa_completed}') else: raise ValueError(f'Method not supported: {result.method}') print(result.json(exclude={'inputs', 'value'}), file=self._output_files['simulation-results'], flush=True)