Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MpiWolrd: add logger as class member + increase pool size #101

Merged
merged 1 commit into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ class MpiWorld
std::string thisHost;
faabric::util::TimePoint creationTime;

const std::shared_ptr<spdlog::logger> logger;

std::shared_mutex worldMutex;
std::atomic_flag isDestroyed = false;

Expand Down
36 changes: 12 additions & 24 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ MpiWorld::MpiWorld()
, size(-1)
, thisHost(faabric::util::getSystemConfig().endpointHost)
, creationTime(faabric::util::startTimer())
, logger(faabric::util::getLogger())
, cartProcsPerDim(2)
{}

Expand Down Expand Up @@ -84,7 +85,6 @@ faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient(

int MpiWorld::getMpiThreadPoolSize()
{
auto logger = faabric::util::getLogger();
int usableCores = faabric::util::getUsableCores();
int worldSize = size;

Expand All @@ -93,7 +93,15 @@ int MpiWorld::getMpiThreadPoolSize()
logger->warn("To avoid this, set an MPI world size multiple of the "
"number of cores per machine.");
}
return std::min<int>(worldSize, usableCores);
// Note - adding one to the worldSize to prevent deadlocking in certain
// corner-cases.
// For instance, if issuing `worldSize` non-blocking recvs, followed by
// `worldSize` non-blocking sends, and nothing else, the application will
// deadlock as all worker threads will be blocking on `recv` calls. This
// scenario is remote, but feasible. We _assume_ that following the same
// pattern but doing `worldSize + 1` calls is deliberately malicious, and
// we can confidently fail and deadlock.
return std::min<int>(worldSize + 1, usableCores);
}

void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
Expand Down Expand Up @@ -451,8 +459,6 @@ void MpiWorld::send(int sendRank,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

if (recvRank > this->size - 1) {
throw std::runtime_error(fmt::format(
"Rank {} bigger than world size {}", recvRank, this->size));
Expand Down Expand Up @@ -503,8 +509,6 @@ void MpiWorld::recv(int sendRank,
MPI_Status* status,
faabric::MPIMessage::MPIMessageType messageType)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

// Listen to the in-memory queue for this rank and message type
logger->trace("MPI - recv {} -> {}", sendRank, recvRank);
std::shared_ptr<faabric::MPIMessage> m =
Expand Down Expand Up @@ -556,7 +560,6 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer,
int myRank,
MPI_Status* status)
{
auto logger = faabric::util::getLogger();
logger->trace(
"MPI - Sendrecv. Rank {}. Sending to: {} - Receiving from: {}",
myRank,
Expand Down Expand Up @@ -596,7 +599,6 @@ void MpiWorld::broadcast(int sendRank,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();
logger->trace("MPI - bcast {} -> all", sendRank);

for (int r = 0; r < size; r++) {
Expand Down Expand Up @@ -636,7 +638,6 @@ void MpiWorld::scatter(int sendRank,
faabric_datatype_t* recvType,
int recvCount)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();
checkSendRecvMatch(sendType, sendCount, recvType, recvCount);

size_t sendOffset = sendCount * sendType->size;
Expand Down Expand Up @@ -683,7 +684,6 @@ void MpiWorld::gather(int sendRank,
faabric_datatype_t* recvType,
int recvCount)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();
checkSendRecvMatch(sendType, sendCount, recvType, recvCount);

size_t sendOffset = sendCount * sendType->size;
Expand Down Expand Up @@ -792,7 +792,7 @@ void MpiWorld::allGather(int rank,

void MpiWorld::awaitAsyncRequest(int requestId)
{
faabric::util::getLogger()->trace("MPI - await {}", requestId);
logger->trace("MPI - await {}", requestId);

auto it = futureMap.find(requestId);
if (it == futureMap.end()) {
Expand All @@ -804,8 +804,7 @@ void MpiWorld::awaitAsyncRequest(int requestId)
it->second.wait();
futureMap.erase(it);

faabric::util::getLogger()->debug("Finished awaitAsyncRequest on {}",
requestId);
logger->debug("Finished awaitAsyncRequest on {}", requestId);
}

void MpiWorld::reduce(int sendRank,
Expand All @@ -816,8 +815,6 @@ void MpiWorld::reduce(int sendRank,
int count,
faabric_op_t* operation)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

// If we're the receiver, await inputs
if (sendRank == recvRank) {
logger->trace("MPI - reduce ({}) all -> {}", operation->id, recvRank);
Expand Down Expand Up @@ -901,8 +898,6 @@ void MpiWorld::op_reduce(faabric_op_t* operation,
uint8_t* inBuffer,
uint8_t* outBuffer)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

logger->trace(
"MPI - reduce op: {} datatype {}", operation->id, datatype->id);
if (operation->id == faabric_op_max.id) {
Expand Down Expand Up @@ -1005,7 +1000,6 @@ void MpiWorld::scan(int rank,
int count,
faabric_op_t* operation)
{
auto logger = faabric::util::getLogger();
logger->trace("MPI - scan");

if (rank > this->size - 1) {
Expand Down Expand Up @@ -1115,8 +1109,6 @@ void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status)

void MpiWorld::barrier(int thisRank)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

if (thisRank == 0) {
// This is the root, hence just does the waiting

Expand Down Expand Up @@ -1150,8 +1142,6 @@ void MpiWorld::barrier(int thisRank)

void MpiWorld::enqueueMessage(faabric::MPIMessage& msg)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

if (msg.worldid() != id) {
logger->error(
"Queueing message not meant for this world (msg={}, this={})",
Expand Down Expand Up @@ -1285,8 +1275,6 @@ long MpiWorld::getLocalQueueSize(int sendRank, int recvRank)

void MpiWorld::checkRankOnThisHost(int rank)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

// Check if we know about this rank on this host
if (rankHostMap.count(rank) == 0) {
logger->error("No mapping found for rank {} on this host", rank);
Expand Down