Skip to content

Commit

Permalink
WIP: added mpi4py
Browse files Browse the repository at this point in the history
  • Loading branch information
mpvanderschelling committed Feb 10, 2025
1 parent 29f4a0a commit d2e0f67
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 147 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ hydra-core
pathos>=0.3.0
autograd
SALib
filelock
filelock
mpi4py
269 changes: 123 additions & 146 deletions src/f3dasm/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from __future__ import annotations

# Standard
import functools
import multiprocessing
import traceback
from abc import ABC, abstractmethod
from functools import partial
Expand All @@ -20,9 +22,12 @@
from filelock import FileLock
from pathos.helpers import mp

from ._io import EXPERIMENTDATA_SUBFOLDER, LOCK_FILENAME, MAX_TRIES
# Local
from ._io import EXPERIMENTDATA_SUBFOLDER, LOCK_FILENAME, MAX_TRIES
from .errors import TimeOutError
from .logger import logger
from .mpi_utils import (mpi_get_open_job, mpi_lock_manager,
mpi_store_experiment_sample, mpi_terminate_worker)

# Authorship & Credits
# =============================================================================
Expand Down Expand Up @@ -274,13 +279,15 @@ def call(self, data: ExperimentData | str,
return self._evaluate_multiprocessing(data=data, **kwargs)
elif mode.lower() == "cluster":
return self._evaluate_cluster(data=data, **kwargs)
elif mode.lower() == "mpi":
return self._evaluate_mpi(data=data, **kwargs)
else:
raise ValueError(f"Invalid parallelization mode specified: {mode}")

# =========================================================================

def _evaluate_sequential(self, data: ExperimentData, **kwargs
) -> ExperimentData:
def _evaluate_sequential(self, data: ExperimentData, timeout: int = 0,
**kwargs) -> ExperimentData:
"""Run the operation sequentially
Parameters
Expand All @@ -294,14 +301,19 @@ def _evaluate_sequential(self, data: ExperimentData, **kwargs
Raised when there are no open jobs left
"""

if timeout > 0:
execute_fn = timeout_wrapper(timeout=timeout)(self.execute)
else:
execute_fn = self.execute

while True:
job_number, experiment_sample = data.get_open_job()
if job_number is None:
logger.debug("No Open jobs left!")
break

try:
experiment_sample: ExperimentSample = self.execute(
experiment_sample: ExperimentSample = execute_fn(
experiment_sample=experiment_sample, **kwargs)

experiment_sample.store_experimentsample_references(
Expand All @@ -322,45 +334,6 @@ def _evaluate_sequential(self, data: ExperimentData, **kwargs
)
return data

# while True:

# job_number, experiment_sample = data.get_open_job()
# logger.debug(
# f"Accessed experiment_sample \
# {job_number}")
# if job_number is None:
# logger.debug("No Open Jobs left")
# break

# try:

# # # If kwargs is empty dict
# # if not kwargs:
# # logger.debug(
# # f"Running experiment_sample "
# # f"{job_number}")
# # else:
# # logger.debug(
# # f"Running experiment_sample "
# # f"{job_number} with kwargs {kwargs}")
# experiment_sample: ExperimentSample = self.execute(
# experiment_sample=experiment_sample, **kwargs)

# # _experiment_sample = self._run(
# # experiment_sample, **kwargs) # no *args!

# data.store_experimentsample(
# experiment_sample=experiment_sample,
# idx=job_number)
# except Exception as e:
# error_msg = f"Error in experiment_sample \
# {job_number}: {e}"
# error_traceback = traceback.format_exc()
# logger.error(f"{error_msg}\n{error_traceback}")
# data.mark(indices=job_number, status='error')

# return data

def _evaluate_multiprocessing(
self, data: ExperimentData,
nodes: int = mp.cpu_count(), **kwargs) -> ExperimentData:
Expand Down Expand Up @@ -398,81 +371,23 @@ def f(options: Dict[str, Any]) -> Tuple[int, ExperimentSample, int]:
finally:
return (job_number, experiment_sample)

# except Exception as e:
# error_msg = f"Error in experiment_sample \
# {options['_job_number']}: {e}"
# error_traceback = traceback.format_exc()
# logger.error(f"{error_msg}\n{error_traceback}")
# return (options['_job_number'],
# options['experiment_sample'], 1)

with mp.Pool() as pool:
# maybe implement pool.starmap_async ?
_experiment_samples: List[
Tuple[int, ExperimentSample, int]] = pool.starmap(f, options)

for job_number, experiment_sample in _experiment_samples:
# if exit_code == 0:
data.store_experimentsample(
experiment_sample=experiment_sample,
idx=job_number)
# else:
# data.mark(indices=job_number, status='error')

return data

# def _evaluate_cluster(
# self, data: ExperimentData, **kwargs) -> ExperimentData:
# """Run the operation on the cluster

# Parameters
# ----------
# operation : ExperimentSampleCallable
# function execution for every entry in the ExperimentData object
# kwargs : dict
# Any keyword arguments that need to be supplied to the function

# Raises
# ------
# NoOpenJobsError
# Raised when there are no open jobs left
# """
# data = type(data).from_file(data.project_dir)

# get_open_job = data.access_file(type(data).get_open_job)
# store_experiment_sample = data.access_file(
# type(data).store_experimentsample)
# mark = data.access_file(type(data).mark)

# while True:
# job_number, experiment_sample = get_open_job()
# if job_number is None:
# logger.debug("No Open jobs left!")
# break

# try:
# _experiment_sample = self._run(
# experiment_sample, **kwargs)
# store_experiment_sample(experiment_sample=_experiment_sample,
# id=job_number)
# except Exception:
# # n = experiment_sample.job_number
# error_msg = f"Error in experiment_sample {job_number}: "
# error_traceback = traceback.format_exc()
# logger.error(f"{error_msg}\n{error_traceback}")
# mark(indices=job_number, status='error')
# continue

# data = type(data).from_file(data.project_dir)

# # Remove the lockfile from disk
# data.remove_lockfile()
# return data

def _evaluate_cluster(
self, data: ExperimentData,
wait_for_creation: bool = False,
max_tries: int = MAX_TRIES, **kwargs
max_tries: int = MAX_TRIES,
timeout: int = 0, **kwargs
) -> None:

# Creat lockfile
Expand All @@ -491,14 +406,19 @@ def _evaluate_cluster(
wait_for_creation=wait_for_creation, max_tries=max_tries,
lockfile=lockfile)

if timeout > 0:
execute_fn = timeout_wrapper(timeout=timeout)(self.execute)
else:
execute_fn = self.execute

while True:
job_number, experiment_sample = cluster_get_open_job()
if job_number is None:
logger.debug("No Open jobs left!")
break

try:
experiment_sample: ExperimentSample = self.execute(
experiment_sample: ExperimentSample = execute_fn(
experiment_sample=experiment_sample, **kwargs)

experiment_sample.store_experimentsample_references(
Expand All @@ -521,7 +441,24 @@ def _evaluate_cluster(
(data.project_dir / EXPERIMENTDATA_SUBFOLDER / LOCK_FILENAME
).with_suffix('.lock').unlink(missing_ok=True)

# =========================================================================
def _evaluate_mpi(
self, comm, data: ExperimentData,
wait_for_creation: bool = False,
max_tries: int = MAX_TRIES,
timeout: int = 0, **kwargs
) -> None:
rank = comm.Get_rank()
size = comm.Get_size()

if rank == 0:
mpi_lock_manager(comm=comm, size=size)
else:
mpi_worker(comm=comm, data=data, execute_fn=self.execute,
wait_for_creation=wait_for_creation,
max_tries=max_tries,
timeout=timeout, **kwargs)

# =========================================================================

@abstractmethod
def execute(self, experiment_sample: ExperimentSample,
Expand All @@ -548,56 +485,16 @@ def execute(self, experiment_sample: ExperimentSample,
"""
...

# def _run(
# self, experiment_sample: ExperimentSample,
# **kwargs) -> ExperimentSample:
# """
# Run the data generator.

# The function also caches the experiment_sample in the data generator.
# This allows the user to access the experiment_sample in the
# execute function as a class variable
# called self.experiment_sample.

# Parameters
# ----------
# ExperimentSample : ExperimentSample
# The design to run the data generator on

# kwargs : dict
# The keyword arguments to pass to the pre_process, execute \
# and post_process

# Returns
# -------
# ExperimentSample
# Processed design with the response of the data generator \
# saved in the experiment_sample
# """
# self.experiment_sample = experiment_sample

# self.experiment_sample.mark('in_progress')

# self.execute(**kwargs)

# self.experiment_sample.mark('finished')

# return self.experiment_sample

# =============================================================================

# lockfile = FileLock(
# (project_dir / EXPERIMENTDATA_SUBFOLDER / LOCK_FILENAME
# ).with_suffix('.lock'))


def get_open_job(experiment_data_type: Type[ExperimentData],
project_dir: Path, lockfile: FileLock,
wait_for_creation: bool, max_tries: int,
) -> Tuple[int, ExperimentSample]:

with lockfile:
data = experiment_data_type.from_file(
data: ExperimentData = experiment_data_type.from_file(
project_dir=project_dir, wait_for_creation=wait_for_creation,
max_tries=max_tries)

Expand All @@ -621,3 +518,83 @@ def store_experiment_sample(
data.store_experimentsample(experiment_sample=experiment_sample,
idx=idx)
data.store(project_dir)


def timeout_wrapper(timeout):
"""Decorator to enforce a timeout on a JAX function."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def target_func(queue, *args, **kwargs):
try:
result = func(*args, **kwargs)
queue.put(result)
except Exception as e:
queue.put(e)

queue = multiprocessing.Queue()
process = multiprocessing.Process(
target=target_func, args=(queue, *args), kwargs=kwargs)
process.start()
process.join(timeout)

if process.is_alive():
process.terminate()
process.join()
raise TimeOutError(timeout=timeout)

result = queue.get()
if isinstance(result, Exception):
raise result
return result

return wrapper
return decorator


def mpi_worker(
comm, data: ExperimentData,
execute_fn: Callable,
wait_for_creation: bool = False,
max_tries: int = MAX_TRIES,
timeout: int = 0, **kwargs
) -> None:

cluster_get_open_job = partial(
mpi_get_open_job, experiment_data_type=type(data),
project_dir=data.project_dir,
wait_for_creation=wait_for_creation, max_tries=max_tries,
comm=comm)
cluster_store_experiment_sample = partial(
mpi_store_experiment_sample, experiment_data_type=type(data),
project_dir=data.project_dir,
wait_for_creation=wait_for_creation, max_tries=max_tries,
comm=comm)

while True:
job_number, experiment_sample = cluster_get_open_job()
if job_number is None:
logger.debug("No Open jobs left!")
break

try:
experiment_sample: ExperimentSample = execute_fn(
experiment_sample=experiment_sample, **kwargs)

experiment_sample.store_experimentsample_references(
idx=job_number)

experiment_sample.mark('finished')

except Exception:
error_msg = f"Error in experiment_sample {job_number}: "
error_traceback = traceback.format_exc()
logger.error(f"{error_msg}\n{error_traceback}")
experiment_sample.mark('error')
continue

finally:
cluster_store_experiment_sample(
idx=job_number, experiment_sample=experiment_sample)

mpi_terminate_worker(comm)
Loading

0 comments on commit d2e0f67

Please sign in to comment.