diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index fe240de7..a35f055b 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -217,6 +217,9 @@ static std::shared_ptr> setupSmChannel static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t, ncclComm_t comm, cudaStream_t stream) { + // FallBack for single node + if (comm->comm->bootstrap()->getNranks() != comm->comm->bootstrap()->getNranksPerNode()) return ncclInvalidUsage; + // Checking if the parameters are valids if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr) return ncclInvalidArgument; @@ -302,6 +305,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { + // FallBack for single node + if (comm->comm->bootstrap()->getNranks() != comm->comm->bootstrap()->getNranksPerNode()) return ncclInvalidUsage; + + // Checking if the parameters are valids size_t bytes = sendcount * ncclTypeSize(datatype); if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument; @@ -340,34 +347,8 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, return ncclSuccess; } -NCCL_API ncclResult_t ncclGetVersion(int* version) { - if (version == nullptr) return ncclInvalidArgument; - *version = MSCCLPP_VERSION; - return ncclSuccess; -} - -NCCL_API ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) { - if (uniqueId == nullptr) return ncclInvalidArgument; - if (MSCCLPP_UNIQUE_ID_BYTES != NCCL_UNIQUE_ID_BYTES) return ncclInternalError; - mscclpp::UniqueId id = mscclpp::TcpBootstrap::createUniqueId(); - memcpy(uniqueId, &id, sizeof(ncclUniqueId)); - return ncclSuccess; -} - -NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank, - ncclConfig_t*) { - // TODO: implement config - return ncclCommInitRank(comm, nranks, commId, rank); -} - -NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) { - if (comm == nullptr) return ncclInvalidArgument; - if (nranks < 0 || rank < 0 || rank >= nranks) return ncclInvalidArgument; - std::shared_ptr bootstrap = std::make_shared(rank, nranks); - mscclpp::UniqueId id; - memcpy(id.data(), &commId, sizeof(ncclUniqueId)); - bootstrap->initialize(id); - std::shared_ptr mscclppComm = std::make_shared(bootstrap); +static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_ptr mscclppComm, + int rank) { std::vector>> connectionFutures; for (int i = 0; i < mscclppComm->bootstrap()->getNranks(); i++) { @@ -390,10 +371,8 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI } } } - mscclppComm->setup(); - ncclComm* commPtr = new ncclComm(); - commPtr->comm = mscclppComm; + mscclppComm->setup(); commPtr->connections = std::move(connections); commPtr->smSemaphores = std::move(smSemaphores); commPtr->buffFlag = 0; @@ -401,8 +380,45 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI commPtr->scratchBuff = mscclpp::allocExtSharedCuda(SCRATCH_SIZE); commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); +} + +NCCL_API ncclResult_t ncclGetVersion(int* version) { + if (version == nullptr) return ncclInvalidArgument; + *version = MSCCLPP_VERSION; + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) { + if (uniqueId == nullptr) return ncclInvalidArgument; + if (MSCCLPP_UNIQUE_ID_BYTES != NCCL_UNIQUE_ID_BYTES) return ncclInternalError; + mscclpp::UniqueId id = mscclpp::TcpBootstrap::createUniqueId(); + memcpy(uniqueId, &id, sizeof(ncclUniqueId)); + return ncclSuccess; +} + +NCCL_API ncclResult_t ncclCommInitRankConfig(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank, + ncclConfig_t*) { + // TODO: implement config + return ncclCommInitRank(comm, nranks, commId, rank); +} + +NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) { + if (comm == nullptr) return ncclInvalidArgument; + if (nranks < 0 || rank < 0 || rank >= nranks) return ncclInvalidArgument; + std::shared_ptr bootstrap = std::make_shared(rank, nranks); + mscclpp::UniqueId id; + memcpy(id.data(), &commId, sizeof(ncclUniqueId)); + bootstrap->initialize(id); + std::shared_ptr mscclppComm = std::make_shared(bootstrap); + ncclComm* commPtr = new ncclComm(); + + commPtr->comm = mscclppComm; commPtr->executor = std::make_shared(mscclppComm); + // FallBack for single node + if (mscclppComm->bootstrap()->getNranks() == mscclppComm->bootstrap()->getNranksPerNode()) + ncclCommInitRankFallbackSingleNode(commPtr, mscclppComm, rank); + if (getenv("MSCCLPP_EXECUTION_PLAN_DIR")) { std::string collectiveDir = getenv("MSCCLPP_EXECUTION_PLAN_DIR"); if (!std::filesystem::is_directory(collectiveDir)) {