diff --git a/Src/Base/AMReX_NonLocalBCImpl.H b/Src/Base/AMReX_NonLocalBCImpl.H index 0ecfe80b3bb..432c298c48b 100644 --- a/Src/Base/AMReX_NonLocalBCImpl.H +++ b/Src/Base/AMReX_NonLocalBCImpl.H @@ -363,13 +363,13 @@ AMREX_STATIC_ASSERT_NO_MESSAGE((SwapComponents{0, 1}(2) == 2)); template EnableIf_t() && IsCallableR() && IsFabProjection()> -local_copy_cpu(FabArray& dest, const FabArray& src, int dcomp, int scomp, int ncomp, - FabArrayBase::CopyComTagsContainer const& local_tags, DTOS dtos = DTOS{}, - Proj proj = Proj{}) noexcept { +local_copy_cpu (FabArray& dest, const FabArray& src, int dcomp, int scomp, int ncomp, + FabArrayBase::CopyComTagsContainer const& local_tags, DTOS dtos = DTOS{}, + Proj proj = Proj{}) noexcept { #ifdef AMREX_USE_OMP #pragma omp parallel for #endif - for (const FabArrayBase::CopyComTag& tag : local_tags) { + for (auto const& tag : local_tags) { auto const& sfab = src.const_array(tag.srcIndex); auto const& dfab = dest.array(tag.dstIndex); amrex::LoopConcurrentOnCpu(tag.dbox, ncomp, [=](int i, int j, int k, int n) noexcept { @@ -381,10 +381,10 @@ local_copy_cpu(FabArray& dest, const FabArray& src, int dcomp, int sco template EnableIf_t() && IsCallableR() && IsFabProjection()> -unpack_recv_buffer_cpu(FabArray& mf, int scomp, int ncomp, Vector const& recv_data, - Vector const& recv_size, - Vector const& recv_cctc, - DTOS dtos = DTOS{}, Proj proj = Proj{}) noexcept { +unpack_recv_buffer_cpu (FabArray& mf, int scomp, int ncomp, Vector const& recv_data, + Vector const& recv_size, + Vector const& recv_cctc, + DTOS dtos = DTOS{}, Proj proj = Proj{}) noexcept { amrex::ignore_unused(recv_size); const int N_rcvs = recv_cctc.size(); @@ -527,6 +527,11 @@ struct CommHandler { #endif }; +#ifdef AMREX_USE_MPI +void PostRecvs(CommData& comm, int mpi_tag); +void PostSends(CommData& comm, int mpi_tag); +#endif + //////////////////////////////////////////////////////////////////////////////////// // [concept.DataPacking] // @@ -632,16 +637,20 @@ static_assert(IsDataPacking(), // template struct ApplyDtosAndProjectionOnReciever : PackComponents { + constexpr ApplyDtosAndProjectionOnReciever() = default; + constexpr ApplyDtosAndProjectionOnReciever(const PackComponents& components, DTOS dtos_, FabProj proj_) + : PackComponents(components), dtos(std::move(dtos_)), proj(std::move(proj_)) {} + AMREX_NO_UNIQUE_ADDRESS DTOS dtos; AMREX_NO_UNIQUE_ADDRESS FabProj proj; - static_assert(IsIndexMapping(), "DTOS needs to be an index map"); + static_assert(IsIndexMapping(), "DTOS needs to be an index map"); }; template amrex::EnableIf_t() && IsIndexMapping() && IsFabProjection()> -LocalCopy(const ApplyDtosAndProjectionOnReciever& packing, FabArray& dest, - const FabArray& src, const FabArrayBase::CopyComTagsContainer& local_tags) { +LocalCopy (const ApplyDtosAndProjectionOnReciever& packing, FabArray& dest, + const FabArray& src, const FabArrayBase::CopyComTagsContainer& local_tags) { local_copy_cpu(dest, src, packing.dest_component, packing.src_component, packing.n_components, local_tags, packing.dtos, packing.proj); } @@ -649,20 +658,15 @@ LocalCopy(const ApplyDtosAndProjectionOnReciever& packing, FabArr #ifdef AMREX_USE_MPI template amrex::EnableIf_t() && IsIndexMapping() && IsFabProjection()> -UnpackRecvBuffers(const ApplyDtosAndProjectionOnReciever& packing, - FabArray& dest, const CommData& comm) { +UnpackRecvBuffers (const ApplyDtosAndProjectionOnReciever& packing, + FabArray& dest, const CommData& comm) { unpack_recv_buffer_cpu(dest, packing.dest_component, packing.n_components, comm.data, comm.size, comm.cctc, packing.dtos, packing.proj); } #endif static_assert(IsDataPacking, FArrayBox>(), - "ApplyDtosAndProjectionOnReciever is expected to satisfy the DataPacking concept."); - -#ifdef AMREX_USE_MPI -void PostRecvs(CommData& comm, int mpi_tag); -void PostSends(CommData& comm, int mpi_tag); -#endif + "ApplyDtosAndProjectionOnReciever<> is expected to satisfy the DataPacking concept."); /// Initiate recv and send calls for MPI and immediately return without doing any work. /// @@ -678,9 +682,9 @@ void PostSends(CommData& comm, int mpi_tag); template ::value>, typename = EnableIf_t::value>> -AMREX_NODISCARD CommHandler ParallelCopy_nowait(FabArray& dest, const FabArray& src, - const FabArrayBase::CommMetaData& cmd, - const DataPacking& data_packing) { +AMREX_NODISCARD CommHandler +ParallelCopy_nowait (FabArray& dest, const FabArray& src, + const FabArrayBase::CommMetaData& cmd, const DataPacking& data_packing) { CommHandler handler{}; #ifdef AMREX_USE_MPI if (ParallelContext::NProcsSub() == 1) { @@ -710,8 +714,8 @@ AMREX_NODISCARD CommHandler ParallelCopy_nowait(FabArray& dest, const FabAr template EnableIf_t() && IsDataPacking()> -ParallelCopy_finish(FabArray& dest, const FabArray& src, CommHandler handler, - const FabArrayBase::CommMetaData& cmd, const DataPacking& data_packing) { +ParallelCopy_finish (FabArray& dest, const FabArray& src, CommHandler handler, + const FabArrayBase::CommMetaData& cmd, const DataPacking& data_packing) { // If any FabArray is empty we have nothing to do. if (dest.size() == 0) { return; @@ -803,11 +807,24 @@ struct MultiBlockCommMetaData : FabArrayBase::CommMetaData { //! @} }; -template +template EnableIf_t() && IsIndexMapping() && IsFabProjection()> -ParallelCopy(FabArray& dest, const Box& destbox, const FabArray& src, int destcomp, int srccomp, int numcomp, DTOS dtos = DTOS{}, Proj proj = Proj{}) -{ +ParallelCopy (FabArray& dest, const Box& destbox, const FabArray& src, + const MultiBlockCommMetaData& cmd, int destcomp, int srccomp, int numcomp, + DTOS dtos = DTOS{}, Proj proj = Proj{}) { + ApplyDtosAndProjectionOnReciever packing{PackComponents{destcomp, srccomp, numcomp}, + std::move(dtos), std::move(proj)}; + CommHandler handler = ParallelCopy_nowait(dest, src, cmd, packing); + ParallelCopy_finish(dest, src, std::move(handler), cmd, packing); +} +template +EnableIf_t() && IsIndexMapping() && IsFabProjection()> +ParallelCopy (FabArray& dest, const Box& destbox, const FabArray& src, int destcomp, + int srccomp, int numcomp, DTOS dtos = DTOS{}, Proj proj = Proj{}) { + MultiBlockCommMetaData cmd(dest, destbox, src, dtos); + ParallelCopy(dest, destbox, src, cmd, destcomp, srccomp, numcomp, std::move(dtos), + std::move(proj)); } template @@ -1070,6 +1087,10 @@ void FillPolar(FabArray& mf, int scomp, int ncomp, IntVect const& ngh extern template void FillPolar (FabArray& mf, Box const& domain); +extern template +void ParallelCopy(FabArray& dest, const Box& destbox, const FabArray& src, int destcomp, + int srccomp, int numcomp, Identity, Identity); + }}