Skip to content

Commit

Permalink
add logger as member + increase by one the pool size
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed May 28, 2021
1 parent c5b87ef commit dd04828
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 24 deletions.
2 changes: 2 additions & 0 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ class MpiWorld
std::string thisHost;
faabric::util::TimePoint creationTime;

const std::shared_ptr<spdlog::logger> logger;

std::shared_mutex worldMutex;
std::atomic_flag isDestroyed = false;

Expand Down
36 changes: 12 additions & 24 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ MpiWorld::MpiWorld()
, size(-1)
, thisHost(faabric::util::getSystemConfig().endpointHost)
, creationTime(faabric::util::startTimer())
, logger(faabric::util::getLogger())
, cartProcsPerDim(2)
{}

Expand Down Expand Up @@ -84,7 +85,6 @@ faabric::scheduler::FunctionCallClient& MpiWorld::getFunctionCallClient(

int MpiWorld::getMpiThreadPoolSize()
{
auto logger = faabric::util::getLogger();
int usableCores = faabric::util::getUsableCores();
int worldSize = size;

Expand All @@ -93,7 +93,15 @@ int MpiWorld::getMpiThreadPoolSize()
logger->warn("To avoid this, set an MPI world size multiple of the "
"number of cores per machine.");
}
return std::min<int>(worldSize, usableCores);
// Note - adding one to the worldSize to prevent deadlocking in certain
// corner-cases.
// For instance, if issuing `worldSize` non-blocking recvs, followed by
// `worldSize` non-blocking sends, and nothing else, the application will
// deadlock as all worker threads will be blocking on `recv` calls. This
// scenario is remote, but feasible. We _assume_ that following the same
// pattern but doing `worldSize + 1` calls is deliberately malicious, and
// we can confidently fail and deadlock.
return std::min<int>(worldSize + 1, usableCores);
}

void MpiWorld::create(const faabric::Message& call, int newId, int newSize)
Expand Down Expand Up @@ -451,8 +459,6 @@ void MpiWorld::send(int sendRank,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

if (recvRank > this->size - 1) {
throw std::runtime_error(fmt::format(
"Rank {} bigger than world size {}", recvRank, this->size));
Expand Down Expand Up @@ -503,8 +509,6 @@ void MpiWorld::recv(int sendRank,
MPI_Status* status,
faabric::MPIMessage::MPIMessageType messageType)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

// Listen to the in-memory queue for this rank and message type
logger->trace("MPI - recv {} -> {}", sendRank, recvRank);
std::shared_ptr<faabric::MPIMessage> m =
Expand Down Expand Up @@ -556,7 +560,6 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer,
int myRank,
MPI_Status* status)
{
auto logger = faabric::util::getLogger();
logger->trace(
"MPI - Sendrecv. Rank {}. Sending to: {} - Receiving from: {}",
myRank,
Expand Down Expand Up @@ -596,7 +599,6 @@ void MpiWorld::broadcast(int sendRank,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();
logger->trace("MPI - bcast {} -> all", sendRank);

for (int r = 0; r < size; r++) {
Expand Down Expand Up @@ -636,7 +638,6 @@ void MpiWorld::scatter(int sendRank,
faabric_datatype_t* recvType,
int recvCount)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();
checkSendRecvMatch(sendType, sendCount, recvType, recvCount);

size_t sendOffset = sendCount * sendType->size;
Expand Down Expand Up @@ -683,7 +684,6 @@ void MpiWorld::gather(int sendRank,
faabric_datatype_t* recvType,
int recvCount)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();
checkSendRecvMatch(sendType, sendCount, recvType, recvCount);

size_t sendOffset = sendCount * sendType->size;
Expand Down Expand Up @@ -792,7 +792,7 @@ void MpiWorld::allGather(int rank,

void MpiWorld::awaitAsyncRequest(int requestId)
{
faabric::util::getLogger()->trace("MPI - await {}", requestId);
logger->trace("MPI - await {}", requestId);

auto it = futureMap.find(requestId);
if (it == futureMap.end()) {
Expand All @@ -804,8 +804,7 @@ void MpiWorld::awaitAsyncRequest(int requestId)
it->second.wait();
futureMap.erase(it);

faabric::util::getLogger()->debug("Finished awaitAsyncRequest on {}",
requestId);
logger->debug("Finished awaitAsyncRequest on {}", requestId);
}

void MpiWorld::reduce(int sendRank,
Expand All @@ -816,8 +815,6 @@ void MpiWorld::reduce(int sendRank,
int count,
faabric_op_t* operation)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

// If we're the receiver, await inputs
if (sendRank == recvRank) {
logger->trace("MPI - reduce ({}) all -> {}", operation->id, recvRank);
Expand Down Expand Up @@ -901,8 +898,6 @@ void MpiWorld::op_reduce(faabric_op_t* operation,
uint8_t* inBuffer,
uint8_t* outBuffer)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

logger->trace(
"MPI - reduce op: {} datatype {}", operation->id, datatype->id);
if (operation->id == faabric_op_max.id) {
Expand Down Expand Up @@ -1005,7 +1000,6 @@ void MpiWorld::scan(int rank,
int count,
faabric_op_t* operation)
{
auto logger = faabric::util::getLogger();
logger->trace("MPI - scan");

if (rank > this->size - 1) {
Expand Down Expand Up @@ -1115,8 +1109,6 @@ void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status)

void MpiWorld::barrier(int thisRank)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

if (thisRank == 0) {
// This is the root, hence just does the waiting

Expand Down Expand Up @@ -1150,8 +1142,6 @@ void MpiWorld::barrier(int thisRank)

void MpiWorld::enqueueMessage(faabric::MPIMessage& msg)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

if (msg.worldid() != id) {
logger->error(
"Queueing message not meant for this world (msg={}, this={})",
Expand Down Expand Up @@ -1285,8 +1275,6 @@ long MpiWorld::getLocalQueueSize(int sendRank, int recvRank)

void MpiWorld::checkRankOnThisHost(int rank)
{
const std::shared_ptr<spdlog::logger>& logger = faabric::util::getLogger();

// Check if we know about this rank on this host
if (rankHostMap.count(rank) == 0) {
logger->error("No mapping found for rank {} on this host", rank);
Expand Down

0 comments on commit dd04828

Please sign in to comment.