Skip to content

Commit

Permalink
taking a few steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lroberts36 committed Oct 22, 2024
1 parent 160c77f commit 6355185
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
10 changes: 7 additions & 3 deletions src/bvals/comms/boundary_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,17 @@ TaskStatus ReceiveBoundBufs(std::shared_ptr<MeshData<Real>> &md) {
false);

// Receive any messages that are around
pmesh->pcombined_buffers->TryReceiveAny(bound_type);
pmesh->pcombined_buffers->TryReceiveAny(pmesh, bound_type);

bool all_received = true;
int nreceived = 0;
std::for_each(
std::begin(cache.buf_vec), std::end(cache.buf_vec),
[&all_received](auto pbuf) { all_received = pbuf->TryReceiveLocal() && all_received; });

[&all_received, &nreceived](auto pbuf) {
bool received = pbuf->TryReceiveLocal();
nreceived += received;
all_received = received && all_received; });
printf("All receive = %i on rank %i (%i received, %i expected)\n", all_received, Globals::my_rank, nreceived, cache.buf_vec.size());
int ibound = 0;
if (Globals::sparse_config.enabled && all_received) {
ForEachBoundary<bound_type>(
Expand Down
33 changes: 22 additions & 11 deletions src/bvals/comms/combined_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ bool CombinedBuffersRank::TryReceiveBufInfo(Mesh *pmesh) {
PARTHENON_REQUIRE(pmesh->boundary_comm_map.count(GetChannelKey(buf)),
"Buffer doesn't exist.");
buf.buf = pmesh->boundary_comm_map[GetChannelKey(buf)];
bufs.push_back(pmesh->boundary_comm_map[GetChannelKey(buf)]);
bufs.push_back(&(pmesh->boundary_comm_map[GetChannelKey(buf)]));
buf.pcombined_buf = &(combined_buffers[partition].buffer());
idx += BndId::NDAT;
}
Expand Down Expand Up @@ -127,7 +127,7 @@ void CombinedBuffersRank::ResolveSendBuffersAndSendInfo(Mesh *pmesh) {
PARTHENON_REQUIRE(pmesh->boundary_comm_map.count(GetChannelKey(buf_struct)),
"Buffer doesn't exist.");

bufs.push_back(pmesh->boundary_comm_map[GetChannelKey(buf_struct)]);
bufs.push_back(&(pmesh->boundary_comm_map[GetChannelKey(buf_struct)]));
idx += BndId::NDAT;
}
}
Expand All @@ -152,19 +152,18 @@ void CombinedBuffersRank::ResolveSendBuffersAndSendInfo(Mesh *pmesh) {
}

void CombinedBuffersRank::RepointBuffers(Mesh *pmesh, int partition) {
printf("Repointing buffers\n");
// Pull out the buffers and point them to the buf_struct
auto &buf_struct_vec = combined_info[partition];
for (auto &buf_struct : buf_struct_vec) {
buf_struct.buf = pmesh->boundary_comm_map[GetChannelKey(buf_struct)];
}

// Get the BndId objects on device
auto &buf_vec = combined_info[partition];

combined_info_device[partition] = ParArray1D<BndId>("bnd_id", buf_vec.size());
// Get the BndId objects on device
combined_info_device[partition] = ParArray1D<BndId>("bnd_id", buf_struct_vec.size());
auto ci_host = Kokkos::create_mirror_view(combined_info_device[partition]);
for (int i = 0; i < ci_host.size(); ++i)
ci_host[i] = buf_vec[i];
ci_host[i] = buf_struct_vec[i];
Kokkos::deep_copy(combined_info_device[partition], ci_host);
}

Expand All @@ -191,14 +190,26 @@ void CombinedBuffersRank::PackAndSend(int partition) {
combined_buffers[partition].Send();
// Information in these send buffers is no longer required
for (auto &buf : buffers[partition])
buf.Stale();
buf->Stale();
}

bool CombinedBuffersRank::TryReceiveAndUnpack(int partition) {
bool CombinedBuffersRank::TryReceiveAndUnpack(Mesh *pmesh, int partition) {
PARTHENON_REQUIRE(buffers_built, "Trying to recv combined buffers before they have been built")
auto &comb_info = combined_info_device[partition];
auto received = combined_buffers[partition].TryReceive();
if (!received) return false;

// TODO(LFR): Fix this so it works in the more general case
bool all_allocated = true;
for (auto &buf : buffers[partition]) {
if (!buf->IsActive()) {
all_allocated = false;
buf->Allocate();
}
}
if (!all_allocated)
RepointBuffers(pmesh, partition);

auto &comb_info = combined_info_device[partition];
Kokkos::parallel_for(
PARTHENON_AUTO_LABEL,
Kokkos::TeamPolicy<>(parthenon::DevExecSpace(), combined_info[partition].size(), Kokkos::AUTO),
Expand All @@ -215,7 +226,7 @@ bool CombinedBuffersRank::TryReceiveAndUnpack(int partition) {
});
combined_buffers[partition].Stale();
for (auto &buf : buffers[partition])
buf.SetReceived();
buf->SetReceived();
return true;
}

Expand Down
8 changes: 4 additions & 4 deletions src/bvals/comms/combined_buffers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct CombinedBuffersRank {
// partition id of the sender will be the mpi tag we use
bool buffers_built{false};
std::map<int, coalesced_message_structure_t> combined_info;
std::map<int, std::vector<CommBuffer<buf_pool_t<Real>::weak_t>>> buffers;
std::map<int, std::vector<CommBuffer<buf_pool_t<Real>::owner_t>*>> buffers;
std::map<int, ParArray1D<BndId>> combined_info_device;
std::map<int, CommBuffer<buf_t>> combined_buffers;
std::map<int, int> current_size;
Expand Down Expand Up @@ -74,7 +74,7 @@ struct CombinedBuffersRank {

void PackAndSend(int partition);

bool TryReceiveAndUnpack(int partition);
bool TryReceiveAndUnpack(Mesh *pmesh, int partition);

void RepointBuffers(Mesh *pmesh, int partition);
};
Expand Down Expand Up @@ -151,7 +151,7 @@ struct CombinedBuffers {
}
}

void TryReceiveAny(BoundaryType b_type) {
void TryReceiveAny(Mesh *pmesh, BoundaryType b_type) {
#ifdef MPI_PARALLEL
MPI_Status status;
int flag;
Expand All @@ -161,7 +161,7 @@ struct CombinedBuffers {
if (flag) {
const int rank = status.MPI_SOURCE;
const int partition = status.MPI_TAG;
combined_recv_buffers[{rank, b_type}].TryReceiveAndUnpack(partition);
combined_recv_buffers[{rank, b_type}].TryReceiveAndUnpack(pmesh, partition);
}
} while(flag);
#endif
Expand Down
18 changes: 16 additions & 2 deletions src/utils/communication_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ class CommBuffer {
bool TryReceive() noexcept;
bool TryReceiveLocal() noexcept;
void SetReceived() noexcept {
PARTHENON_REQUIRE(*comm_type_ == BuffCommType::receiver, "This doesn't make sense for a non-receiver.");
PARTHENON_REQUIRE(*comm_type_ == BuffCommType::receiver ||
*comm_type_ == BuffCommType::sparse_receiver,
"This doesn't make sense for a non-receiver.");
*state_ = BufferState::received;
}
bool IsSafeToDelete() {
Expand Down Expand Up @@ -180,7 +182,7 @@ CommBuffer<T>::CommBuffer(const CommBuffer<U> &in)
: buf_(in.buf_), state_(in.state_), comm_type_(in.comm_type_),
started_irecv_(in.started_irecv_), nrecv_tries_(in.nrecv_tries_),
my_request_(in.my_request_), tag_(in.tag_), send_rank_(in.send_rank_),
recv_rank_(in.recv_rank_), comm_(in.comm_), active_(in.active_) {
recv_rank_(in.recv_rank_), comm_(in.comm_), active_(in.active_), get_resource_(in.get_resource_) {
my_rank = Globals::my_rank;
}

Expand Down Expand Up @@ -220,6 +222,7 @@ CommBuffer<T> &CommBuffer<T>::operator=(const CommBuffer<U> &in) {
comm_ = in.comm_;
active_ = in.active_;
my_rank = Globals::my_rank;
get_resource_ = in.get_resource_;
return *this;
}

Expand Down Expand Up @@ -356,6 +359,17 @@ template <class T>
bool CommBuffer<T>::TryReceiveLocal() noexcept {
if (*state_ == BufferState::received || *state_ == BufferState::received_null)
return true;
if (*comm_type_ == BuffCommType::both) {
if (*state_ == BufferState::sending) {
*state_ = BufferState::received;
// Memory should already be available, since both
// send and receive rank point at the same memory
return true;
} else if (*state_ == BufferState::sending_null) {
*state_ = BufferState::received_null;
return true;
}
}
return false;
}

Expand Down

0 comments on commit 6355185

Please sign in to comment.