Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axpby using less deep copy (solves issue #2080) #2081

Merged
merged 4 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions blas/impl/KokkosBlas1_axpby_unification_attempt_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,31 @@ struct AxpbyUnificationAttemptTraits {
// - variable names begin with lower case letters
// - type names begin with upper case letters
// ********************************************************************
private:
public:
static constexpr bool onDevice =
KokkosKernels::Impl::kk_is_gpu_exec_space<tExecSpace>();

private:
static constexpr bool onHost = !onDevice;

public:
static constexpr bool a_is_scalar = !Kokkos::is_view_v<AV>;
static constexpr bool a_is_r0 = Tr0_val<AV>();
static constexpr bool a_is_r1s = Tr1s_val<AV>();
static constexpr bool a_is_r1d = Tr1d_val<AV>();

private:
static constexpr bool a_is_r0 = Tr0_val<AV>();
static constexpr bool a_is_r1s = Tr1s_val<AV>();
static constexpr bool a_is_r1d = Tr1d_val<AV>();

static constexpr bool x_is_r1 = Kokkos::is_view_v<XMV> && (XMV::rank == 1);
static constexpr bool x_is_r2 = Kokkos::is_view_v<XMV> && (XMV::rank == 2);

public:
static constexpr bool b_is_scalar = !Kokkos::is_view_v<BV>;
static constexpr bool b_is_r0 = Tr0_val<BV>();
static constexpr bool b_is_r1s = Tr1s_val<BV>();
static constexpr bool b_is_r1d = Tr1d_val<BV>();

private:
static constexpr bool b_is_r0 = Tr0_val<BV>();
static constexpr bool b_is_r1s = Tr1s_val<BV>();
static constexpr bool b_is_r1d = Tr1d_val<BV>();

static constexpr bool y_is_r1 = Kokkos::is_view_v<YMV> && (YMV::rank == 1);
static constexpr bool y_is_r2 = Kokkos::is_view_v<YMV> && (YMV::rank == 2);
Expand Down Expand Up @@ -220,10 +228,12 @@ struct AxpbyUnificationAttemptTraits {
// 'AtInputScalarTypeA_nonConst'
>;

using InternalTypeA_onDevice =
using InternalTypeA_onDevice = std::conditional_t<
a_is_scalar && b_is_scalar && onDevice, // Keep 'a' as scalar
InternalScalarTypeA,
Kokkos::View<const InternalScalarTypeA*, InternalLayoutA,
typename XMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
Kokkos::MemoryTraits<Kokkos::Unmanaged>>>;

using InternalTypeA_onHost = std::conditional_t<
(a_is_r1d || a_is_r1s) && xyRank2Case && onHost,
Expand Down Expand Up @@ -276,13 +286,15 @@ struct AxpbyUnificationAttemptTraits {
// 'AtInputScalarTypeB_nonConst'
>;

using InternalTypeB_onDevice =
using InternalTypeB_onDevice = std::conditional_t<
a_is_scalar && b_is_scalar && onDevice, // Keep 'b' as scalar
InternalScalarTypeB,
Kokkos::View<const InternalScalarTypeB*, InternalLayoutB,
typename YMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
Kokkos::MemoryTraits<Kokkos::Unmanaged>>>;

using InternalTypeB_onHost = std::conditional_t<
((b_is_r1d || b_is_r1s) && xyRank2Case && onHost),
(b_is_r1d || b_is_r1s) && xyRank2Case && onHost,
Kokkos::View<const InternalScalarTypeB*, InternalLayoutB,
typename YMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>,
Expand Down Expand Up @@ -614,7 +626,9 @@ struct AxpbyUnificationAttemptTraits {
}
} else {
if constexpr (xyRank1Case) {
constexpr bool internalTypeA_isOk = internalTypeA_is_r1d;
constexpr bool internalTypeA_isOk =
internalTypeA_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeA_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -630,7 +644,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank1Case: InternalTypeX is wrong");

constexpr bool internalTypeB_isOk = internalTypeB_is_r1d;
constexpr bool internalTypeB_isOk =
internalTypeB_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeB_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -646,7 +662,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank1Case: InternalTypeY is wrong");
} else {
constexpr bool internalTypeA_isOk = internalTypeA_is_r1d;
constexpr bool internalTypeA_isOk =
internalTypeA_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeA_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -662,7 +680,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank2Case: InternalTypeX is wrong");

constexpr bool internalTypeB_isOk = internalTypeB_is_r1d;
constexpr bool internalTypeB_isOk =
internalTypeB_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeB_is_scalar);
static_assert(
internalTypeB_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand Down Expand Up @@ -703,16 +723,19 @@ struct AxpbyUnificationAttemptTraits {
// ****************************************************************
// We are in the 'onDevice' case, with 2 possible subcases:
//
// 1) xyRank1Case, with only one possible situation:
// - [InternalTypeA / B] = [view<S_a*,1>, view<S_b*,1>]
// 1) xyRank1Case, with the following possible situations:
// - [InternalTypeA, B] = [S_a, S_b], or
// - [InternalTypeA, B] = [view<S_a*,1>, view<S_b*,1>]
//
// or
//
// 2) xyRank2Case, with only one possible situation:
// - [InternalTypeA / B] = [view<S_a*,1 / m>, view<S_b*,1 / m>]
// 2) xyRank2Case, with the following possible situations:
// - [InternalTypeA, B] = [S_a, S_b], or
// - [InternalTypeA, B] = [view<S_a*,1 / m>, view<S_b*,1 / m>]
// ****************************************************************
static_assert(
internalTypesAB_bothViews,
internalTypesAB_bothViews ||
(a_is_scalar && b_is_scalar && internalTypesAB_bothScalars),
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, invalid combination of types");
}
Expand Down
41 changes: 34 additions & 7 deletions blas/src/KokkosBlas1_axpby.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,41 @@ void axpby(const execution_space& exec_space, const AV& a, const XMV& X,
InternalTypeY internal_Y = Y;

if constexpr (AxpbyTraits::internalTypesAB_bothScalars) {
InternalTypeA internal_a(Impl::getScalarValueFromVariableAtHost<
AV, Impl::typeRank<AV>()>::getValue(a));
InternalTypeB internal_b(Impl::getScalarValueFromVariableAtHost<
BV, Impl::typeRank<BV>()>::getValue(b));
// ********************************************************************
// The unification logic applies the following general rules:
// 1) In a 'onHost' case, it makes the internal types for 'a' and 'b'
// to be both scalars (hence the name 'internalTypesAB_bothScalars')
// 2) In a 'onDevice' case, it makes the internal types for 'a' and 'b'
// to be Kokkos views. For performance reasons in Trilinos, the only
// exception for this rule is when the input types for both 'a' and
// 'b' are already scalars, in which case the internal types for 'a'
// and 'b' become scalars as well, eventually changing precision in
// order to match the precisions of 'X' and 'Y'.
// ********************************************************************
if constexpr (AxpbyTraits::a_is_scalar && AxpbyTraits::b_is_scalar &&
AxpbyTraits::onDevice) {
// ******************************************************************
// We are in the exception situation for rule 2
// ******************************************************************
InternalTypeA internal_a(a);
InternalTypeA internal_b(b);

Impl::Axpby<execution_space, InternalTypeA, InternalTypeX, InternalTypeB,
InternalTypeY>::axpby(exec_space, internal_a, internal_X,
internal_b, internal_Y);
Impl::Axpby<execution_space, InternalTypeA, InternalTypeX, InternalTypeB,
InternalTypeY>::axpby(exec_space, internal_a, internal_X,
internal_b, internal_Y);
} else {
// ******************************************************************
// We are in rule 1, that is, we are in a 'onHost' case now
// ******************************************************************
InternalTypeA internal_a(Impl::getScalarValueFromVariableAtHost<
AV, Impl::typeRank<AV>()>::getValue(a));
InternalTypeB internal_b(Impl::getScalarValueFromVariableAtHost<
BV, Impl::typeRank<BV>()>::getValue(b));

Impl::Axpby<execution_space, InternalTypeA, InternalTypeX, InternalTypeB,
InternalTypeY>::axpby(exec_space, internal_a, internal_X,
internal_b, internal_Y);
}
} else if constexpr (AxpbyTraits::internalTypesAB_bothViews) {
constexpr bool internalLayoutA_isStride(
std::is_same_v<typename InternalTypeA::array_layout,
Expand Down