diff --git a/src/connection.cc b/src/connection.cc index fef7ca77..23c744db 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -7,6 +7,7 @@ #include #endif +#include #include #include #include @@ -75,7 +76,7 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register char* dstPtr = (char*)dst.data(); char* srcPtr = (char*)src.data(); - if (stream_->empty()) stream_->set(cudaStreamNonBlocking); + if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking); MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, *stream_)); INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size); @@ -95,7 +96,7 @@ void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, *src = newValue; uint64_t* dstPtr = reinterpret_cast(reinterpret_cast(dst.data()) + dstOffset); - if (stream_->empty()) stream_->set(cudaStreamNonBlocking); + if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking); MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr, src, sizeof(uint64_t), cudaMemcpyHostToDevice, *stream_)); INFO(MSCCLPP_P2P, "CudaIpcConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue, @@ -115,7 +116,7 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { INFO(MSCCLPP_P2P, "CudaIpcConnection flush: timeout is not supported, ignored"); } - if (stream_->empty()) stream_->set(cudaStreamNonBlocking); + if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking); AvoidCudaGraphCaptureGuard guard; MSCCLPP_CUDATHROW(cudaStreamSynchronize(*stream_));