Skip to content

Commit

Permalink
adds precision dispatch to distributed matrix apply
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Apr 22, 2022
1 parent 61faa81 commit b432578
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 129 deletions.
110 changes: 58 additions & 52 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/distributed/matrix.hpp>


#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/csr.hpp>

Expand Down Expand Up @@ -286,65 +287,70 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
const LinOp* b, LinOp* x) const
{
auto dense_b = as<global_vector_type>(b);
auto dense_x = as<global_vector_type>(x);
auto x_exec = x->get_executor();
auto local_x = gko::matrix::Dense<ValueType>::create(
x_exec, dense_x->get_local_vector()->get_size(),
gko::make_array_view(
x_exec, dense_x->get_local_vector()->get_num_stored_elements(),
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());
if (this->get_const_local_offdiag()->get_size()) {
auto req = this->communicate(dense_b->get_local_vector());
diag_mtx_->apply(dense_b->get_local_vector(), local_x.get());
req.wait();
auto exec = this->get_executor();
auto needs_host_buffer =
exec->get_master() != exec && !gko::mpi::is_gpu_aware();
if (needs_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
offdiag_mtx_->apply(one_scalar_.get(), recv_buffer_.get(),
one_scalar_.get(), local_x.get());
} else {
diag_mtx_->apply(dense_b->get_local_vector(), local_x.get());
}
distributed::precision_dispatch_real_complex<ValueType>(
[this](const auto dense_b, auto dense_x) {
auto x_exec = dense_x->get_executor();
auto local_x = gko::matrix::Dense<ValueType>::create(
x_exec, dense_x->get_local_vector()->get_size(),
gko::make_array_view(
x_exec,
dense_x->get_local_vector()->get_num_stored_elements(),
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());
if (this->get_const_local_offdiag()->get_size()) {
auto req = this->communicate(dense_b->get_local_vector());
diag_mtx_->apply(dense_b->get_local_vector(), local_x.get());
req.wait();
auto exec = this->get_executor();
auto needs_host_buffer =
exec->get_master() != exec && !gko::mpi::is_gpu_aware();
if (needs_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
offdiag_mtx_->apply(one_scalar_.get(), recv_buffer_.get(),
one_scalar_.get(), local_x.get());
} else {
diag_mtx_->apply(dense_b->get_local_vector(), local_x.get());
}
},
b, x);
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const
{
auto dense_b = as<global_vector_type>(b);
auto dense_x = as<global_vector_type>(x);
const auto x_exec = x->get_executor();
auto local_x = gko::matrix::Dense<ValueType>::create(
x_exec, dense_x->get_local_vector()->get_size(),
gko::make_array_view(
x_exec, dense_x->get_local_vector()->get_num_stored_elements(),
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());
auto local_alpha = as<local_vector_type>(alpha);
auto local_beta = as<local_vector_type>(beta);
if (this->get_const_local_offdiag()->get_size()) {
auto req = this->communicate(dense_b->get_local_vector());
diag_mtx_->apply(local_alpha, dense_b->get_local_vector(), local_beta,
local_x.get());
req.wait();
auto exec = this->get_executor();
auto needs_host_buffer =
exec->get_master() != exec && !gko::mpi::is_gpu_aware();
if (needs_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
offdiag_mtx_->apply(local_alpha, recv_buffer_.get(), one_scalar_.get(),
local_x.get());
} else {
diag_mtx_->apply(local_alpha, dense_b->get_local_vector(), local_beta,
local_x.get());
}
distributed::precision_dispatch_real_complex<ValueType>(
[this](const auto local_alpha, const auto dense_b,
const auto local_beta, auto dense_x) {
const auto x_exec = dense_x->get_executor();
auto local_x = gko::matrix::Dense<ValueType>::create(
x_exec, dense_x->get_local_vector()->get_size(),
gko::make_array_view(
x_exec,
dense_x->get_local_vector()->get_num_stored_elements(),
dense_x->get_local_values()),
dense_x->get_local_vector()->get_stride());
if (this->get_const_local_offdiag()->get_size()) {
auto req = this->communicate(dense_b->get_local_vector());
diag_mtx_->apply(local_alpha, dense_b->get_local_vector(),
local_beta, local_x.get());
req.wait();
auto exec = this->get_executor();
auto needs_host_buffer =
exec->get_master() != exec && !gko::mpi::is_gpu_aware();
if (needs_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
offdiag_mtx_->apply(local_alpha, recv_buffer_.get(),
one_scalar_.get(), local_x.get());
} else {
diag_mtx_->apply(local_alpha, dense_b->get_local_vector(),
local_beta, local_x.get());
}
},
alpha, b, beta, x);
}


Expand Down
4 changes: 2 additions & 2 deletions core/stop/residual_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ void norm_dispatch(Function&& fn, LinOps*... linops)
{
if (use_distributed(linops...)) {
if (any_is_complex<ValueType>(linops...)) {
precision_dispatch_distributed<to_complex<ValueType>>(
distributed::precision_dispatch<to_complex<ValueType>>(
std::forward<Function>(fn), linops...);
} else {
precision_dispatch_distributed<ValueType>(
distributed::precision_dispatch<ValueType>(
std::forward<Function>(fn), linops...);
}
} else {
Expand Down
177 changes: 102 additions & 75 deletions include/ginkgo/core/base/precision_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,12 @@ void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
#if GINKGO_BUILD_MPI


namespace distributed {


template <typename ValueType>
detail::temporary_conversion<distributed::Vector<ValueType>>
make_temporary_conversion_distributed(LinOp* matrix)
make_temporary_conversion(LinOp* matrix)
{
auto result = detail::temporary_conversion<distributed::Vector<ValueType>>::
template create<distributed::Vector<next_precision<ValueType>>>(matrix);
Expand All @@ -350,7 +353,7 @@ make_temporary_conversion_distributed(LinOp* matrix)

template <typename ValueType>
detail::temporary_conversion<const distributed::Vector<ValueType>>
make_temporary_conversion_distributed(const LinOp* matrix)
make_temporary_conversion(const LinOp* matrix)
{
auto result =
detail::temporary_conversion<const distributed::Vector<ValueType>>::
Expand All @@ -364,38 +367,106 @@ make_temporary_conversion_distributed(const LinOp* matrix)


template <typename ValueType, typename Function, typename... Args>
void precision_dispatch_distributed(Function fn, Args*... linops)
void precision_dispatch(Function fn, Args*... linops)
{
fn(make_temporary_conversion_distributed<ValueType>(linops).get()...);
fn(distributed::make_temporary_conversion<ValueType>(linops).get()...);
}


template <typename ValueType, typename Function>
void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
{
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
distributed::make_temporary_conversion<to_complex<ValueType>>(in);
auto dense_out =
distributed::make_temporary_conversion<to_complex<ValueType>>(out);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
distributed::precision_dispatch<ValueType>(fn, in, out);
}
}


template <typename ValueType, typename Function>
void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
const LinOp* in, LinOp* out)
{
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
distributed::make_temporary_conversion<to_complex<ValueType>>(in);
auto dense_out =
distributed::make_temporary_conversion<to_complex<ValueType>>(out);
auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dense_alpha.get(),
dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
distributed::make_temporary_conversion<ValueType>(in).get(),
distributed::make_temporary_conversion<ValueType>(out).get());
}
}


template <typename ValueType, typename Function>
void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
const LinOp* in, const LinOp* beta,
LinOp* out)
{
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
distributed::make_temporary_conversion<to_complex<ValueType>>(in);
auto dense_out =
distributed::make_temporary_conversion<to_complex<ValueType>>(out);
auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
auto dense_beta = gko::make_temporary_conversion<ValueType>(beta);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dense_alpha.get(),
dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dense_beta.get(),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
distributed::make_temporary_conversion<ValueType>(in).get(),
gko::make_temporary_conversion<ValueType>(beta).get(),
distributed::make_temporary_conversion<ValueType>(out).get());
}
}


} // namespace distributed


template <typename ValueType, typename Function>
void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in,
LinOp* out)
{
if (dynamic_cast<const distributed::DistributedBase*>(in)) {
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
make_temporary_conversion_distributed<to_complex<ValueType>>(
in);
auto dense_out =
make_temporary_conversion_distributed<to_complex<ValueType>>(
out);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
precision_dispatch_distributed<ValueType>(fn, in, out);
}
distributed::precision_dispatch_real_complex<ValueType>(fn, in, out);
} else {
precision_dispatch_real_complex<ValueType>(fn, in, out);
gko::precision_dispatch_real_complex<ValueType>(fn, in, out);
}
}

Expand All @@ -406,31 +477,10 @@ void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* in, LinOp* out)
{
if (dynamic_cast<const distributed::DistributedBase*>(in)) {
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
make_temporary_conversion_distributed<to_complex<ValueType>>(
in);
auto dense_out =
make_temporary_conversion_distributed<to_complex<ValueType>>(
out);
auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dense_alpha.get(),
dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
fn(make_temporary_conversion<ValueType>(alpha).get(),
make_temporary_conversion_distributed<ValueType>(in).get(),
make_temporary_conversion_distributed<ValueType>(out).get());
}
distributed::precision_dispatch_real_complex<ValueType>(fn, alpha, in,
out);
} else {
precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
}
}

Expand All @@ -442,34 +492,11 @@ void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* beta, LinOp* out)
{
if (dynamic_cast<const distributed::DistributedBase*>(in)) {
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
make_temporary_conversion_distributed<to_complex<ValueType>>(
in);
auto dense_out =
make_temporary_conversion_distributed<to_complex<ValueType>>(
out);
auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
auto dense_beta = make_temporary_conversion<ValueType>(beta);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dense_alpha.get(),
dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dense_beta.get(),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
fn(make_temporary_conversion<ValueType>(alpha).get(),
make_temporary_conversion_distributed<ValueType>(in).get(),
make_temporary_conversion<ValueType>(beta).get(),
make_temporary_conversion_distributed<ValueType>(out).get());
}
distributed::precision_dispatch_real_complex<ValueType>(fn, alpha, in,
beta, out);
} else {
precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta, out);
gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta,
out);
}
}

Expand Down

0 comments on commit b432578

Please sign in to comment.