diff --git a/Src/Base/AMReX_CTOParallelForImpl.H b/Src/Base/AMReX_CTOParallelForImpl.H index 8f7e8ce567f..d4431d7899b 100644 --- a/Src/Base/AMReX_CTOParallelForImpl.H +++ b/Src/Base/AMReX_CTOParallelForImpl.H @@ -44,24 +44,33 @@ namespace detail } }; - template + template bool - AnyCTO_helper2 (const L& l, const F& f, TypeList, - std::array const& runtime_options) + AnyCTO_helper2 (const L& l, TypeList, + std::array const& runtime_options, const Fs&...cto_functs) { if (runtime_options == std::array{As::value...}) { - l(CTOWrapper{f}); + if constexpr (sizeof...(cto_functs) != 0) { + // Apply the CTOWrapper to each function that was given in cto_functs + // and call the CPU function l with all of them + l(CTOWrapper{cto_functs}...); + } else { + // No functions in cto_functs so we call l directly with the compile time arguments + l(As{}...); + } return true; } else { return false; } } - template + template void - AnyCTO_helper1 (const L& l, const F& f, TypeList, RO const& runtime_options) + AnyCTO_helper1 (const L& l, TypeList, + RO const& runtime_options, const Fs&...cto_functs) { - bool found_option = (false || ... || AnyCTO_helper2(l, f, PPs{}, runtime_options)); + bool found_option = (false || ... || + AnyCTO_helper2(l, PPs{}, runtime_options, cto_functs...)); amrex::ignore_unused(found_option); AMREX_ASSERT(found_option); } @@ -168,17 +177,18 @@ namespace detail * \param list_of_compile_time_options list of all possible values of the parameters. * \param runtime_options the run time parameters. * \param l a callable object containing a CPU function that launches the provided GPU kernel. - * \param f a callable object containing the GPU kernel with optimizations. + * \param cto_functs a callable object containing the GPU kernel with optimizations. */ -template +template void AnyCTO ([[maybe_unused]] TypeList list_of_compile_time_options, std::array const& runtime_options, - L&& l, F&& f) + L&& l, Fs&&...cto_functs) { #if (__cplusplus >= 201703L) - detail::AnyCTO_helper1(std::forward(l), std::forward(f), + detail::AnyCTO_helper1(std::forward(l), CartesianProduct(typename CTOs::list_type{}...), - runtime_options); + runtime_options, + std::forward(cto_functs)...); #else amrex::ignore_unused(runtime_options, l, f); static_assert(std::is_integral::value, "This requires C++17");