Skip to content

Commit

Permalink
implement a number of different receive strategies and use issend
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Oct 30, 2024
1 parent cff847b commit 8cc982c
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 45 deletions.
103 changes: 71 additions & 32 deletions src/bvals/comms/combined_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ void CombinedBuffersRank::PackAndSend(int partition) {
"Trying to send combined buffers before they have been built");
if (combined_info_device.count(partition) == 0) return; // There is nothing to send here
auto &comb_info = combined_info_device[partition];
PARTHENON_REQUIRE(combined_buffers[partition].IsAvailableForWrite(), "Trying to write to a buffer that is in use.");
Kokkos::parallel_for(
PARTHENON_AUTO_LABEL,
Kokkos::TeamPolicy<>(parthenon::DevExecSpace(), combined_info[partition].size(),
Expand Down Expand Up @@ -218,15 +219,15 @@ bool CombinedBuffersRank::IsAvailableForWrite(int partition) {
return combined_buffers[partition].IsAvailableForWrite();
}

bool CombinedBuffersRank::TryReceiveAndUnpack(Mesh *pmesh, int partition) {
bool CombinedBuffersRank::TryReceiveAndUnpack(Mesh *pmesh, int partition, MPI_Message *message) {
PARTHENON_REQUIRE(buffers_built,
"Trying to recv combined buffers before they have been built");
PARTHENON_REQUIRE(combined_buffers.count(partition) > 0,
"Trying to receive on a non-existent combined receive buffer.");
for (auto &buf : buffers[partition]) {
if (buf->GetState() != BufferState::stale) return false;
}
auto received = combined_buffers[partition].TryReceive();
auto received = combined_buffers[partition].TryReceive(message);
if (!received) return false;

// TODO(LFR): Fix this so it works in the more general case
Expand Down Expand Up @@ -369,40 +370,78 @@ void CombinedBuffers::TryReceiveAny(Mesh *pmesh, BoundaryType b_type) {
#ifdef MPI_PARALLEL
// This was an attempt at another method for receiving, it seemed to work
// but was subject to the same problems as the Iprobe based code
// for (int rank = 0; rank < Globals::nranks; ++rank) {
// if (combined_recv_buffers.count({rank, b_type})) {
// auto &comb_bufs = combined_recv_buffers.at({rank, b_type});
// for (auto &[partition, buf] : comb_bufs.buffers) {
// comb_bufs.TryReceiveAndUnpack(pmesh, partition);
// }
// }
//}

MPI_Status status;
int flag;
do {
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, comms_[GetAssociatedSender(b_type)], &flag,
&status);
if (flag) {
const int rank = status.MPI_SOURCE;
const int partition = status.MPI_TAG;
if (pmesh->receive_type == "old") {
for (int rank = 0; rank < Globals::nranks; ++rank) {
if (combined_recv_buffers.count({rank, b_type})) {
auto &comb_bufs = combined_recv_buffers.at({rank, b_type});
for (auto &[partition, buf] : comb_bufs.buffers) {
comb_bufs.TryReceiveAndUnpack(pmesh, partition, nullptr);
}
}
}
} else if (pmesh->receive_type == "iprobe") {
MPI_Status status;
int flag;
int iters{0};
do {
MPI_Message message;
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, comms_[GetAssociatedSender(b_type)], &flag,
&status);
if (flag) {
const int rank = status.MPI_SOURCE;
const int partition = status.MPI_TAG;
bool finished =
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition, nullptr);
if (!finished) processing_messages.insert(std::make_pair(std::pair<int, int>{rank, partition}, message));
}
++iters;
} while (flag || iters < 10);

// Process in-flight messages
std::vector<std::pair<int, int>> finished_messages;
for (auto &[p, message] : processing_messages) {
int rank = p.first;
int partition = p.second;
bool finished =
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition);
if (!finished) processing_messages.insert({rank, partition});
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition, nullptr);
if (finished) finished_messages.push_back({rank, partition});
}
} while (flag);

// Process in flight messages
std::vector<std::pair<int, int>> finished_messages;
for (auto &[rank, partition] : processing_messages) {
bool finished =
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition);
if (finished) finished_messages.push_back({rank, partition});
}

for (auto &m : finished_messages)
processing_messages.erase(m);
for (auto &m : finished_messages)
processing_messages.erase(m);
} else if (pmesh->receive_type == "improbe") {
MPI_Status status;
int flag;
int iters{0};
do {
MPI_Message message;
MPI_Improbe(MPI_ANY_SOURCE, MPI_ANY_TAG, comms_[GetAssociatedSender(b_type)], &flag,
&message, &status);
if (flag) {
const int rank = status.MPI_SOURCE;
const int partition = status.MPI_TAG;
bool finished =
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition, &message);
if (!finished) processing_messages.insert(std::make_pair(std::pair<int, int>{rank, partition}, message));
}
++iters;
} while (flag || iters < 10);

// Process in-flight messages
std::vector<std::pair<int, int>> finished_messages;
for (auto &[p, message] : processing_messages) {
int rank = p.first;
int partition = p.second;
bool finished =
combined_recv_buffers.at({rank, b_type}).TryReceiveAndUnpack(pmesh, partition, &message);
if (finished) finished_messages.push_back({rank, partition});
}

