Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle MPI errors #209

Merged
merged 1 commit into from
Mar 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,27 @@ using std::vector;
using std::make_shared;
using std::unordered_set;

/** Get MPI error string from error code.
*/
std::string get_mpi_error_string(int error) {
int len;
char estring[MPI_MAX_ERROR_STRING];
MPI_Error_string(error, estring, &len);
return std::string(estring);
}

/** MPI error handler which throws an exception
*/
#define NBLA_MPI_CHECK(condition) \
{ \
int error = condition; \
if (error != MPI_SUCCESS) { \
auto estring = get_mpi_error_string(error); \
NBLA_ERROR(error_code::runtime, "`" #condition "` failed by `%s`.", \
estring.c_str()); \
} \
}

/**
MPI singleton class that manages lifetime of MPI.

Expand Down Expand Up @@ -71,16 +92,16 @@ public:
*/
return;
}
MPI_Group_free(&world_group_);
MPI_Finalize();
NBLA_MPI_CHECK(MPI_Group_free(&world_group_));
NBLA_MPI_CHECK(MPI_Finalize());
}

/**
Returns whether MPI is initialzed or not.
*/
static bool initialized() {
int flag = 1;
MPI_Initialized(&flag);
NBLA_MPI_CHECK(MPI_Initialized(&flag));
return bool(flag);
}

Expand All @@ -89,7 +110,7 @@ public:
*/
static bool finalized() {
int flag = 1;
MPI_Finalized(&flag);
NBLA_MPI_CHECK(MPI_Finalized(&flag));
return bool(flag);
}

Expand All @@ -115,15 +136,16 @@ private:
char **argv = nullptr;
int requiredThreadLevelSupport = MPI_THREAD_SERIALIZED;
int provided;
MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided);
NBLA_MPI_CHECK(
MPI_Init_thread(&argc, &argv, requiredThreadLevelSupport, &provided));
if (provided != requiredThreadLevelSupport) {
NBLA_ERROR(error_code::target_specific,
"MPI_Init_thread failed since provided (%d) is not equal to "
"requiredThreadLevelSupport (%d)",
provided, requiredThreadLevelSupport);
}
}
MPI_Comm_group(MPI_COMM_WORLD, &world_group_);
NBLA_MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &world_group_));
}
MPI_Group world_group_;
};
Expand Down Expand Up @@ -161,8 +183,10 @@ public:
MPI_COMM_WORLD.
*/
MpiCommWrapper(std::vector<int> ranks) : mpi_(Mpi::get()), own_(true) {
MPI_Group_incl(Mpi::world_group(), ranks.size(), ranks.data(), &group_);
MPI_Comm_create_group(MPI_COMM_WORLD, group_, 0, &this->comm_);
NBLA_MPI_CHECK(MPI_Group_incl(Mpi::world_group(), ranks.size(),
ranks.data(), &group_));
NBLA_MPI_CHECK(
MPI_Comm_create_group(MPI_COMM_WORLD, group_, 0, &this->comm_));
}
/*
Deletes MPI_Group and MPI_Comm objects when it owns.
Expand Down Expand Up @@ -233,17 +257,17 @@ template <typename T>
bool MultiProcessDataParallelCommunicatorNccl<T>::mpi_check_any(
bool condition, const string &group) {
bool result;
MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LOR,
this->mpi_comms_[group]->comm());
NBLA_MPI_CHECK(MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LOR,
this->mpi_comms_[group]->comm()));
return result;
}

template <typename T>
bool MultiProcessDataParallelCommunicatorNccl<T>::mpi_check_all(
bool condition, const string &group) {
bool result;
MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LAND,
this->mpi_comms_[group]->comm());
NBLA_MPI_CHECK(MPI_Allreduce(&condition, &result, 1, MPI_C_BOOL, MPI_LAND,
this->mpi_comms_[group]->comm()));
return result;
}

Expand Down Expand Up @@ -289,18 +313,18 @@ template <typename T> void MultiProcessDataParallelCommunicatorNccl<T>::init() {
this->mpi_comms_["world"] = make_shared<MpiCommWrapper>();

// Create comm, set size, and rank
MPI_Comm_size(MPI_COMM_WORLD, &this->size_);
MPI_Comm_rank(MPI_COMM_WORLD, &this->rank_);
NBLA_MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &this->size_));
NBLA_MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &this->rank_));

// Set local rank and device id
uint64_t host_hashs[this->size_];
char hostname[1024];
get_host_name(hostname, 1024);
host_hashs[this->rank_] = get_host_hash(hostname);

MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, host_hashs,
sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
NBLA_MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, host_hashs,
sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
NBLA_MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD));

int local_rank = 0;
for (int i = 0; i < this->size_; ++i) {
Expand All @@ -320,8 +344,9 @@ template <typename T> void MultiProcessDataParallelCommunicatorNccl<T>::init() {
if (this->rank_ == 0) {
ncclGetUniqueId(&comm_id);
}
MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
NBLA_MPI_CHECK(
MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, MPI_COMM_WORLD));
NBLA_MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD));

// NCCL Init
cuda_set_device(device_id_);
Expand Down Expand Up @@ -352,12 +377,12 @@ template <typename T> void MultiProcessDataParallelCommunicatorNccl<T>::init() {

template <typename T>
void MultiProcessDataParallelCommunicatorNccl<T>::barrier() {
MPI_Barrier(MPI_COMM_WORLD);
NBLA_MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD));
}

template <typename T>
void MultiProcessDataParallelCommunicatorNccl<T>::abort() {
MPI_Abort(MPI_COMM_WORLD, -1);
NBLA_MPI_CHECK(MPI_Abort(MPI_COMM_WORLD, -1));
}

template <typename T>
Expand Down Expand Up @@ -402,9 +427,10 @@ string MultiProcessDataParallelCommunicatorNccl<T>::new_group(
ncclGetUniqueId(&comm_id);
}
int rank;
MPI_Comm_rank(group_mpi_comm->comm(), &rank);
MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0, group_mpi_comm->comm());
MPI_Barrier(group_mpi_comm->comm());
NBLA_MPI_CHECK(MPI_Comm_rank(group_mpi_comm->comm(), &rank));
NBLA_MPI_CHECK(MPI_Bcast(&comm_id, sizeof(comm_id), MPI_BYTE, 0,
group_mpi_comm->comm()));
NBLA_MPI_CHECK(MPI_Barrier(group_mpi_comm->comm()));

// NCCL Comm Init
cuda_set_device(device_id_);
Expand Down