Skip to content

Commit

Permalink
use env to use default stream
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Jan 12, 2025
1 parent fdac9a2 commit 3a70ecd
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <mscclpp/npkit/npkit.hpp>
#endif

#include <mscclpp/env.hpp>
#include <mscclpp/utils.hpp>
#include <sstream>
#include <thread>
Expand Down Expand Up @@ -75,7 +76,7 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register
char* dstPtr = (char*)dst.data();
char* srcPtr = (char*)src.data();

if (stream_->empty()) stream_->set(cudaStreamNonBlocking);
if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);

MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, *stream_));
INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size);
Expand All @@ -95,7 +96,7 @@ void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset,
*src = newValue;
uint64_t* dstPtr = reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(dst.data()) + dstOffset);

if (stream_->empty()) stream_->set(cudaStreamNonBlocking);
if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);

MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr, src, sizeof(uint64_t), cudaMemcpyHostToDevice, *stream_));
INFO(MSCCLPP_P2P, "CudaIpcConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
Expand All @@ -115,7 +116,7 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
INFO(MSCCLPP_P2P, "CudaIpcConnection flush: timeout is not supported, ignored");
}

if (stream_->empty()) stream_->set(cudaStreamNonBlocking);
if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);

AvoidCudaGraphCaptureGuard guard;
MSCCLPP_CUDATHROW(cudaStreamSynchronize(*stream_));
Expand Down

0 comments on commit 3a70ecd

Please sign in to comment.