Skip to content

Commit

Permalink
Add MT-REXEE analysis functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ajfriedman22 committed Jan 17, 2025
1 parent cb8a343 commit ab9f394
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 0 deletions.
86 changes: 86 additions & 0 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,3 +1330,89 @@ def get_delta_w_updates(log_file, plot=False):
plt.savefig('delta_w_updates.png', dpi=600)

return t_updates, delta_w_updates, equil


def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_rep_pattern, ps_per_frame):
import pandas as pd
import os
import mdtraj as md

# Determine how many end states are present, which simulations and lambdas those end states correspond to
state_name = ['A']
considered_swaps = [[0,0]]
cat = ord('A') + 1
for swap in swap_rep_pattern:
part_1, part_2 = swap
if part_1 in considered_swaps and part_2 in considered_swaps:
continue
elif part_1 in considered_swaps:
index = considered_swaps.index(part_1)
state_name.append(state_name[index])
considered_swaps.append(part_2)
elif part_2 in considered_swaps:
index = considered_swaps.index(part_2)
state_name.append(state_name[index])
considered_swaps.append(part_1)
else:
state_name.append(chr(cat))
state_name.append(chr(cat))
considered_swaps.append(part_1)
considered_swaps.append(part_2)
cat += 1
for i in range(n_sim):
for j in [0, 1]:
if [i, j] not in considered_swaps:
state_name.append(chr(cat))
considered_swaps.append([i, j])
cat += 1

# Determine which frames correspond to which end states
state_frame_df = pd.DataFrame()
for n in range(n_sim):
for i in range(n_iter):
l0_frame, l1_frame = [],[]
dhdl_file = open(f'{working_dir}/sim_{n}/iteration_{i}/dhdl.xvg', 'r').readlines()
start = True
for line in dhdl_file:
split_line = line.split(' ')
while '' in split_line:
split_line.remove('')
if '#' not in split_line[0] and '@' not in split_line[0]:
time = float(split_line[0])
if start:
start_time = time
start = False
state = float(split_line[1])
if time%ps_per_frame == 0:
if state in l0_states:
l0_frame.append(int((time-start_time)/ps_per_frame))
elif state in l1_states:
l1_frame.append(int((time-start_time)/ps_per_frame))
if len(l0_frame) != 0:
df_0 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l0_frame, 'Lambda': 0})
state_frame_df = pd.concat([state_frame_df, df_0])
if len(l1_frame) != 0:
df_1 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l1_frame, 'Lambda': 1})
state_frame_df = pd.concat([state_frame_df, df_1])

# Concatenate all frames from each set of trajectories for each end state
unique_states = list(set(state_name))
for state in unique_states:
indices = [i for i, value in enumerate(state_name) if value == state]
for i, index in enumerate(indices):
rep, l = considered_swaps[index]
started = False
if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'):
name = 'confout_backup'
else:
name = 'confout'
for iteration in range(n_iter):
frames_select = state_frame_df[(state_frame_df['Sim'] == rep) & (state_frame_df['Iteration'] == iteration) & (state_frame_df['Lambda'] == l)]['Frame'].to_numpy()
if len(frames_select) != 0:
if not started:
traj = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
started = True
else:
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
traj = md.join(traj, traj_add)
traj.save_xtc(f'{working_dir}/analysis/{state}_{rep}.xtc')
12 changes: 12 additions & 0 deletions ensemble_md/cli/analyze_REXEE.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
warnings.simplefilter(action='ignore', category=UserWarning)

from ensemble_md.utils import utils # noqa: E402
from ensemble_md.utils import gmx_parser
from ensemble_md.analysis import analyze_traj # noqa: E402
from ensemble_md.analysis import analyze_matrix # noqa: E402
from ensemble_md.analysis import msm_analysis # noqa: E402
Expand Down Expand Up @@ -454,6 +455,17 @@ def main():
for i in range(REXEE.n_sim):
print(f'RMSE of the free energy profile for alchemical range {i} (states {REXEE.state_ranges[i][0]} to {REXEE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501

# Section 5. Process trajecotries for MT-REXEE
if REXEE.modify_coords is not None:
# Section 5.1. Create end-state trajecotries for each simulation
l0, l1, ps_per_frame = gmx_parser.get_end_states(f'{REXEE.working_dir}/sim_0/iteration_0/expanded.mdp')
n_sim, n_iter = np.shape(rep_trajs)
if REXEE.swap_rep_pattern is None:
raise Exception('MT-REXEE trajectory analysis requires swap_rep_pattern to be defined')
analyze_traj.end_states_only_traj(REXEE.working_dir, n_sim, n_iter, l0, l1, REXEE.swap_rep_pattern, ps_per_frame)

# Section 5.2. Create concatenated trajectories for each individual simulation

# Section 4. Calculate the time spent in GROMACS (This could take a while.)
t_wall_tot, t_sync, _ = utils.analyze_REXEE_time()
print(f'\nTotal wall time GROMACS spent to finish all iterations: {utils.format_time(t_wall_tot)}')
Expand Down
18 changes: 18 additions & 0 deletions ensemble_md/utils/gmx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,21 @@ def deter_atom_order(mol_file, resname):
break

return atom_order

def get_end_states(mdp_path):
mdp = MDP(mdp_path)
end_0, end_1 = [], []
coul_lambda = mdp['coul_lambdas']
vdw_lambda = mdp['vdw_lambdas']
n = 0
for vdw, coul in zip(coul_lambda, vdw_lambda):
if vdw == 0.0 and coul == 0.0:
end_0.append(n)
elif vdw == 1.0 and coul == 1.0:
end_1.append(n)
n += 1
dt = mdp['dt']
steps_per_frame = mdp['nstxout']
ps_per_frame = dt*steps_per_frame

return end_0, end_1, ps_per_frame

0 comments on commit ab9f394

Please sign in to comment.