diff --git a/core/distributed/matrix.cpp b/core/distributed/matrix.cpp index fb656f78684..5808cba692f 100644 --- a/core/distributed/matrix.cpp +++ b/core/distributed/matrix.cpp @@ -33,6 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include #include @@ -286,30 +287,33 @@ template void Matrix::apply_impl( const LinOp* b, LinOp* x) const { - auto dense_b = as(b); - auto dense_x = as(x); - auto x_exec = x->get_executor(); - auto local_x = gko::matrix::Dense::create( - x_exec, dense_x->get_local_vector()->get_size(), - gko::make_array_view( - x_exec, dense_x->get_local_vector()->get_num_stored_elements(), - dense_x->get_local_values()), - dense_x->get_local_vector()->get_stride()); - if (this->get_const_local_offdiag()->get_size()) { - auto req = this->communicate(dense_b->get_local_vector()); - diag_mtx_->apply(dense_b->get_local_vector(), local_x.get()); - req.wait(); - auto exec = this->get_executor(); - auto needs_host_buffer = - exec->get_master() != exec && !gko::mpi::is_gpu_aware(); - if (needs_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } - offdiag_mtx_->apply(one_scalar_.get(), recv_buffer_.get(), - one_scalar_.get(), local_x.get()); - } else { - diag_mtx_->apply(dense_b->get_local_vector(), local_x.get()); - } + distributed::precision_dispatch_real_complex( + [this](const auto dense_b, auto dense_x) { + auto x_exec = dense_x->get_executor(); + auto local_x = gko::matrix::Dense::create( + x_exec, dense_x->get_local_vector()->get_size(), + gko::make_array_view( + x_exec, + dense_x->get_local_vector()->get_num_stored_elements(), + dense_x->get_local_values()), + dense_x->get_local_vector()->get_stride()); + if (this->get_const_local_offdiag()->get_size()) { + auto req = this->communicate(dense_b->get_local_vector()); + diag_mtx_->apply(dense_b->get_local_vector(), local_x.get()); + req.wait(); + auto exec = this->get_executor(); + auto needs_host_buffer = + exec->get_master() != exec && !gko::mpi::is_gpu_aware(); + if (needs_host_buffer) { + recv_buffer_->copy_from(host_recv_buffer_.get()); + } + offdiag_mtx_->apply(one_scalar_.get(), recv_buffer_.get(), + one_scalar_.get(), local_x.get()); + } else { + diag_mtx_->apply(dense_b->get_local_vector(), local_x.get()); + } + }, + b, x); } @@ -317,34 +321,36 @@ template void Matrix::apply_impl( const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const { - auto dense_b = as(b); - auto dense_x = as(x); - const auto x_exec = x->get_executor(); - auto local_x = gko::matrix::Dense::create( - x_exec, dense_x->get_local_vector()->get_size(), - gko::make_array_view( - x_exec, dense_x->get_local_vector()->get_num_stored_elements(), - dense_x->get_local_values()), - dense_x->get_local_vector()->get_stride()); - auto local_alpha = as(alpha); - auto local_beta = as(beta); - if (this->get_const_local_offdiag()->get_size()) { - auto req = this->communicate(dense_b->get_local_vector()); - diag_mtx_->apply(local_alpha, dense_b->get_local_vector(), local_beta, - local_x.get()); - req.wait(); - auto exec = this->get_executor(); - auto needs_host_buffer = - exec->get_master() != exec && !gko::mpi::is_gpu_aware(); - if (needs_host_buffer) { - recv_buffer_->copy_from(host_recv_buffer_.get()); - } - offdiag_mtx_->apply(local_alpha, recv_buffer_.get(), one_scalar_.get(), - local_x.get()); - } else { - diag_mtx_->apply(local_alpha, dense_b->get_local_vector(), local_beta, - local_x.get()); - } + distributed::precision_dispatch_real_complex( + [this](const auto local_alpha, const auto dense_b, + const auto local_beta, auto dense_x) { + const auto x_exec = dense_x->get_executor(); + auto local_x = gko::matrix::Dense::create( + x_exec, dense_x->get_local_vector()->get_size(), + gko::make_array_view( + x_exec, + dense_x->get_local_vector()->get_num_stored_elements(), + dense_x->get_local_values()), + dense_x->get_local_vector()->get_stride()); + if (this->get_const_local_offdiag()->get_size()) { + auto req = this->communicate(dense_b->get_local_vector()); + diag_mtx_->apply(local_alpha, dense_b->get_local_vector(), + local_beta, local_x.get()); + req.wait(); + auto exec = this->get_executor(); + auto needs_host_buffer = + exec->get_master() != exec && !gko::mpi::is_gpu_aware(); + if (needs_host_buffer) { + recv_buffer_->copy_from(host_recv_buffer_.get()); + } + offdiag_mtx_->apply(local_alpha, recv_buffer_.get(), + one_scalar_.get(), local_x.get()); + } else { + diag_mtx_->apply(local_alpha, dense_b->get_local_vector(), + local_beta, local_x.get()); + } + }, + alpha, b, beta, x); } diff --git a/core/stop/residual_norm.cpp b/core/stop/residual_norm.cpp index 429e1a20c06..9fe89447f93 100644 --- a/core/stop/residual_norm.cpp +++ b/core/stop/residual_norm.cpp @@ -106,10 +106,10 @@ void norm_dispatch(Function&& fn, LinOps*... linops) { if (use_distributed(linops...)) { if (any_is_complex(linops...)) { - precision_dispatch_distributed>( + distributed::precision_dispatch>( std::forward(fn), linops...); } else { - precision_dispatch_distributed( + distributed::precision_dispatch( std::forward(fn), linops...); } } else { diff --git a/include/ginkgo/core/base/precision_dispatch.hpp b/include/ginkgo/core/base/precision_dispatch.hpp index 8c422450dc7..ac40914be89 100644 --- a/include/ginkgo/core/base/precision_dispatch.hpp +++ b/include/ginkgo/core/base/precision_dispatch.hpp @@ -335,9 +335,12 @@ void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in, #if GINKGO_BUILD_MPI +namespace distributed { + + template detail::temporary_conversion> -make_temporary_conversion_distributed(LinOp* matrix) +make_temporary_conversion(LinOp* matrix) { auto result = detail::temporary_conversion>:: template create>>(matrix); @@ -350,7 +353,7 @@ make_temporary_conversion_distributed(LinOp* matrix) template detail::temporary_conversion> -make_temporary_conversion_distributed(const LinOp* matrix) +make_temporary_conversion(const LinOp* matrix) { auto result = detail::temporary_conversion>:: @@ -364,38 +367,106 @@ make_temporary_conversion_distributed(const LinOp* matrix) template -void precision_dispatch_distributed(Function fn, Args*... linops) +void precision_dispatch(Function fn, Args*... linops) { - fn(make_temporary_conversion_distributed(linops).get()...); + fn(distributed::make_temporary_conversion(linops).get()...); } +template +void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out) +{ + auto complex_to_real = + !(is_complex() || + dynamic_cast>*>(in)); + if (complex_to_real) { + auto dense_in = + distributed::make_temporary_conversion>(in); + auto dense_out = + distributed::make_temporary_conversion>(out); + using Vector = distributed::Vector; + // These dynamic_casts are only needed to make the code compile + // If ValueType is complex, this branch will never be taken + // If ValueType is real, the cast is a no-op + fn(dynamic_cast(dense_in->create_real_view().get()), + dynamic_cast(dense_out->create_real_view().get())); + } else { + distributed::precision_dispatch(fn, in, out); + } +} + + +template +void precision_dispatch_real_complex(Function fn, const LinOp* alpha, + const LinOp* in, LinOp* out) +{ + auto complex_to_real = + !(is_complex() || + dynamic_cast>*>(in)); + if (complex_to_real) { + auto dense_in = + distributed::make_temporary_conversion>(in); + auto dense_out = + distributed::make_temporary_conversion>(out); + auto dense_alpha = gko::make_temporary_conversion(alpha); + using Vector = distributed::Vector; + // These dynamic_casts are only needed to make the code compile + // If ValueType is complex, this branch will never be taken + // If ValueType is real, the cast is a no-op + fn(dense_alpha.get(), + dynamic_cast(dense_in->create_real_view().get()), + dynamic_cast(dense_out->create_real_view().get())); + } else { + fn(gko::make_temporary_conversion(alpha).get(), + distributed::make_temporary_conversion(in).get(), + distributed::make_temporary_conversion(out).get()); + } +} + + +template +void precision_dispatch_real_complex(Function fn, const LinOp* alpha, + const LinOp* in, const LinOp* beta, + LinOp* out) +{ + auto complex_to_real = + !(is_complex() || + dynamic_cast>*>(in)); + if (complex_to_real) { + auto dense_in = + distributed::make_temporary_conversion>(in); + auto dense_out = + distributed::make_temporary_conversion>(out); + auto dense_alpha = gko::make_temporary_conversion(alpha); + auto dense_beta = gko::make_temporary_conversion(beta); + using Vector = distributed::Vector; + // These dynamic_casts are only needed to make the code compile + // If ValueType is complex, this branch will never be taken + // If ValueType is real, the cast is a no-op + fn(dense_alpha.get(), + dynamic_cast(dense_in->create_real_view().get()), + dense_beta.get(), + dynamic_cast(dense_out->create_real_view().get())); + } else { + fn(gko::make_temporary_conversion(alpha).get(), + distributed::make_temporary_conversion(in).get(), + gko::make_temporary_conversion(beta).get(), + distributed::make_temporary_conversion(out).get()); + } +} + + +} // namespace distributed + + template void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in, LinOp* out) { if (dynamic_cast(in)) { - auto complex_to_real = - !(is_complex() || - dynamic_cast>*>(in)); - if (complex_to_real) { - auto dense_in = - make_temporary_conversion_distributed>( - in); - auto dense_out = - make_temporary_conversion_distributed>( - out); - using Vector = distributed::Vector; - // These dynamic_casts are only needed to make the code compile - // If ValueType is complex, this branch will never be taken - // If ValueType is real, the cast is a no-op - fn(dynamic_cast(dense_in->create_real_view().get()), - dynamic_cast(dense_out->create_real_view().get())); - } else { - precision_dispatch_distributed(fn, in, out); - } + distributed::precision_dispatch_real_complex(fn, in, out); } else { - precision_dispatch_real_complex(fn, in, out); + gko::precision_dispatch_real_complex(fn, in, out); } } @@ -406,31 +477,10 @@ void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in, LinOp* out) { if (dynamic_cast(in)) { - auto complex_to_real = - !(is_complex() || - dynamic_cast>*>(in)); - if (complex_to_real) { - auto dense_in = - make_temporary_conversion_distributed>( - in); - auto dense_out = - make_temporary_conversion_distributed>( - out); - auto dense_alpha = make_temporary_conversion(alpha); - using Vector = distributed::Vector; - // These dynamic_casts are only needed to make the code compile - // If ValueType is complex, this branch will never be taken - // If ValueType is real, the cast is a no-op - fn(dense_alpha.get(), - dynamic_cast(dense_in->create_real_view().get()), - dynamic_cast(dense_out->create_real_view().get())); - } else { - fn(make_temporary_conversion(alpha).get(), - make_temporary_conversion_distributed(in).get(), - make_temporary_conversion_distributed(out).get()); - } + distributed::precision_dispatch_real_complex(fn, alpha, in, + out); } else { - precision_dispatch_real_complex(fn, alpha, in, out); + gko::precision_dispatch_real_complex(fn, alpha, in, out); } } @@ -442,34 +492,11 @@ void precision_dispatch_real_complex_distributed(Function fn, const LinOp* beta, LinOp* out) { if (dynamic_cast(in)) { - auto complex_to_real = - !(is_complex() || - dynamic_cast>*>(in)); - if (complex_to_real) { - auto dense_in = - make_temporary_conversion_distributed>( - in); - auto dense_out = - make_temporary_conversion_distributed>( - out); - auto dense_alpha = make_temporary_conversion(alpha); - auto dense_beta = make_temporary_conversion(beta); - using Vector = distributed::Vector; - // These dynamic_casts are only needed to make the code compile - // If ValueType is complex, this branch will never be taken - // If ValueType is real, the cast is a no-op - fn(dense_alpha.get(), - dynamic_cast(dense_in->create_real_view().get()), - dense_beta.get(), - dynamic_cast(dense_out->create_real_view().get())); - } else { - fn(make_temporary_conversion(alpha).get(), - make_temporary_conversion_distributed(in).get(), - make_temporary_conversion(beta).get(), - make_temporary_conversion_distributed(out).get()); - } + distributed::precision_dispatch_real_complex(fn, alpha, in, + beta, out); } else { - precision_dispatch_real_complex(fn, alpha, in, beta, out); + gko::precision_dispatch_real_complex(fn, alpha, in, beta, + out); } }