From 974f1a8915e7f15f8d8ba37ed8ef9277bb8d86f3 Mon Sep 17 00:00:00 2001 From: Takuya Narihira Date: Thu, 28 Nov 2019 14:51:32 +0900 Subject: [PATCH] Properly handle MPI errors --- ...ulti_process_data_parallel_communicator.cu | 74 +++++++++++++------ 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/src/nbla/cuda/communicator/multi_process_data_parallel_communicator.cu b/src/nbla/cuda/communicator/multi_process_data_parallel_communicator.cu index 0092f7a78..ffde707d4 100644 --- a/src/nbla/cuda/communicator/multi_process_data_parallel_communicator.cu +++ b/src/nbla/cuda/communicator/multi_process_data_parallel_communicator.cu @@ -34,6 +34,27 @@ using std::vector; using std::make_shared; using std::unordered_set; +/** Get MPI error string from error code. + */ +std::string get_mpi_error_string(int error) { + int len; + char estring[MPI_MAX_ERROR_STRING]; + MPI_Error_string(error, estring, &len); + return std::string(estring); +} + +/** MPI error handler which throws an exception +*/ +#define NBLA_MPI_CHECK(condition) \ + { \ + int error = condition; \ + if (error != MPI_SUCCESS) { \ + auto estring = get_mpi_error_string(error); \ + NBLA_ERROR(error_code::runtime, "`" #condition "` failed by `%s`.", \ + estring.c_str()); \ + } \ + } + /** MPI singleton class that manages lifetime of MPI. @@ -71,8 +92,8 @@ public: */ return; } - MPI_Group_free(&world_group_); - MPI_Finalize(); + NBLA_MPI_CHECK(MPI_Group_free(&world_group_)); + NBLA_MPI_CHECK(MPI_Finalize()); } /** @@ -80,7 +101,7 @@ public: */ static bool initialized() { int flag = 1; - MPI_Initialized(&flag); + NBLA_MPI_CHECK(MPI_Initialized(&flag)); return bool(flag); } @@ -89,7 +110,7 @@ public: */ static bool finalized() { int flag = 1; - MPI_Finalized(&flag); + NBLA_MPI_CHECK(MPI_Finalized(&flag)); return bool(flag); } @@ -115,7 +136,8 @@ private: char **argv = nullptr; int requiredThreadLevelSupport = MPI_THREAD_SERIALIZED; int provided; - MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided); + NBLA_MPI_CHECK( + MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided)); if (provided != requiredThreadLevelSupport) { NBLA_ERROR(error_code::target_specific, "MPI_Init_thread failed since provided (%d) is not equal to " @@ -123,7 +145,7 @@ private: provided, requiredThreadLevelSupport); } } - MPI_Comm_group(MPI_COMM_WORLD, &world_group_); + NBLA_MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_group_)); } MPI_Group world_group_; }; @@ -161,8 +183,10 @@ public: MPI_COMM_WORLD. */ MpiCommWrapper(std::vector ranks) : mpi_(Mpi::get()), own_(true) { - MPI_Group_incl(Mpi::world_group(), ranks.size(), ranks.data(), &group_); - MPI_Comm_create_group(MPI_COMM_WORLD, group_, 0, &this->comm_); + NBLA_MPI_CHECK(MPI_Group_incl(Mpi::world_group(), ranks.size(), + ranks.data(), &group_)); + NBLA_MPI_CHECK( + MPI_Comm_create_group(MPI_COMM_WORLD, group_, 0, &this->comm_)); } /* Deletes MPI_Group and MPI_Comm objects when it owns. @@ -233,8 +257,8 @@ template bool MultiProcessDataParallelCommunicatorNccl::mpi_check_any( bool condition, const string &group) { bool result; - MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LOR, - this->mpi_comms_[group]->comm()); + NBLA_MPI_CHECK(MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LOR, + this->mpi_comms_[group]->comm())); return result; } @@ -242,8 +266,8 @@ template bool MultiProcessDataParallelCommunicatorNccl::mpi_check_all( bool condition, const string &group) { bool result; - MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LAND, - this->mpi_comms_[group]->comm()); + NBLA_MPI_CHECK(MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LAND, + this->mpi_comms_[group]->comm())); return result; } @@ -289,8 +313,8 @@ template void MultiProcessDataParallelCommunicatorNccl::init() { this->mpi_comms_["world"] = make_shared(); // Create comm, set size, and rank - MPI_Comm_size(MPI_COMM_WORLD, &this->size_); - MPI_Comm_rank(MPI_COMM_WORLD, &this->rank_); + NBLA_MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &this->size_)); + NBLA_MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &this->rank_)); // Set local rank and device id uint64_t host_hashs[this->size_]; @@ -298,9 +322,9 @@ template void MultiProcessDataParallelCommunicatorNccl::init() { get_host_name(hostname, 1024); host_hashs[this->rank_] = get_host_hash(hostname); - MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, host_hashs, - sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD); - MPI_Barrier(MPI_COMM_WORLD); + NBLA_MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, host_hashs, + sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); + NBLA_MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD)); int local_rank = 0; for (int i = 0; i < this->size_; ++i) { @@ -320,8 +344,9 @@ template void MultiProcessDataParallelCommunicatorNccl::init() { if (this->rank_ == 0) { ncclGetUniqueId(&comm_id); } - MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, MPI_COMM_WORLD); - MPI_Barrier(MPI_COMM_WORLD); + NBLA_MPI_CHECK( + MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, MPI_COMM_WORLD)); + NBLA_MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD)); // NCCL Init cuda_set_device(device_id_); @@ -352,12 +377,12 @@ template void MultiProcessDataParallelCommunicatorNccl::init() { template void MultiProcessDataParallelCommunicatorNccl::barrier() { - MPI_Barrier(MPI_COMM_WORLD); + NBLA_MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD)); } template void MultiProcessDataParallelCommunicatorNccl::abort() { - MPI_Abort(MPI_COMM_WORLD, -1); + NBLA_MPI_CHECK(MPI_Abort(MPI_COMM_WORLD, -1)); } template @@ -402,9 +427,10 @@ string MultiProcessDataParallelCommunicatorNccl::new_group( ncclGetUniqueId(&comm_id); } int rank; - MPI_Comm_rank(group_mpi_comm->comm(), &rank); - MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, group_mpi_comm->comm()); - MPI_Barrier(group_mpi_comm->comm()); + NBLA_MPI_CHECK(MPI_Comm_rank(group_mpi_comm->comm(), &rank)); + NBLA_MPI_CHECK(MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, + group_mpi_comm->comm())); + NBLA_MPI_CHECK(MPI_Barrier(group_mpi_comm->comm())); // NCCL Comm Init cuda_set_device(device_id_);