diff --git a/Src/Base/AMReX_GpuLaunchFunctsG.H b/Src/Base/AMReX_GpuLaunchFunctsG.H index 78e9e856535..dfbb1092697 100644 --- a/Src/Base/AMReX_GpuLaunchFunctsG.H +++ b/Src/Base/AMReX_GpuLaunchFunctsG.H @@ -867,6 +867,28 @@ namespace detail { { for (T n = 0; n < ncomp; ++n) f(i,j,k,n,Gpu::Handler(amrex::min(nleft,(int)blockDim.x))); } + + template + void parfor (Box const& box, ExecutionConfigconst& ec, F const& f) noexcept + { + const auto lo = amrex::lbound(box); + const auto len = amrex::length(box); + const auto ncells = T(box.numPts()); + const auto lenxy = Long(len.x)*Long(len.y); + const auto lenx = Long(len.x); + AMREX_LAUNCH_KERNEL(MT, ec.numBlocks, ec.numThreads, 0, Gpu::gpuStream(), + [=] AMREX_GPU_DEVICE () noexcept { + for (T icell = blockDim.x*blockIdx.x+threadIdx.x, stride = blockDim.x*gridDim.x; + icell < ncells; icell += stride) + { + T k = icell / lenxy; + T j = (icell - k*lenxy) / lenx; + T i = (icell - k*lenxy) - j*lenx; + detail::call_f(f, int(i)+lo.x, int(j)+lo.y, int(k)+lo.z, + (ncells-icell+(int)threadIdx.x)); + } + }); + } } template ::value> > @@ -890,26 +912,15 @@ std::enable_if_t::value> ParallelFor (Gpu::KernelInfo const&, Box const& box, L&& f) noexcept { if (amrex::isEmpty(box)) { return; } - int ncells = box.numPts(); - const auto lo = amrex::lbound(box); - const auto len = amrex::length(box); - const auto lenxy = len.x*len.y; - const auto lenx = len.x; - const auto ec = Gpu::makeExecutionConfig(ncells); - AMREX_LAUNCH_KERNEL(MT, ec.numBlocks, ec.numThreads, 0, Gpu::gpuStream(), - [=] AMREX_GPU_DEVICE () noexcept { - for (int icell = blockDim.x*blockIdx.x+threadIdx.x, stride = blockDim.x*gridDim.x; - icell < ncells; icell += stride) - { - int k = icell / lenxy; - int j = (icell - k*lenxy) / lenx; - int i = (icell - k*lenxy) - j*lenx; - i += lo.x; - j += lo.y; - k += lo.z; - detail::call_f(f, i, j, k, (ncells-icell+(int)threadIdx.x)); - } - }); + auto ncells = box.numPts(); + auto const& ec = Gpu::makeExecutionConfig(ncells); + auto const nthreads = Long(ec.numBlocks.x) * Long(ec.numThreads.x); + Long icell_max = std::max(nthreads,ncells) + nthreads; + if (icells_max <= std::numeric_limits::max()) { + detail::parfor_box(box, ec, f); + } else { + detail::parfor_box(box, ec, f); + } AMREX_GPU_ERROR_CHECK(); }