Skip to content

Commit

Permalink
Modified test_mpi_func.py and .circleci/config.yml
Browse files Browse the repository at this point in the history
  • Loading branch information
wehs7661 committed Apr 9, 2024
1 parent a0f93b7 commit 9a6e574
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
8 changes: 4 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ jobs:
pip3 install pytest
pip3 install pytest-mpi
pip3 install pytest-cov
pytest -vv --disable-pytest-warnings --cov=ensemble_md --cov-report=xml --color=yes ensemble_md/tests/
# COVERAGE_FILE=.coverage_1 pytest -vv --disable-pytest-warnings --cov=ensemble_md --cov-report=xml --color=yes ensemble_md/tests/
# COVERAGE_FILE=.coverage_2 mpirun -np 4 pytest -vv --disable-pytest-warnings --cov=ensemble_md --cov-report=xml --color=yes ensemble_md/tests/test_mpi_func.py --with-mpi
# coverage combine .coverage_*
# pytest -vv --disable-pytest-warnings --cov=ensemble_md --cov-report=xml --color=yes ensemble_md/tests/
COVERAGE_FILE=.coverage_1 pytest -vv --disable-pytest-warnings --cov=ensemble_md --cov-report=xml --color=yes ensemble_md/tests/
COVERAGE_FILE=.coverage_2 mpirun -np 4 pytest -vv --disable-pytest-warnings --cov=ensemble_md --cov-report=xml --color=yes ensemble_md/tests/test_mpi_func.py --with-mpi
coverage combine .coverage_*
- run:
name: CodeCov
Expand Down
25 changes: 18 additions & 7 deletions ensemble_md/tests/test_mpi_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
"""
import os
import yaml
import glob
import shutil
import pytest
from mpi4py import MPI
from ensemble_md.replica_exchange_EE import ReplicaExchangeEE


Expand All @@ -22,11 +24,14 @@ def params_dict():
"""
Generates a dictionary containing the required REXEE parameters.
"""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

REXEE_dict = {
'gmx_executable': 'gmx',
'gro': 'ensemble_md/tests/data/sys.gro',
'top': 'ensemble_md/tests/data/sys.top',
'mdp': 'ensemble_md/tests/data/mdp/expanded.mdp',
'mdp': 'ensemble_md/tests/data/expanded.mdp',
'n_sim': 4,
'n_iter': 10,
's': 1,
Expand All @@ -35,14 +40,21 @@ def params_dict():
yield REXEE_dict

# Remove the file after the unit test is done.
if os.path.isfile('params.yaml') is True:
os.remove('params.yaml')
yml_file = f'params_{rank}.yaml'
if os.path.isfile(yml_file):
os.remove(yml_file)


def get_REXEE_instance(input_dict, yml_file='params.yaml'):
def get_REXEE_instance(input_dict, rank, yml_file=None):
"""
Saves a dictionary as a yaml file and use it to instantiate the ReplicaExchangeEE class.
This version of the function creates a unique YAML file for each MPI process. This could
avoid race conditions between MPI processes where one process reads the file before another
finished writing it, or even worse, tries to read it while it's being written, leading to
inconsistent or incomplete data being read.
"""
if yml_file is None:
yml_file = f'params_{rank}.yaml'
with open(yml_file, 'w') as f:
yaml.dump(input_dict, f)
REXEE = ReplicaExchangeEE(yml_file)
Expand Down Expand Up @@ -94,7 +106,6 @@ def get_gmx_cmd_from_output(output):

@pytest.mark.mpi
def test_run_grompp(params_dict):
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

Expand All @@ -103,7 +114,7 @@ def test_run_grompp(params_dict):
# Case 1: The first iteration, i.e., n = 0
n = 0
swap_pattern = [1, 0, 2, 3]
REXEE = get_REXEE_instance(params_dict)
REXEE = get_REXEE_instance(params_dict, rank)

if rank == 0:
for i in range(params_dict['n_sim']):
Expand Down Expand Up @@ -133,7 +144,7 @@ def test_run_grompp(params_dict):

# Case 2: Other iterations, i.e., n != 0
n = 1 # For swap_pattern, we stick with [1, 0, 2, 3]
REXEE = get_REXEE_instance(params_dict)
REXEE = get_REXEE_instance(params_dict, rank)
if rank == 0:
for i in range(params_dict['n_sim']):
os.makedirs(f'{REXEE.working_dir}/sim_{i}/iteration_{n}')
Expand Down

0 comments on commit 9a6e574

Please sign in to comment.