Skip to content

Commit

Permalink
Allow MSCCL++ CommGroup to take PyTorch tensors in args (#255)
Browse files Browse the repository at this point in the history
Obtain data_ptr and tensor_size accordingly for torch.Tensor

Co-authored-by: Binyang Li <binyli@microsoft.com>
  • Loading branch information
aashaka and Binyang2014 authored Feb 7, 2024
1 parent 6a19b19 commit 2101f52
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
38 changes: 29 additions & 9 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import mpi4py
import numpy as np

from mscclpp.utils import is_torch_tensor


class CommGroup:
def __init__(
Expand Down Expand Up @@ -108,8 +110,15 @@ def register_tensor_with_connections(
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
all_registered_memories = {}
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
Expand All @@ -136,20 +145,24 @@ def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor.data.ptr)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
return channels

def make_sm_channels_with_scratch(
self, tensor: cp.ndarray, scratchTensor: cp.ndarray, connections: dict[int, Connection]
self,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
connections: dict[int, Connection],
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(
semaphores[rank], registered_memories[rank], tensor.data.ptr, scratchTensor.data.ptr
)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr)
return channels

def make_proxy_channels(
Expand Down Expand Up @@ -180,8 +193,15 @@ def make_proxy_channels_with_scratch(
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)

semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
Expand Down
17 changes: 16 additions & 1 deletion python/mscclpp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
import struct
import subprocess
import tempfile
from typing import Type
from typing import Any, Type

from cuda import cuda, nvrtc, cudart
import cupy as cp
import numpy as np

try:
import torch

_use_torch = True
torchTensor = torch.Tensor
except ImportError:
_use_torch = False
torchTensor = Type[Any]


def _check_cuda_errors(result):
if result[0].value:
Expand Down Expand Up @@ -145,6 +154,8 @@ def pack(*args):
res += struct.pack("P", arg.ctypes.data)
elif isinstance(arg, cp.ndarray):
res += struct.pack("P", arg.data.ptr)
elif is_torch_tensor(arg):
res += struct.pack("P", arg.data_ptr())
# use int to represent bool, which can avoid CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES error
elif isinstance(arg, bool):
res += struct.pack("i", arg)
Expand All @@ -153,3 +164,7 @@ def pack(*args):
else:
raise RuntimeError(f"Unsupported type: {type(arg)}")
return res


def is_torch_tensor(tensor: Any) -> bool:
return _use_torch and isinstance(tensor, torchTensor)

0 comments on commit 2101f52

Please sign in to comment.