diff --git a/ensemble_md/cli/analyze_REXEE.py b/ensemble_md/cli/analyze_REXEE.py index b694d17e..693f0065 100644 --- a/ensemble_md/cli/analyze_REXEE.py +++ b/ensemble_md/cli/analyze_REXEE.py @@ -29,7 +29,7 @@ from ensemble_md.analysis import analyze_matrix # noqa: E402 from ensemble_md.analysis import msm_analysis # noqa: E402 from ensemble_md.analysis import analyze_free_energy # noqa: E402 -from ensemble_md.ensemble_EXE import EnsembleEXE # noqa: E402 +from ensemble_md.replica_exchange_EE import ReplicaExchangeEE # noqa: E402 from ensemble_md.utils.exceptions import ParameterError # noqa: E402 @@ -41,13 +41,13 @@ def initialize(args): '--yaml', type=str, default='params.yaml', - help='The input YAML file used to run the EEXE simulation. (Default: params.yaml)') + help='The input YAML file used to run the REXEE simulation. (Default: params.yaml)') parser.add_argument('-o', '--output', type=str, - default='analyze_EEXE_log.txt', - help='The output log file that contains the analysis results of EEXE. \ - (Default: analyze_EEXE_log.txt)') + default='analyze_REXEE_log.txt', + help='The output log file that contains the analysis results of REXEE. \ + (Default: analyze_REXEE_log.txt)') parser.add_argument('-rt', '--rep_trajs', type=str, @@ -101,15 +101,15 @@ def main(): print(f'Current time: {datetime.now().strftime("%d/%m/%Y %H:%M:%S")}') print(f'Command line: {" ".join(sys.argv)}') - EEXE = EnsembleEXE(args.yaml) - EEXE.print_params(params_analysis=True) + REXEE = ReplicaExchangeEE(args.yaml) + REXEE.print_params(params_analysis=True) - for i in EEXE.warnings: + for i in REXEE.warnings: print() print(f'{i}') print() - if len(EEXE.warnings) > args.maxwarn: + if len(REXEE.warnings) > args.maxwarn: raise ParameterError( f"The execution failed due to warning(s) about parameter spcificaiton. Consider setting maxwarn in the input YAML file if you want to ignore them.") # noqa: E501, F541 @@ -130,12 +130,12 @@ def main(): # 1-1. Plot the replica-sapce trajectory print('1-1. Plotting transitions between alchemical ranges ...') - dt_swap = EEXE.nst_sim * EEXE.dt # dt for swapping replicas + dt_swap = REXEE.nst_sim * REXEE.dt # dt for swapping replicas analyze_traj.plot_rep_trajs(rep_trajs, f'{args.dir}/rep_trajs.png', dt_swap) # 1-2. Plot the replica-space transition matrix print('1-2. Plotting the replica-space transition matrix (considering all continuous trajectories) ...') - counts = [analyze_traj.traj2transmtx(rep_trajs[i], EEXE.n_sim, normalize=False) for i in range(len(rep_trajs))] + counts = [analyze_traj.traj2transmtx(rep_trajs[i], REXEE.n_sim, normalize=False) for i in range(len(rep_trajs))] reps_mtx = np.sum(counts, axis=0) # First sum up the counts. This should be symmetric if n_ex=1. Otherwise it might not be. # noqa: E501 reps_mtx /= np.sum(reps_mtx, axis=1)[:, None] # and then normalize each row analyze_matrix.plot_matrix(reps_mtx, f'{args.dir}/rep_transmtx_allconfigs.png') @@ -155,16 +155,16 @@ def main(): else: # This may take a while. print('2-1. Stitching trajectories for each starting configuration from dhdl files ...') - dhdl_files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(EEXE.n_sim)] - shifts = np.arange(EEXE.n_sim) * EEXE.s + dhdl_files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(REXEE.n_sim)] + shifts = np.arange(REXEE.n_sim) * REXEE.s state_trajs = analyze_traj.stitch_time_series(dhdl_files, rep_trajs, shifts=shifts, save_npy=True, save_xvg=True) # length: the number of replicas # noqa: E501 # 2-2. Plot the state-space trajectory print('\n2-2. Plotting transitions between different alchemical states ...') - dt_traj = EEXE.dt * EEXE.template['nstdhdl'] # in ps + dt_traj = REXEE.dt * REXEE.template['nstdhdl'] # in ps analyze_traj.plot_state_trajs( state_trajs, - EEXE.state_ranges, + REXEE.state_ranges, f'{args.dir}/state_trajs.png', dt_traj ) @@ -173,10 +173,10 @@ def main(): print('\n2-3. Plotting the histograms of the state index for different trajectories ...') hist_data = analyze_traj.plot_state_hist( state_trajs, - EEXE.state_ranges, + REXEE.state_ranges, f'{args.dir}/state_hist.png' ) - rmse = analyze_traj.calculate_hist_rmse(hist_data, EEXE.state_ranges) + rmse = analyze_traj.calculate_hist_rmse(hist_data, REXEE.state_ranges) print(f'The RMSE of accumulated histogram counts of the state index: {rmse:.0f}') # 2-4. Stitch the time series of state index for different alchemical ranges @@ -186,15 +186,15 @@ def main(): else: # This may take a while. print('2-4. Stitching time series of state index for each alchemical range ...') - dhdl_files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(EEXE.n_sim)] - shifts = np.arange(EEXE.n_sim) * EEXE.s + dhdl_files = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(REXEE.n_sim)] + shifts = np.arange(REXEE.n_sim) * REXEE.s state_trajs_for_sim = analyze_traj.stitch_time_series_for_sim(dhdl_files, shifts) # 2-5. Plot the time series of state index for different alchemical ranges print('\n2-5. Plotting the time series of state index for different alchemical ranges ...') analyze_traj.plot_state_trajs( state_trajs_for_sim, - EEXE.state_ranges, + REXEE.state_ranges, f'{args.dir}/state_trajs_for_sim.png', title_prefix='Alchemical range' ) @@ -203,7 +203,7 @@ def main(): print('\n2-6. Plotting the histograms of state index for different alchemical ranges') analyze_traj.plot_state_hist( state_trajs_for_sim, - EEXE.state_ranges, + REXEE.state_ranges, f'{args.dir}/state_hist_for_sim.png', prefix='Alchemical range', subplots=True @@ -212,8 +212,8 @@ def main(): # 2-7. Plot the overall state transition matrices calculated from the state-space trajectories print('\n2-7. Plotting the overall state transition matrices from different trajectories ...') mtx_list = [] - for i in range(EEXE.n_sim): - mtx = analyze_traj.traj2transmtx(state_trajs[i], EEXE.n_tot) + for i in range(REXEE.n_sim): + mtx = analyze_traj.traj2transmtx(state_trajs[i], REXEE.n_tot) mtx_list.append(mtx) analyze_matrix.plot_matrix(mtx, f'{args.dir}/traj_{i}_state_transmtx.png') @@ -223,7 +223,7 @@ def main(): spectral_gaps = [results[i][0] if None not in results else None for i in range(len(results))] eig_vals = [results[i][1] if None not in results else None for i in range(len(results))] if None not in spectral_gaps: - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): print(f' - Trajectory {i}: {spectral_gaps[i]:.3f} (λ_1: {eig_vals[i][0]:.5f}, λ_2: {eig_vals[i][1]:.5f})') # noqa: E501 print(f' - Average of the above: {np.mean(spectral_gaps):.3f} (std: {np.std(spectral_gaps, ddof=1):.3f})') @@ -233,21 +233,21 @@ def main(): if any([x is None for x in pi_list]): pass # None is in the list else: - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): print(f' - Trajectory {i}: {", ".join([f"{j:.3f}" for j in pi_list[i].reshape(-1)])}') if len({len(i) for i in pi_list}) == 1: # all lists in pi_list have the same length print(f' - Average of the above: {", ".join([f"{i:.3f}" for i in np.mean(pi_list, axis=0).reshape(-1)])}') # noqa: E501 # 2-10. Calculate the state index correlation time for each trajectory (this step is more time-consuming one) print('\n2-10. Calculating the state index correlation time ...') - tau_list = [(pymbar.timeseries.statistical_inefficiency(state_trajs[i], fast=True) - 1) / 2 * dt_traj for i in range(EEXE.n_sim)] # noqa: E501 - for i in range(EEXE.n_sim): + tau_list = [(pymbar.timeseries.statistical_inefficiency(state_trajs[i], fast=True) - 1) / 2 * dt_traj for i in range(REXEE.n_sim)] # noqa: E501 + for i in range(REXEE.n_sim): print(f' - Trajectory {i}: {tau_list[i]:.1f} ps') print(f' - Average of the above: {np.mean(tau_list):.1f} ps (std: {np.std(tau_list, ddof=1):.1f} ps)') # 2-11. Calculate transit times for each trajectory print('\n2-11. Plotting the average transit times ...') - t_0k_list, t_k0_list, t_roundtrip_list, units = analyze_traj.plot_transit_time(state_trajs, EEXE.n_tot, dt=dt_traj, folder=args.dir) # noqa: E501 + t_0k_list, t_k0_list, t_roundtrip_list, units = analyze_traj.plot_transit_time(state_trajs, REXEE.n_tot, dt=dt_traj, folder=args.dir) # noqa: E501 meta_list = [t_0k_list, t_k0_list, t_roundtrip_list] t_names = [ '\n - Average transit time from states 0 to k', @@ -267,7 +267,7 @@ def main(): if np.sum(np.isnan([np.mean(i) for i in t_list])) != 0: poor_sampling = True - if EEXE.msm is True: + if REXEE.msm is True: section_idx += 1 # Section 3. Analysis based on Markov state models @@ -275,7 +275,7 @@ def main(): # 3-1. Plot the implied timescale as a function of lag time print('\n3-1. Plotting the implied timescale as a function of lag time for all trajectories ...') - lags = np.arange(EEXE.lag_spacing, EEXE.lag_max + EEXE.lag_spacing, EEXE.lag_spacing) + lags = np.arange(REXEE.lag_spacing, REXEE.lag_max + REXEE.lag_spacing, REXEE.lag_spacing) # lags could also be None and decided automatically. Could consider using that. ts_list = msm_analysis.plot_its(state_trajs, lags, fig_name=f'{args.dir}/implied_timescales.png', dt=dt_traj, units='ps') # noqa: E501 @@ -289,7 +289,7 @@ def main(): # 3-3. Build a Bayesian MSM and perform a CK test for each trajectory to validate the models print('\n3-3. Building Bayesian MSMs for the state-space trajectory for each trajectory ...') print(' Performing a Chapman-Kolmogorov test on each trajectory ...') - models = [pyemma.msm.bayesian_markov_model(state_trajs[i], chosen_lags[i], dt_traj=f'{dt_traj} ps', show_progress=False) for i in range(EEXE.n_sim)] # noqa: E501 + models = [pyemma.msm.bayesian_markov_model(state_trajs[i], chosen_lags[i], dt_traj=f'{dt_traj} ps', show_progress=False) for i in range(REXEE.n_sim)] # noqa: E501 for i in range(len(models)): print(f' Plotting the CK-test results for trajectory {i} ...') @@ -300,7 +300,7 @@ def main(): # not be counted as involved in the transition matrix (i.e. not in the active set). To check the # active states, use models[i].active_set. If the system sampled all states frequently, # models[i].active_set should be equal to np.unique(state_trajs[i]) and both lengths should be - # EEXE.n_tot. I'm not sure why the attribute nstates_full is not always EEXE.n_tot but is less + # REXEE.n_tot. I'm not sure why the attribute nstates_full is not always REXEE.n_tot but is less # relevant here. cktest = models[i].cktest(nsets=nsets, mlags=mlags, show_progress=False) pyemma.plots.plot_cktest(cktest, dt=dt_traj, units='ps') @@ -309,7 +309,7 @@ def main(): # Additionally, check if the sampling is poor for each trajectory for i in range(len(models)): - if models[i].nstates != EEXE.n_tot: + if models[i].nstates != REXEE.n_tot: print(f' Note: The sampling of trajectory {i} was poor.') # 3-4. Plot the state transition matrices estimated with the specified lag times in MSMs @@ -318,13 +318,13 @@ def main(): mtx_list_modified = [] # just for plotting (if all trajs sampled the fulle range frequently, this will be the same as mtx_list) # noqa: E501 for i in range(len(mtx_list)): # check if each mtx in mtx_list spans the full alchemical range. (If the system did not visit - # certain states, the dimension will be less than EEXE.n_tot * EEXE.n_tot. In this case, we + # certain states, the dimension will be less than REXEE.n_tot * REXEE.n_tot. In this case, we # add rows and columns of 0. Note that the modified matrix will not be a transition matrix, # so this is only for plotting. For later analysis such as spectral gap calculation, we # will just use the unmodified matrices. - if mtx_list[i].shape != (EEXE.n_tot, EEXE.n_tot): # add rows and columns of 0 + if mtx_list[i].shape != (REXEE.n_tot, REXEE.n_tot): # add rows and columns of 0 sampled = models[i].active_set - missing = list(set(range(EEXE.n_tot)) - set(sampled)) # states not visited + missing = list(set(range(REXEE.n_tot)) - set(sampled)) # states not visited # figure out which end we should stack rows/columns to n_1 = sum(missing > max(sampled)) # add rows/columns to the end of large state indices @@ -348,7 +348,7 @@ def main(): print(' Saving transmtx.npy (plotted transition matrices)...') np.save(f'{args.dir}/transmtx.npy', mtx_list_modified) - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): analyze_matrix.plot_matrix(mtx_list[i], f'{args.dir}/traj_{i}_state_transmtx_msm.png') analyze_matrix.plot_matrix(avg_mtx, f'{args.dir}/state_transmtx_avg_msm.png') @@ -363,13 +363,13 @@ def main(): # 3-5. Calculate the spectral gap from the transition matrix of each trajectory print('\n3-5. Calculating the spectral gap of the state transition matrices obtained from MSMs ...') spectral_gaps, eig_vals = [analyze_matrix.calc_spectral_gap(mtx) for mtx in mtx_list] - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): print(f' - Trajectory {i}: {spectral_gaps[i]:.3f}') # 3-6. Calculate the stationary distribution for each trajectory print('\n3-6. Calculating the stationary distributions from the transition matrices obtained from MSMs ...') pi_list = [m.pi for m in models] - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): print(f' - Trajectory {i}: {", ".join([f"{j:.3f}" for j in pi_list[i]])}') if len({len(i) for i in pi_list}) == 1: # all lists in pi_list have the same length print(f' - Average of the above: {", ".join([f"{i:.3f}" for i in np.mean(pi_list, axis=0)])}') @@ -379,16 +379,16 @@ def main(): # note that it's not m.mfpt(min(m.active_set), max(m.active_set)) as the input to mfpt should be indices # though sometimes these two could be same. mfpt_list = [m.mfpt(0, m.nstates - 1) for m in models] - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): print(f' - Trajectory {i}: {mfpt_list[i]:.1f} ps') print(f' - Average of the above: {np.mean(mfpt_list):.1f} ps (std: {np.std(mfpt_list, ddof=1):.1f} ps)') # 3-8. Calculate the state index correlation time for each trajectory print('\n3-8. Plotting the state index correlation times for all trajectories ...') - msm_analysis.plot_acf(models, EEXE.n_tot, f'{args.dir}/state_ACF.png') + msm_analysis.plot_acf(models, REXEE.n_tot, f'{args.dir}/state_ACF.png') # Section 4 (or Section 3). Free energy calculations - if EEXE.free_energy is True: + if REXEE.free_energy is True: if poor_sampling is True: print('\nFree energy calculation is not performed since the sampling appears poor.') sys.exit() @@ -397,7 +397,7 @@ def main(): # 4-1. Subsampling the data data_list = [] # either a list of u_nk or a list of dhdl - if EEXE.df_data_type == 'u_nk': + if REXEE.df_data_type == 'u_nk': if os.path.isfile(f'{args.dir}/u_nk_data.pickle') is True: print('Loading the preprocessed data u_nk ...') with open(f'{args.dir}/u_nk_data.pickle', 'rb') as handle: @@ -409,54 +409,54 @@ def main(): data_list = pickle.load(handle) if data_list == []: - files_list = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(EEXE.n_sim)] - data_list, t_list, g_list = analyze_free_energy.preprocess_data(files_list, EEXE.temp, EEXE.df_data_type, EEXE.df_spacing) # noqa: E501 + files_list = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(REXEE.n_sim)] + data_list, t_list, g_list = analyze_free_energy.preprocess_data(files_list, REXEE.temp, REXEE.df_data_type, REXEE.df_spacing) # noqa: E501 - with open(f'{args.dir}/{EEXE.df_data_type}_data.pickle', 'wb') as handle: + with open(f'{args.dir}/{REXEE.df_data_type}_data.pickle', 'wb') as handle: pickle.dump(data_list, handle, protocol=pickle.HIGHEST_PROTOCOL) # 4-2. Calculate the free energy profile - f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, EEXE.state_ranges, EEXE.df_method, EEXE.err_method, EEXE.n_bootstrap, EEXE.seed) # noqa: E501 + f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, REXEE.state_ranges, REXEE.df_method, REXEE.err_method, REXEE.n_bootstrap, REXEE.seed) # noqa: E501 print('Plotting the full-range free energy profile ...') analyze_free_energy.plot_free_energy(f, f_err, f'{args.dir}/free_energy_profile.png') print('The full-range free energy profile averaged over all replicas:') - print(f" {', '.join(f'{f[i]:.3f} +/- {f_err[i]:.3f} kT' for i in range(EEXE.n_tot))}") + print(f" {', '.join(f'{f[i]:.3f} +/- {f_err[i]:.3f} kT' for i in range(REXEE.n_tot))}") print(f'The free energy difference between the coupled and decoupled states: {f[-1]:.3f} +/- {f_err[-1]:.3f} kT') # noqa: E501 - if EEXE.df_ref is not None: - rmse_list = analyze_free_energy.calculate_df_rmse(estimators, EEXE.df_ref, EEXE.state_ranges) - for i in range(EEXE.n_sim): - print(f'RMSE of the free energy profile for alchemical range {i} (states {EEXE.state_ranges[i][0]} to {EEXE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501 + if REXEE.df_ref is not None: + rmse_list = analyze_free_energy.calculate_df_rmse(estimators, REXEE.df_ref, REXEE.state_ranges) + for i in range(REXEE.n_sim): + print(f'RMSE of the free energy profile for alchemical range {i} (states {REXEE.state_ranges[i][0]} to {REXEE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501 # 4-3. Recalculate the free energy profile if subsampling_avg is True - if EEXE.subsampling_avg is True: + if REXEE.subsampling_avg is True: print('\nUsing averaged start index of the equilibrated data and the avearged statistic inefficiency to re-perform free energy calculations ...') # noqa: E501 t_avg = int(np.mean(t_list)) + 1 # Using the ceiling function to be a little more conservative g_avg = np.array(g_list).prod() ** (1/len(g_list)) # geometric mean print(f'Averaged start index: {t_avg}') print(f'Averaged statistical inefficiency: {g_avg:.2f}') - data_list, _, _ = analyze_free_energy.preprocess_data(files_list, EEXE.temp, EEXE.df_data_type, EEXE.df_spacing, t_avg, g_avg) # noqa: E501 - with open(f'{args.dir}/{EEXE.df_data_type}_data_avg_subsampling.pickle', 'wb') as handle: + data_list, _, _ = analyze_free_energy.preprocess_data(files_list, REXEE.temp, REXEE.df_data_type, REXEE.df_spacing, t_avg, g_avg) # noqa: E501 + with open(f'{args.dir}/{REXEE.df_data_type}_data_avg_subsampling.pickle', 'wb') as handle: pickle.dump(data_list, handle, protocol=pickle.HIGHEST_PROTOCOL) - f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, EEXE.state_ranges, EEXE.df_method, EEXE.err_method, EEXE.n_bootstrap, EEXE.seed) # noqa: E501 + f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, REXEE.state_ranges, REXEE.df_method, REXEE.err_method, REXEE.n_bootstrap, REXEE.seed) # noqa: E501 print('Plotting the full-range free energy profile ...') analyze_free_energy.plot_free_energy(f, f_err, f'{args.dir}/free_energy_profile_avg_subsampling.png') print('The full-range free energy profile averaged over all replicas:') - print(f" {', '.join(f'{f[i]:.3f} +/- {f_err[i]:.3f} kT' for i in range(EEXE.n_tot))}") + print(f" {', '.join(f'{f[i]:.3f} +/- {f_err[i]:.3f} kT' for i in range(REXEE.n_tot))}") print(f'The free energy difference between the coupled and decoupled states: {f[-1]:.3f} +/- {f_err[-1]:.3f} kT') # noqa: E501 - if EEXE.df_ref is not None: - rmse_list = analyze_free_energy.calculate_df_rmse(estimators, EEXE.df_ref, EEXE.state_ranges) - for i in range(EEXE.n_sim): - print(f'RMSE of the free energy profile for alchemical range {i} (states {EEXE.state_ranges[i][0]} to {EEXE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501 + if REXEE.df_ref is not None: + rmse_list = analyze_free_energy.calculate_df_rmse(estimators, REXEE.df_ref, REXEE.state_ranges) + for i in range(REXEE.n_sim): + print(f'RMSE of the free energy profile for alchemical range {i} (states {REXEE.state_ranges[i][0]} to {REXEE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501 # Section 4. Calculate the time spent in GROMACS (This could take a while.) - t_wall_tot, t_sync, _ = utils.analyze_EEXE_time() + t_wall_tot, t_sync, _ = utils.analyze_REXEE_time() print(f'\nTotal wall time GROMACS spent to finish all iterations: {utils.format_time(t_wall_tot)}') print(f'Total time spent in syncrhonizing all replicas: {utils.format_time(t_sync)}') diff --git a/ensemble_md/cli/explore_REXEE.py b/ensemble_md/cli/explore_REXEE.py index 23fe5c1c..e83334f4 100644 --- a/ensemble_md/cli/explore_REXEE.py +++ b/ensemble_md/cli/explore_REXEE.py @@ -12,23 +12,23 @@ import argparse import numpy as np import pandas as pd -from ensemble_md.ensemble_EXE import EnsembleEXE +from ensemble_md.replica_exchange_EE import ReplicaExchangeEE def initialize(args): parser = argparse.ArgumentParser( - description='This code explores the parameter space of homogenous EEXE to help you figure \ + description='This code explores the parameter space of homogenous REXEE to help you figure \ out all possible combinations of the number of replicas, the number of \ states in each replica, and the number of overlapping states, and the total number states.') parser.add_argument('-N', '--N', required=True, type=int, - help='The total number of states of the EEXE simulation.') + help='The total number of states of the REXEE simulation.') parser.add_argument('-r', '--r', type=int, - help='The number of replicas that compose the EEXE simulation.') + help='The number of replicas that compose the REXEE simulation.') parser.add_argument('-n', '--n', type=int, @@ -54,9 +54,9 @@ def initialize(args): return args_parse -def solv_EEXE_diophantine(N, constraint=False): +def solv_REXEE_diophantine(N, constraint=False): """ - Solves the general nonlinear Diophantine equation associated with the homogeneous EEXE + Solves the general nonlinear Diophantine equation associated with the homogeneous REXEE parameters. Specifically, given the total number of states :math:`N` and the number of replicas r, the states for each replica n and the state shift s can be expressed as: n = N + (r-1)(t-1), and s = 1 - t, with the range of t being either the following: @@ -66,7 +66,7 @@ def solv_EEXE_diophantine(N, constraint=False): Parameters ---------- N : int - The total number of states of the homogeneous EEXE of interesst. + The total number of states of the homogeneous REXEE of interesst. constraint : bool Whether to apply additional constraints such that n-s <= 1/2n. @@ -115,7 +115,7 @@ def estimate_swapless_rate(state_ranges, N=1000000): n = 0 # number of times of not having any swappable pairs for i in range(N): rands = [random.choice(state_ranges[i]) for i in range(len(state_ranges))] - swappables = EnsembleEXE.identify_swappable_pairs(rands, state_ranges, False) + swappables = ReplicaExchangeEE.identify_swappable_pairs(rands, state_ranges, False) if swappables == []: n += 1 @@ -125,18 +125,18 @@ def estimate_swapless_rate(state_ranges, N=1000000): def main(): - # For now, we only consider homogenous EEXE simulations + # For now, we only consider homogenous REXEE simulations args = initialize(sys.argv[1:]) - print('Exploration of the EEXE parameter space') + print('Exploration of the REXEE parameter space') print('=======================================') - print('[ EEXE parameters of interest ]') + print('[ REXEE parameters of interest ]') print('- N: The total number of states') print('- r: The number of replicas') print('- n: The number of states for each replica') print('- s: The state shift between adjacent replicas') # Enuerate all possible combinations of (N, r, n, s) even if any of r, n, s is given - it's easy/fast anyway. - soln_all = solv_EEXE_diophantine(args.N, constraint=args.cnst) + soln_all = solv_REXEE_diophantine(args.N, constraint=args.cnst) # Now filter the solutions if args.r is not None: diff --git a/ensemble_md/cli/run_REXEE.py b/ensemble_md/cli/run_REXEE.py index cfc69c82..239bc910 100644 --- a/ensemble_md/cli/run_REXEE.py +++ b/ensemble_md/cli/run_REXEE.py @@ -19,7 +19,7 @@ from datetime import datetime from ensemble_md.utils import utils -from ensemble_md.ensemble_EXE import EnsembleEXE +from ensemble_md.replica_exchange_EE import ReplicaExchangeEE def initialize(args): @@ -29,7 +29,7 @@ def initialize(args): '--yaml', type=str, default='params.yaml', - help='The input YAML file that contains EEXE parameters. (Default: params.yaml)') + help='The input YAML file that contains REXEE parameters. (Default: params.yaml)') parser.add_argument('-c', '--ckpt', type=str, @@ -46,9 +46,9 @@ def initialize(args): parser.add_argument('-o', '--output', type=str, - default='run_EEXE_log.txt', + default='run_REXEE_log.txt', help='The output file for logging how replicas interact with each other. \ - (Default: run_EEXE_log.txt)') + (Default: run_REXEE_log.txt)') parser.add_argument('-m', '--maxwarn', type=int, @@ -65,7 +65,7 @@ def main(): sys.stdout = utils.Logger(logfile=args.output) sys.stderr = utils.Logger(logfile=args.output) - # Step 1: Set up MPI rank and instantiate EnsembleEXE to set up EEXE parameters + # Step 1: Set up MPI rank and instantiate ReplicaExchangeEE to set up REXEE parameters comm = MPI.COMM_WORLD rank = comm.Get_rank() # Note that this is a GLOBAL variable @@ -73,17 +73,17 @@ def main(): print(f'Current time: {datetime.now().strftime("%d/%m/%Y %H:%M:%S")}') print(f'Command line: {" ".join(sys.argv)}\n') - EEXE = EnsembleEXE(args.yaml) + REXEE = ReplicaExchangeEE(args.yaml) if rank == 0: # Print out simulation parameters - EEXE.print_params() + REXEE.print_params() # Print out warnings and fail if needed - for i in EEXE.warnings: + for i in REXEE.warnings: print(f'\n{i}\n') - if len(EEXE.warnings) > args.maxwarn: + if len(REXEE.warnings) > args.maxwarn: print(f"The execution failed due to warning(s) about parameter spcificaiton. Check the warnings, or consider setting maxwarn in the input YAML file if you find them harmless.") # noqa: E501, F541 comm.Abort(101) @@ -93,28 +93,28 @@ def main(): # 2-1. Set up input files for all simulations if rank == 0: - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): os.mkdir(f'sim_{i}') os.mkdir(f'sim_{i}/iteration_0') - MDP = EEXE.initialize_MDP(i) + MDP = REXEE.initialize_MDP(i) MDP.write(f"sim_{i}/iteration_0/expanded.mdp", skipempty=True) # 2-2. Run the first ensemble of simulations - EEXE.run_EEXE(0) + REXEE.run_REXEE(0) else: if rank == 0: - # If there is a checkpoint file, we see the execution as an extension of an EEXE simulation + # If there is a checkpoint file, we see the execution as an extension of an REXEE simulation ckpt_data = np.load(args.ckpt) start_idx = len(ckpt_data[0]) # The length should be the same for the same axis - print(f'\nGetting prepared to extend the EEXE simulation from iteration {start_idx} ...') + print(f'\nGetting prepared to extend the REXEE simulation from iteration {start_idx} ...') - if start_idx == EEXE.n_iter: + if start_idx == REXEE.n_iter: print('Extension aborted: The expected number of iterations have been completed!') MPI.COMM_WORLD.Abort(1) else: print('Deleting data generated after the checkpoint ...') - for i in range(EEXE.n_sim): + for i in range(REXEE.n_sim): n_finished = len(next(os.walk(f'sim_{i}'))[1]) # number of finished iterations for j in range(start_idx, n_finished): print(f' Deleting the folder sim_{i}/iteration_{j}') @@ -123,18 +123,18 @@ def main(): # Read g_vecs.npy and rep_trajs.npy so that new data can be appended, if any. # Note that these two arrays are created in rank 0 and should always be operated in rank 0, # or broadcasting is required. - EEXE.rep_trajs = [list(i) for i in ckpt_data] + REXEE.rep_trajs = [list(i) for i in ckpt_data] if os.path.isfile(args.g_vecs) is True: - EEXE.g_vecs = [list(i) for i in np.load(args.g_vecs)] + REXEE.g_vecs = [list(i) for i in np.load(args.g_vecs)] else: start_idx = None start_idx = comm.bcast(start_idx, root=0) # so that all the ranks are aware of start_idx # 2-3. Get the reference distance for the distance restraint specified in the pull code, if any. - EEXE.get_ref_dist() + REXEE.get_ref_dist() - for i in range(start_idx, EEXE.n_iter): + for i in range(start_idx, REXEE.n_iter): # For a large code block like below executed on rank 0, we try to catch any exception and abort the simulation. # So if there is bug, the execution will be terminated and no computation time will be wasted. try: @@ -143,10 +143,10 @@ def main(): # 3-1. For all the replica simulations, # (1) Find the last sampled state and the corresponding lambda values from the DHDL files. # (2) Find the final Wang-Landau incrementors and weights from the LOG files. - dhdl_files = [f'sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(EEXE.n_sim)] - log_files = [f'sim_{j}/iteration_{i - 1}/md.log' for j in range(EEXE.n_sim)] - states_ = EEXE.extract_final_dhdl_info(dhdl_files) - wl_delta, weights_, counts_ = EEXE.extract_final_log_info(log_files) + dhdl_files = [f'sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(REXEE.n_sim)] + log_files = [f'sim_{j}/iteration_{i - 1}/md.log' for j in range(REXEE.n_sim)] + states_ = REXEE.extract_final_dhdl_info(dhdl_files) + wl_delta, weights_, counts_ = REXEE.extract_final_log_info(log_files) print() # 3-2. Identify swappable pairs, propose swap(s), calculate P_acc, and accept/reject swap(s) @@ -159,68 +159,68 @@ def main(): states = copy.deepcopy(states_) weights = copy.deepcopy(weights_) counts = copy.deepcopy(counts_) - swap_pattern, swap_list = EEXE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501 + swap_pattern, swap_list = REXEE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501 # 3-3. Perform weight correction/weight combination - if wl_delta != [None for i in range(EEXE.n_sim)]: # weight-updating + if wl_delta != [None for i in range(REXEE.n_sim)]: # weight-updating print(f'\nCurrent Wang-Landau incrementors: {wl_delta}\n') # (1) First we prepare the weights to be combined. # Note that although averaged weights are sometimes used for weight correction/weight combination, # the final weights are always used for calculating the acceptance ratio. - if EEXE.N_cutoff != -1 or EEXE.w_combine is not None: + if REXEE.N_cutoff != -1 or REXEE.w_combine is not None: # Only when weight correction/weight combination is needed. - weights_avg, weights_err = EEXE.get_averaged_weights(log_files) - weights_input = EEXE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501 + weights_avg, weights_err = REXEE.get_averaged_weights(log_files) + weights_input = REXEE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501 # (2) Now we perform weight correction/weight combination. # The product of this step should always be named as "weights" to be used in update_MDP - if EEXE.N_cutoff != -1 and EEXE.w_combine is not None: + if REXEE.N_cutoff != -1 and REXEE.w_combine is not None: # perform both if weights_input is None: # Then only weight correction will be performed print('Note: Weight combination is deactivated because the weights are too noisy.') - weights = EEXE.weight_correction(weights, counts) - _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 + weights = REXEE.weight_correction(weights, counts) + _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 else: - weights_preprocessed = EEXE.weight_correction(weights_input, counts) - if EEXE.verbose is True: + weights_preprocessed = REXEE.weight_correction(weights_input, counts) + if REXEE.verbose is True: print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - counts, weights, g_vec = EEXE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 - EEXE.g_vecs.append(g_vec) - elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None: + counts, weights, g_vec = REXEE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501 + REXEE.g_vecs.append(g_vec) + elif REXEE.N_cutoff == -1 and REXEE.w_combine is not None: # only perform weight combination print('Note: No weight correction will be performed.') if weights_input is None: print('Note: Weight combination is deactivated because the weights are too noisy.') - _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 + _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: - if EEXE.verbose is True: + if REXEE.verbose is True: print('Performing weight combination ...') else: print('Performing weight combination ...', end='') - counts, weights, g_vec = EEXE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501 - EEXE.g_vecs.append(g_vec) - elif EEXE.N_cutoff != -1 and EEXE.w_combine is None: + counts, weights, g_vec = REXEE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501 + REXEE.g_vecs.append(g_vec) + elif REXEE.N_cutoff != -1 and REXEE.w_combine is None: # only perform weight correction print('Note: No weight combination will be performed.') - weights = EEXE.histogram_correction(weights_input, counts) - _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 + weights = REXEE.histogram_correction(weights_input, counts) + _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501 else: print('Note: No weight correction will be performed.') print('Note: No weight combination will be performed.') - _ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 + _ = REXEE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501 # 3-5. Modify the MDP files and swap out the GRO files (if needed) # Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501 # Note we use states (copy of states_) instead of states_ in update_MDP. - for j in list(range(EEXE.n_sim)): + for j in list(range(REXEE.n_sim)): os.mkdir(f'sim_{j}/iteration_{i}') - MDP = EEXE.update_MDP(f"sim_{j}/iteration_{i - 1}/expanded.mdp", j, i, states, wl_delta, weights, counts) # modify with a new template # noqa: E501 + MDP = REXEE.update_MDP(f"sim_{j}/iteration_{i - 1}/expanded.mdp", j, i, states, wl_delta, weights, counts) # modify with a new template # noqa: E501 MDP.write(f"sim_{j}/iteration_{i}/expanded.mdp", skipempty=True) - # In run_EEXE(i, swap_pattern), where the tpr files will be generated, we use the top file at the + # In run_REXEE(i, swap_pattern), where the tpr files will be generated, we use the top file at the # level of the simulation (the file that will be shared by all simulations). For the gro file, we # pass swap_pattern to the function to figure it out internally. else: @@ -231,13 +231,13 @@ def main(): print(f'An error occurred on rank 0:\n{traceback.format_exc()}') MPI.COMM_WORLD.Abort(1) - if -1 not in EEXE.equil and 0 not in EEXE.equil: + if -1 not in REXEE.equil and 0 not in REXEE.equil: # This is the case where the weights are equilibrated in a weight-updating simulation. - # As a remidner, EEXE.equil should be a list of 0 after extract_final_log_info in a + # As a remidner, REXEE.equil should be a list of 0 after extract_final_log_info in a # fixed-weight simulation, and a list of -1 for a weight-updating simulation with unequilibrated weights. print('\nSimulation terminated: The weights have been equilibrated for all replicas.') # this will only be printed in rank 0 # noqa: E501 - # Note that EEXE.equil is avaiable for all ranks but only updated in rank 0. So the if condition here + # Note that REXEE.equil is avaiable for all ranks but only updated in rank 0. So the if condition here # can only be satisfied in rank 0. We broadcast exit_loop to all ranks so that all ranks can exit the # simulation at the same time, if the weights get equilibrated. exit_loop = True @@ -257,7 +257,7 @@ def main(): if len(swap_list) == 0: pass else: - if EEXE.modify_coords_fn is not None: + if REXEE.modify_coords_fn is not None: try: if rank == 0: for j in range(len(swap_list)): @@ -274,53 +274,53 @@ def main(): os.rename(gro_2, gro_2_backup) # Here we input gro_1_backup and gro_2_backup and modify_coords_fn will save the modified gro files as gro_1 and gro_2 # noqa: E501 - EEXE.modify_coords_fn(gro_1_backup, gro_2_backup) # the order should not matter + REXEE.modify_coords_fn(gro_1_backup, gro_2_backup) # the order should not matter except Exception: print('\n--------------------------------------------------------------------------\n') print(f'\nAn error occurred on rank 0:\n{traceback.format_exc()}') MPI.COMM_WORLD.Abort(1) # 4-2. Run another ensemble of simulations - EEXE.run_EEXE(i, swap_pattern) + REXEE.run_REXEE(i, swap_pattern) # 4-3. Save data if rank == 0: - if (i + 1) % EEXE.n_ckpt == 0: - if len(EEXE.g_vecs) != 0: + if (i + 1) % REXEE.n_ckpt == 0: + if len(REXEE.g_vecs) != 0: # Save g_vec as a function of time if weight combination was used. - np.save('g_vecs.npy', EEXE.g_vecs) + np.save('g_vecs.npy', REXEE.g_vecs) print('\n----- Saving .npy files to checkpoint the simulation ---') - np.save('rep_trajs.npy', EEXE.rep_trajs) + np.save('rep_trajs.npy', REXEE.rep_trajs) # Save the npy files at the end of the simulation anyway. if rank == 0: - if len(EEXE.g_vecs) != 0: # The length will be 0 only if there is no weight combination. - np.save('g_vecs.npy', EEXE.g_vecs) - np.save('rep_trajs.npy', EEXE.rep_trajs) + if len(REXEE.g_vecs) != 0: # The length will be 0 only if there is no weight combination. + np.save('g_vecs.npy', REXEE.g_vecs) + np.save('rep_trajs.npy', REXEE.rep_trajs) # Step 5: Write a summary for the simulation ensemble if rank == 0: print('\nSummary of the simulation ensemble') print('==================================') print('Simulation status:') - for i in range(EEXE.n_sim): - if EEXE.fixed_weights is True: + for i in range(REXEE.n_sim): + if REXEE.fixed_weights is True: print(f'- Rep {i}: The weights were fixed throughout the simulation.') - elif EEXE.equil[i] == -1: + elif REXEE.equil[i] == -1: print(f' - Rep {i}: The weights have not been equilibrated.') else: - idx = int(np.floor(EEXE.equil[i] / (EEXE.dt * EEXE.nst_sim))) - if EEXE.equil[i] > 1000: + idx = int(np.floor(REXEE.equil[i] / (REXEE.dt * REXEE.nst_sim))) + if REXEE.equil[i] > 1000: units = 'ns' - EEXE.equil[i] /= 1000 + REXEE.equil[i] /= 1000 else: units = 'ps' - print(f' - Rep {i}: The weights have been equilibrated at {EEXE.equil[i]:.2f} {units} (iteration {idx}).') # noqa: E501 + print(f' - Rep {i}: The weights have been equilibrated at {REXEE.equil[i]:.2f} {units} (iteration {idx}).') # noqa: E501 - print(f'\n{EEXE.n_empty_swappable} out of {EEXE.n_iter}, or {EEXE.n_empty_swappable / EEXE.n_iter * 100:.1f}% iterations had an empty list of swappable pairs.') # noqa: E501 - if EEXE.n_swap_attempts != 0: - print(f'{EEXE.n_rejected} out of {EEXE.n_swap_attempts}, or {EEXE.n_rejected / EEXE.n_swap_attempts * 100:.1f}% of attempted exchanges were rejected.') # noqa: E501 + print(f'\n{REXEE.n_empty_swappable} out of {REXEE.n_iter}, or {REXEE.n_empty_swappable / REXEE.n_iter * 100:.1f}% iterations had an empty list of swappable pairs.') # noqa: E501 + if REXEE.n_swap_attempts != 0: + print(f'{REXEE.n_rejected} out of {REXEE.n_swap_attempts}, or {REXEE.n_rejected / REXEE.n_swap_attempts * 100:.1f}% of attempted exchanges were rejected.') # noqa: E501 print(f'\nTime elapsed: {utils.format_time(time.time() - t1)}')