Skip to content

Commit

Permalink
adds complex-to-real dispatch for distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Mar 1, 2022
1 parent f98622c commit 4f09d10
Showing 1 changed file with 80 additions and 5 deletions.
85 changes: 80 additions & 5 deletions include/ginkgo/core/base/precision_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,24 +375,99 @@ void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in,
LinOp* out)
{
if (dynamic_cast<const distributed::DistributedBase*>(in)) {
precision_dispatch_distributed<ValueType>(fn, in, out);
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
make_temporary_conversion_distributed<to_complex<ValueType>>(
in);
auto dense_out =
make_temporary_conversion_distributed<to_complex<ValueType>>(
out);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
precision_dispatch_distributed<ValueType>(fn, in, out);
}
} else {
precision_dispatch_real_complex<ValueType>(fn, in, out);
}
}


template <typename ValueType, typename Function>
void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* alpha,
const LinOp* in, LinOp* out)
{
if (dynamic_cast<const distributed::DistributedBase*>(in)) {
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
make_temporary_conversion_distributed<to_complex<ValueType>>(
in);
auto dense_out =
make_temporary_conversion_distributed<to_complex<ValueType>>(
out);
auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dense_alpha.get(),
dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
fn(make_temporary_conversion<ValueType>(alpha).get(),
make_temporary_conversion_distributed<ValueType>(in).get(),
make_temporary_conversion_distributed<ValueType>(out).get());
}
} else {
precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
}
}


template <typename ValueType, typename Function>
void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* alpha,
const LinOp* in,
const LinOp* beta, LinOp* out)
{
if (dynamic_cast<const distributed::DistributedBase*>(in)) {
fn(make_temporary_conversion<ValueType>(alpha).get(),
make_temporary_conversion_distributed<ValueType>(in).get(),
make_temporary_conversion<ValueType>(beta).get(),
make_temporary_conversion_distributed<ValueType>(out).get());
auto complex_to_real =
!(is_complex<ValueType>() ||
dynamic_cast<const ConvertibleTo<distributed::Vector<>>*>(in));
if (complex_to_real) {
auto dense_in =
make_temporary_conversion_distributed<to_complex<ValueType>>(
in);
auto dense_out =
make_temporary_conversion_distributed<to_complex<ValueType>>(
out);
auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
auto dense_beta = make_temporary_conversion<ValueType>(beta);
using Vector = distributed::Vector<ValueType>;
// These dynamic_casts are only needed to make the code compile
// If ValueType is complex, this branch will never be taken
// If ValueType is real, the cast is a no-op
fn(dense_alpha.get(),
dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
dense_beta.get(),
dynamic_cast<Vector*>(dense_out->create_real_view().get()));
} else {
fn(make_temporary_conversion<ValueType>(alpha).get(),
make_temporary_conversion_distributed<ValueType>(in).get(),
make_temporary_conversion<ValueType>(beta).get(),
make_temporary_conversion_distributed<ValueType>(out).get());
}
} else {
precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta, out);
}
Expand Down

0 comments on commit 4f09d10

Please sign in to comment.