for (auto &m : finished_messages)
processing_messages.erase(m);
} else {
PARTHENON_FAIL("Unknown receiving strategy.");
}
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions src/bvals/comms/combined_buffers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct CombinedBuffersRank {

void PackAndSend(int partition);

bool TryReceiveAndUnpack(Mesh *pmesh, int partition);
bool TryReceiveAndUnpack(Mesh *pmesh, int partition, MPI_Message *message);

void RepointBuffers(Mesh *pmesh, int partition);

Expand All @@ -88,7 +88,7 @@ struct CombinedBuffers {
std::map<std::pair<int, BoundaryType>, CombinedBuffersRank> combined_send_buffers;
std::map<std::pair<int, BoundaryType>, CombinedBuffersRank> combined_recv_buffers;

std::set<std::pair<int, int>> processing_messages;
std::map<std::pair<int, int>, MPI_Message> processing_messages;

std::map<BoundaryType, mpi_comm_t> comms_;
CombinedBuffers() {
Expand Down
3 changes: 2 additions & 1 deletion src/mesh/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ Mesh::Mesh(ParameterInput *pin, ApplicationInput *app_in, Packages_t &packages,
nref(Globals::nranks), nderef(Globals::nranks), rdisp(Globals::nranks),
ddisp(Globals::nranks), bnref(Globals::nranks), bnderef(Globals::nranks),
brdisp(Globals::nranks), bddisp(Globals::nranks),
pcombined_buffers(std::make_shared<CombinedBuffers>()) {
pcombined_buffers(std::make_shared<CombinedBuffers>()),
receive_type{pin->GetOrAddString("parthenon/mesh", "receive_type", "iprobe")} {
// Allow for user overrides to default Parthenon functions
if (app_in->InitUserMeshData != nullptr) {
InitUserMeshData = app_in->InitUserMeshData;
Expand Down
3 changes: 2 additions & 1 deletion src/mesh/mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ class Mesh {
std::unordered_map<channel_key_t, comm_buf_t, tuple_hash<channel_key_t>>;
comm_buf_map_t boundary_comm_map;
TagMap tag_map;


std::string receive_type; // Defines how to structure the MPI receives for combined buffers
std::shared_ptr<CombinedBuffers> pcombined_buffers;

#ifdef MPI_PARALLEL
Expand Down
25 changes: 16 additions & 9 deletions src/utils/communication_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ class CommBuffer {

bool IsAvailableForWrite();

void TryStartReceive() noexcept;
bool TryReceive() noexcept;
void TryStartReceive(MPI_Message * message_id = nullptr) noexcept;
bool TryReceive(MPI_Message * message_id = nullptr) noexcept;
bool TryReceiveLocal() noexcept;
void SetReceived() noexcept {
PARTHENON_REQUIRE(*comm_type_ == BuffCommType::receiver ||
Expand Down Expand Up @@ -247,7 +247,7 @@ void CommBuffer<T>::Send() noexcept {
buf_.size() > 0,
"Trying to send zero size buffer, which will be interpreted as sending_null.");
PARTHENON_MPI_CHECK(MPI_Wait(my_request_.get(), MPI_STATUS_IGNORE));
PARTHENON_MPI_CHECK(MPI_Isend(buf_.data(), buf_.size(),
PARTHENON_MPI_CHECK(MPI_Issend(buf_.data(), buf_.size(),
MPITypeMap<buf_base_t>::type(), recv_rank_, tag_, comm_,
my_request_.get()));
#endif
Expand Down Expand Up @@ -319,7 +319,7 @@ bool CommBuffer<T>::IsAvailableForWrite() {
}

template <class T>
void CommBuffer<T>::TryStartReceive() noexcept {
void CommBuffer<T>::TryStartReceive(MPI_Message* message_id) noexcept {
#ifdef MPI_PARALLEL
if (*comm_type_ == BuffCommType::receiver && !*started_irecv_) {
PARTHENON_REQUIRE(
Expand All @@ -328,11 +328,18 @@ void CommBuffer<T>::TryStartReceive() noexcept {
if (!IsActive())
Allocate(
-1); // For early start of Irecv, always need storage space even if not used
PARTHENON_MPI_CHECK(MPI_Irecv(buf_.data(), buf_.size(),
MPITypeMap<buf_base_t>::type(), send_rank_, tag_, comm_,
my_request_.get()));
if (message_id != nullptr) {
PARTHENON_MPI_CHECK(MPI_Imrecv(buf_.data(), buf_.size(),
MPITypeMap<buf_base_t>::type(), message_id,
my_request_.get()));
} else {
PARTHENON_MPI_CHECK(MPI_Irecv(buf_.data(), buf_.size(),
MPITypeMap<buf_base_t>::type(), send_rank_, tag_, comm_,
my_request_.get()));
}
*started_irecv_ = true;
} else if (*comm_type_ == BuffCommType::sparse_receiver && !*started_irecv_) {
PARTHENON_REQUIRE(message_id == nullptr, "Imrecv not yet implemented for sparse buffers.");
int test;
MPI_Status status;
// Check if our message is available so that we can use the correct buffer size
Expand Down Expand Up @@ -377,7 +384,7 @@ bool CommBuffer<T>::TryReceiveLocal() noexcept {
}

template <class T>
bool CommBuffer<T>::TryReceive() noexcept {
bool CommBuffer<T>::TryReceive(MPI_Message* message_id) noexcept {
if (*state_ == BufferState::received || *state_ == BufferState::received_null)
return true;

Expand All @@ -388,7 +395,7 @@ bool CommBuffer<T>::TryReceive() noexcept {
PARTHENON_REQUIRE(*nrecv_tries_ < 1e8,
"MPI probably hanging after 1e8 receive tries.");

TryStartReceive();
TryStartReceive(message_id);

if (*started_irecv_) {
MPI_Status status;
Expand Down

0 comments on commit 8cc982c

Please sign in to comment.