Skip to content

Commit

Permalink
add multi-gpu support for cupy scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Jan 17, 2025
1 parent 32d3a1f commit 50a85d8
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions pycbc/scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from functools import wraps
import logging
from .libutils import get_ctypes_library
from .pool import use_mpi

logger = logging.getLogger('pycbc.scheme')

Expand Down Expand Up @@ -117,26 +118,43 @@ def __init__(self, device_num=0):


class CUPYScheme(Scheme):
"""Scheme for using CUPY"""
"""Scheme for using CUPY.
Supports using CUPY with MPI. If MPI is enabled, will use all available
devices. The environment variable `CUDA_VISIBLE_DEVICES` can be used to
restrict the devices used.
Parameters
----------
device_num : int, optional
The device number to use. If not provided, will use the default, 0.
Should not be provided when using MPI to parallelize across devices.
"""
def __init__(self, device_num=None):
import cupy # Fail now if cupy is not there.
import cupy.cuda

do_mpi, _, rank = use_mpi(require_mpi=False, log=False)

if device_num is not None and do_mpi:
logger.warning("MPI is enabled, but a device number was provided.")

if device_num is None and do_mpi:
# Logical device numbers will always be 0, 1, 2, ... etc. irrespective
# of the physical device numbers.
device_num = rank % cupy.cuda.runtime.getDeviceCount()
logging.debug("MPI enabled, using CUDA device %s", device_num)

self.device_num = device_num
self.cuda_device = cupy.cuda.Device(self.device_num)

def __enter__(self):
super().__enter__()
self.cuda_device.__enter__()
logging.warn(
"You are using the CUPY GPU backend for PyCBC. This backend is "
"still only a prototype. It may be useful for your application "
"but it may fail unexpectedly, run slowly, or not give correct "
"output. Please do contribute to the effort to develop this "
"further."
)

def __exit__(self, *args):
super().__exit__(*args)
self.cuda_device.__exit__(*args)
self.cuda.device.__exit__(*args)


class CPUScheme(Scheme):
Expand Down

0 comments on commit 50a85d8

Please sign in to comment.