Skip to content

Commit

Permalink
Added a unit test for plot_dg_evolution
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Apr 9, 2024
1 parent f18c1ad commit 0765a20
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ def get_dg_evolution(log_files, start_state, end_state):
return dg


def plot_dg_evolution(log_files, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2):
def plot_dg_evolution(log_files, start_state, end_state, start_idx=None, end_idx=None, dt_log=2):
"""
For weight-updating simulations, plots the time series of the weight
difference (:math:`Δg = g_2-g_1`) between the specified states.
Expand Down
28 changes: 26 additions & 2 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,8 +1023,32 @@ def test_get_dg_evoluation(mock_fn):
np.testing.assert_array_almost_equal(dg, np.array([1.80707, 1.80707, 2.60707, 2.60707, 2.60707]))


def test_plot_dg_evolution():
pass
@patch('ensemble_md.analysis.analyze_traj.plt')
@patch('ensemble_md.analysis.analyze_traj.get_dg_evolution')
def test_plot_dg_evolution(mock_fn, mock_plt): # the outer decorator mock_plt should be the second parameter
# Test 1: Short dg
mock_fn.return_value = np.arange(10)
dg = analyze_traj.plot_dg_evolution(['log_0.log'], 1, 3) # the values of log_files does not matter since the mocked value of dg is specified anyway # noqa: E501
mock_fn.assert_called_once_with(['log_0.log'], 1, 3)
np.testing.assert_array_equal(dg, np.arange(10))
t = np.array([0, 2, 4, 6, 8, 10, 12, 14, 16, 18])

mock_plt.figure.assert_called()
np.testing.assert_array_equal(mock_plt.plot.call_args_list[0][0][0], t)
np.testing.assert_array_equal(mock_plt.plot.call_args_list[0][0][1], dg)
mock_plt.xlabel.assert_called_once_with('Time (ps)')
mock_plt.ylabel.assert_called_once_with(r'$\Delta g$')
mock_plt.grid.assert_called_once()
mock_plt.savefig.assert_called_once_with('dg_evolution.png', dpi=600)

# Test 2: Long dg
mock_fn.reset_mock()
mock_plt.reset_mock()
mock_fn.return_value = np.arange(20000)
dg = analyze_traj.plot_dg_evolution(['log_0.log'], 1, 3, start_idx=100)
mock_fn.assert_called_once_with(['log_0.log'], 1, 3)
np.testing.assert_array_equal(dg, np.arange(20000)[100:])
mock_plt.xlabel.assert_called_once_with('Time (ns)')


def test_get_delta_w_updates():
Expand Down

0 comments on commit 0765a20

Please sign in to comment.