diff --git a/ensemble_md/analysis/analyze_traj.py b/ensemble_md/analysis/analyze_traj.py index 45af0bb..8285862 100644 --- a/ensemble_md/analysis/analyze_traj.py +++ b/ensemble_md/analysis/analyze_traj.py @@ -1330,3 +1330,89 @@ def get_delta_w_updates(log_file, plot=False): plt.savefig('delta_w_updates.png', dpi=600) return t_updates, delta_w_updates, equil + + +def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_rep_pattern, ps_per_frame): + import pandas as pd + import os + import mdtraj as md + + # Determine how many end states are present, which simulations and lambdas those end states correspond to + state_name = ['A'] + considered_swaps = [[0,0]] + cat = ord('A') + 1 + for swap in swap_rep_pattern: + part_1, part_2 = swap + if part_1 in considered_swaps and part_2 in considered_swaps: + continue + elif part_1 in considered_swaps: + index = considered_swaps.index(part_1) + state_name.append(state_name[index]) + considered_swaps.append(part_2) + elif part_2 in considered_swaps: + index = considered_swaps.index(part_2) + state_name.append(state_name[index]) + considered_swaps.append(part_1) + else: + state_name.append(chr(cat)) + state_name.append(chr(cat)) + considered_swaps.append(part_1) + considered_swaps.append(part_2) + cat += 1 + for i in range(n_sim): + for j in [0, 1]: + if [i, j] not in considered_swaps: + state_name.append(chr(cat)) + considered_swaps.append([i, j]) + cat += 1 + + # Determine which frames correspond to which end states + state_frame_df = pd.DataFrame() + for n in range(n_sim): + for i in range(n_iter): + l0_frame, l1_frame = [],[] + dhdl_file = open(f'{working_dir}/sim_{n}/iteration_{i}/dhdl.xvg', 'r').readlines() + start = True + for line in dhdl_file: + split_line = line.split(' ') + while '' in split_line: + split_line.remove('') + if '#' not in split_line[0] and '@' not in split_line[0]: + time = float(split_line[0]) + if start: + start_time = time + start = False + state = float(split_line[1]) + if time%ps_per_frame == 0: + if state in l0_states: + l0_frame.append(int((time-start_time)/ps_per_frame)) + elif state in l1_states: + l1_frame.append(int((time-start_time)/ps_per_frame)) + if len(l0_frame) != 0: + df_0 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l0_frame, 'Lambda': 0}) + state_frame_df = pd.concat([state_frame_df, df_0]) + if len(l1_frame) != 0: + df_1 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l1_frame, 'Lambda': 1}) + state_frame_df = pd.concat([state_frame_df, df_1]) + + # Concatenate all frames from each set of trajectories for each end state + unique_states = list(set(state_name)) + for state in unique_states: + indices = [i for i, value in enumerate(state_name) if value == state] + for i, index in enumerate(indices): + rep, l = considered_swaps[index] + started = False + if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'): + name = 'confout_backup' + else: + name = 'confout' + for iteration in range(n_iter): + frames_select = state_frame_df[(state_frame_df['Sim'] == rep) & (state_frame_df['Iteration'] == iteration) & (state_frame_df['Lambda'] == l)]['Frame'].to_numpy() + if len(frames_select) != 0: + if not started: + traj = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') + started = True + else: + traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') + traj = md.join(traj, traj_add) + traj.save_xtc(f'{working_dir}/analysis/{state}_{rep}.xtc') diff --git a/ensemble_md/cli/analyze_REXEE.py b/ensemble_md/cli/analyze_REXEE.py index 62dfbd6..37e0df3 100644 --- a/ensemble_md/cli/analyze_REXEE.py +++ b/ensemble_md/cli/analyze_REXEE.py @@ -25,6 +25,7 @@ warnings.simplefilter(action='ignore', category=UserWarning) from ensemble_md.utils import utils # noqa: E402 +from ensemble_md.utils import gmx_parser from ensemble_md.analysis import analyze_traj # noqa: E402 from ensemble_md.analysis import analyze_matrix # noqa: E402 from ensemble_md.analysis import msm_analysis # noqa: E402 @@ -454,6 +455,17 @@ def main(): for i in range(REXEE.n_sim): print(f'RMSE of the free energy profile for alchemical range {i} (states {REXEE.state_ranges[i][0]} to {REXEE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501 + # Section 5. Process trajecotries for MT-REXEE + if REXEE.modify_coords is not None: + # Section 5.1. Create end-state trajecotries for each simulation + l0, l1, ps_per_frame = gmx_parser.get_end_states(f'{REXEE.working_dir}/sim_0/iteration_0/expanded.mdp') + n_sim, n_iter = np.shape(rep_trajs) + if REXEE.swap_rep_pattern is None: + raise Exception('MT-REXEE trajectory analysis requires swap_rep_pattern to be defined') + analyze_traj.end_states_only_traj(REXEE.working_dir, n_sim, n_iter, l0, l1, REXEE.swap_rep_pattern, ps_per_frame) + + # Section 5.2. Create concatenated trajectories for each individual simulation + # Section 4. Calculate the time spent in GROMACS (This could take a while.) t_wall_tot, t_sync, _ = utils.analyze_REXEE_time() print(f'\nTotal wall time GROMACS spent to finish all iterations: {utils.format_time(t_wall_tot)}') diff --git a/ensemble_md/utils/gmx_parser.py b/ensemble_md/utils/gmx_parser.py index 4ee8eb3..256e791 100644 --- a/ensemble_md/utils/gmx_parser.py +++ b/ensemble_md/utils/gmx_parser.py @@ -444,3 +444,21 @@ def deter_atom_order(mol_file, resname): break return atom_order + +def get_end_states(mdp_path): + mdp = MDP(mdp_path) + end_0, end_1 = [], [] + coul_lambda = mdp['coul_lambdas'] + vdw_lambda = mdp['vdw_lambdas'] + n = 0 + for vdw, coul in zip(coul_lambda, vdw_lambda): + if vdw == 0.0 and coul == 0.0: + end_0.append(n) + elif vdw == 1.0 and coul == 1.0: + end_1.append(n) + n += 1 + dt = mdp['dt'] + steps_per_frame = mdp['nstxout'] + ps_per_frame = dt*steps_per_frame + + return end_0, end_1, ps_per_frame \ No newline at end of file