Skip to content

Commit

Permalink
Supporting Executor multi node in NCCL API (#412)
Browse files Browse the repository at this point in the history
Co-authored-by: Binyang Li <binyli@microsoft.com>
  • Loading branch information
caiomcbr and Binyang2014 authored Dec 18, 2024
1 parent fcb2e46 commit 774602d
Showing 1 changed file with 47 additions and 31 deletions.
78 changes: 47 additions & 31 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ static std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> setupSmChannel

static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) {
// FallBack for single node
if (comm->comm->bootstrap()->getNranks() != comm->comm->bootstrap()->getNranksPerNode()) return ncclInvalidUsage;

// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr)
return ncclInvalidArgument;
Expand Down Expand Up @@ -302,6 +305,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,

static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
// FallBack for single node
if (comm->comm->bootstrap()->getNranks() != comm->comm->bootstrap()->getNranksPerNode()) return ncclInvalidUsage;

// Checking if the parameters are valids
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

Expand Down Expand Up @@ -340,34 +347,8 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) {
if (uniqueId == nullptr) return ncclInvalidArgument;
if (MSCCLPP_UNIQUE_ID_BYTES != NCCL_UNIQUE_ID_BYTES) return ncclInternalError;
mscclpp::UniqueId id = mscclpp::TcpBootstrap::createUniqueId();
memcpy(uniqueId, &id, sizeof(ncclUniqueId));
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank,
ncclConfig_t*) {
// TODO: implement config
return ncclCommInitRank(comm, nranks, commId, rank);
}

NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
if (comm == nullptr) return ncclInvalidArgument;
if (nranks < 0 || rank < 0 || rank >= nranks) return ncclInvalidArgument;
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, nranks);
mscclpp::UniqueId id;
memcpy(id.data(), &commId, sizeof(ncclUniqueId));
bootstrap->initialize(id);
std::shared_ptr<mscclpp::Communicator> mscclppComm = std::make_shared<mscclpp::Communicator>(bootstrap);
static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_ptr<mscclpp::Communicator> mscclppComm,
int rank) {
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;

for (int i = 0; i < mscclppComm->bootstrap()->getNranks(); i++) {
Expand All @@ -390,19 +371,54 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
}
}
}
mscclppComm->setup();

ncclComm* commPtr = new ncclComm();
commPtr->comm = mscclppComm;
mscclppComm->setup();
commPtr->connections = std::move(connections);
commPtr->smSemaphores = std::move(smSemaphores);
commPtr->buffFlag = 0;
commPtr->numScratchBuff = 2;
commPtr->scratchBuff = mscclpp::allocExtSharedCuda<char>(SCRATCH_SIZE);
commPtr->remoteScratchRegMemories =
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
}

NCCL_API ncclResult_t ncclGetVersion(int* version) {
if (version == nullptr) return ncclInvalidArgument;
*version = MSCCLPP_VERSION;
return ncclSuccess;
}

NCCL_API ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) {
if (uniqueId == nullptr) return ncclInvalidArgument;
if (MSCCLPP_UNIQUE_ID_BYTES != NCCL_UNIQUE_ID_BYTES) return ncclInternalError;
mscclpp::UniqueId id = mscclpp::TcpBootstrap::createUniqueId();
memcpy(uniqueId, &id, sizeof(ncclUniqueId));
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank,
ncclConfig_t*) {
// TODO: implement config
return ncclCommInitRank(comm, nranks, commId, rank);
}

NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
if (comm == nullptr) return ncclInvalidArgument;
if (nranks < 0 || rank < 0 || rank >= nranks) return ncclInvalidArgument;
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, nranks);
mscclpp::UniqueId id;
memcpy(id.data(), &commId, sizeof(ncclUniqueId));
bootstrap->initialize(id);
std::shared_ptr<mscclpp::Communicator> mscclppComm = std::make_shared<mscclpp::Communicator>(bootstrap);
ncclComm* commPtr = new ncclComm();

commPtr->comm = mscclppComm;
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);

// FallBack for single node
if (mscclppComm->bootstrap()->getNranks() == mscclppComm->bootstrap()->getNranksPerNode())
ncclCommInitRankFallbackSingleNode(commPtr, mscclppComm, rank);

if (getenv("MSCCLPP_EXECUTION_PLAN_DIR")) {
std::string collectiveDir = getenv("MSCCLPP_EXECUTION_PLAN_DIR");
if (!std::filesystem::is_directory(collectiveDir)) {
Expand Down

0 comments on commit 774602d

Please sign in to comment.