From c6b4fde07ae1a553f17a864c839b40f24107ac28 Mon Sep 17 00:00:00 2001 From: Alexander Sinn <64009254+AlexanderSinn@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:45:28 +0200 Subject: [PATCH] AnyCTO with arbitrary number of functions (#4135) This PR extends AnyCTO to support an arbitrary number of optimized functions. Some constructs such as PrefixSum need multiple functions that may need to be optimized. It is also possible to provide no GPU function, in which case the compile time parameters will be directly given to the CPU function. --- Src/Base/AMReX_CTOParallelForImpl.H | 34 +++++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/Src/Base/AMReX_CTOParallelForImpl.H b/Src/Base/AMReX_CTOParallelForImpl.H index 8f7e8ce567f..d4431d7899b 100644 --- a/Src/Base/AMReX_CTOParallelForImpl.H +++ b/Src/Base/AMReX_CTOParallelForImpl.H @@ -44,24 +44,33 @@ namespace detail } }; - template + template bool - AnyCTO_helper2 (const L& l, const F& f, TypeList, - std::array const& runtime_options) + AnyCTO_helper2 (const L& l, TypeList, + std::array const& runtime_options, const Fs&...cto_functs) { if (runtime_options == std::array{As::value...}) { - l(CTOWrapper{f}); + if constexpr (sizeof...(cto_functs) != 0) { + // Apply the CTOWrapper to each function that was given in cto_functs + // and call the CPU function l with all of them + l(CTOWrapper{cto_functs}...); + } else { + // No functions in cto_functs so we call l directly with the compile time arguments + l(As{}...); + } return true; } else { return false; } } - template + template void - AnyCTO_helper1 (const L& l, const F& f, TypeList, RO const& runtime_options) + AnyCTO_helper1 (const L& l, TypeList, + RO const& runtime_options, const Fs&...cto_functs) { - bool found_option = (false || ... || AnyCTO_helper2(l, f, PPs{}, runtime_options)); + bool found_option = (false || ... || + AnyCTO_helper2(l, PPs{}, runtime_options, cto_functs...)); amrex::ignore_unused(found_option); AMREX_ASSERT(found_option); } @@ -168,17 +177,18 @@ namespace detail * \param list_of_compile_time_options list of all possible values of the parameters. * \param runtime_options the run time parameters. * \param l a callable object containing a CPU function that launches the provided GPU kernel. - * \param f a callable object containing the GPU kernel with optimizations. + * \param cto_functs a callable object containing the GPU kernel with optimizations. */ -template +template void AnyCTO ([[maybe_unused]] TypeList list_of_compile_time_options, std::array const& runtime_options, - L&& l, F&& f) + L&& l, Fs&&...cto_functs) { #if (__cplusplus >= 201703L) - detail::AnyCTO_helper1(std::forward(l), std::forward(f), + detail::AnyCTO_helper1(std::forward(l), CartesianProduct(typename CTOs::list_type{}...), - runtime_options); + runtime_options, + std::forward(cto_functs)...); #else amrex::ignore_unused(runtime_options, l, f); static_assert(std::is_integral::value, "This requires C++17");