diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 84b0ccc4..50305caf 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -24,28 +24,6 @@ } while (false) namespace { -// Get the recommended granularity for cuMemAddressReserve -size_t getRecommendedGranularity() { -#if (CUDA_NVLS_SUPPORTED) - size_t gran = 0; - int deviceId = -1; - int currentDevice = -1; - MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); - MSCCLPP_CUTHROW(cuDeviceGet(¤tDevice, deviceId)); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.requestedHandleTypes = - (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); - prop.location.id = currentDevice; - MSCCLPP_CUTHROW(cuMemGetAllocationGranularity(&gran, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); - return gran; -#else - throw mscclpp::Error("Only support GPU with NVLS support", mscclpp::ErrorCode::InvalidUsage); -#endif -} - CUmemAllocationHandleType getNvlsCompatibleMemHandleType() { #if (CUDA_NVLS_SUPPORTED) return CU_MEM_HANDLE_TYPE_FABRIC; @@ -239,13 +217,21 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { auto entry = getTransportInfo(Transport::CudaIpc); void* base; if (this->isCuMemMapAlloc) { +#if (CUDA_NVLS_SUPPORTED) CUmemGenericAllocationHandle handle; MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsCompatibleMemHandleType())); - size_t gran = getRecommendedGranularity(); - MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0)); - MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0)); - detail::setReadWriteMemoryAccess(base, this->size); + size_t minGran = detail::getMulticastGranularity(size, CU_MULTICAST_GRANULARITY_MINIMUM); + size_t recommendedGran = detail::getMulticastGranularity(size, CU_MULTICAST_GRANULARITY_RECOMMENDED); + size_t size = (this->size + recommendedGran - 1) / recommendedGran * recommendedGran; + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, size, minGran, 0, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, size, 0, handle, 0)); + detail::setReadWriteMemoryAccess(base, size); this->data = static_cast(base) + entry.offsetFromBase; +#else + throw mscclpp::Error( + "CUDA does not support NVLS. Please ensure your CUDA version supports NVLS to use this feature.", + mscclpp::ErrorCode::InvalidUsage); +#endif } else { MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); this->data = static_cast(base) + entry.cudaIpcOffsetFromBase;