Skip to content

Commit

Permalink
Backup
Browse files Browse the repository at this point in the history
  • Loading branch information
eeprude committed Dec 25, 2023
1 parent 5b75a1a commit 11d369b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
35 changes: 22 additions & 13 deletions blas/impl/KokkosBlas1_axpby_unification_attempt_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ struct AxpbyUnificationAttemptTraits {
static constexpr bool a_is_scalar = !Kokkos::is_view_v<AV>;

private:
static constexpr bool a_is_r0 = Tr0_val<AV>();
static constexpr bool a_is_r1s = Tr1s_val<AV>();
static constexpr bool a_is_r1d = Tr1d_val<AV>();
static constexpr bool a_is_r0 = Tr0_val<AV>();
static constexpr bool a_is_r1s = Tr1s_val<AV>();
static constexpr bool a_is_r1d = Tr1d_val<AV>();

static constexpr bool x_is_r1 = Kokkos::is_view_v<XMV> && (XMV::rank == 1);
static constexpr bool x_is_r2 = Kokkos::is_view_v<XMV> && (XMV::rank == 2);
Expand All @@ -126,9 +126,9 @@ struct AxpbyUnificationAttemptTraits {
static constexpr bool b_is_scalar = !Kokkos::is_view_v<BV>;

private:
static constexpr bool b_is_r0 = Tr0_val<BV>();
static constexpr bool b_is_r1s = Tr1s_val<BV>();
static constexpr bool b_is_r1d = Tr1d_val<BV>();
static constexpr bool b_is_r0 = Tr0_val<BV>();
static constexpr bool b_is_r1s = Tr1s_val<BV>();
static constexpr bool b_is_r1d = Tr1d_val<BV>();

static constexpr bool y_is_r1 = Kokkos::is_view_v<YMV> && (YMV::rank == 1);
static constexpr bool y_is_r2 = Kokkos::is_view_v<YMV> && (YMV::rank == 2);
Expand Down Expand Up @@ -229,7 +229,7 @@ struct AxpbyUnificationAttemptTraits {
>;

using InternalTypeA_onDevice = std::conditional_t<
a_is_scalar && b_is_scalar && onDevice, // Keep 'a' as scalar
a_is_scalar && b_is_scalar && onDevice, // Keep 'a' as scalar
InternalScalarTypeA,
Kokkos::View<const InternalScalarTypeA*, InternalLayoutA,
typename XMV::device_type,
Expand Down Expand Up @@ -287,7 +287,7 @@ struct AxpbyUnificationAttemptTraits {
>;

using InternalTypeB_onDevice = std::conditional_t<
a_is_scalar && b_is_scalar && onDevice, // Keep 'b' as scalar
a_is_scalar && b_is_scalar && onDevice, // Keep 'b' as scalar
InternalScalarTypeB,
Kokkos::View<const InternalScalarTypeB*, InternalLayoutB,
typename YMV::device_type,
Expand Down Expand Up @@ -626,7 +626,9 @@ struct AxpbyUnificationAttemptTraits {
}
} else {
if constexpr (xyRank1Case) {
constexpr bool internalTypeA_isOk = internalTypeA_is_r1d || (a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
constexpr bool internalTypeA_isOk =
internalTypeA_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeA_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -642,7 +644,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank1Case: InternalTypeX is wrong");

constexpr bool internalTypeB_isOk = internalTypeB_is_r1d || (a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
constexpr bool internalTypeB_isOk =
internalTypeB_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeB_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -658,7 +662,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank1Case: InternalTypeY is wrong");
} else {
constexpr bool internalTypeA_isOk = internalTypeA_is_r1d || (a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
constexpr bool internalTypeA_isOk =
internalTypeA_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeA_is_scalar);
static_assert(
internalTypeA_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand All @@ -674,7 +680,9 @@ struct AxpbyUnificationAttemptTraits {
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, xyRank2Case: InternalTypeX is wrong");

constexpr bool internalTypeB_isOk = internalTypeB_is_r1d || (a_is_scalar && b_is_scalar && internalTypeB_is_scalar);
constexpr bool internalTypeB_isOk =
internalTypeB_is_r1d ||
(a_is_scalar && b_is_scalar && internalTypeB_is_scalar);
static_assert(
internalTypeB_isOk,
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
Expand Down Expand Up @@ -726,7 +734,8 @@ struct AxpbyUnificationAttemptTraits {
// - [InternalTypeA, B] = [view<S_a*,1 / m>, view<S_b*,1 / m>]
// ****************************************************************
static_assert(
internalTypesAB_bothViews || (a_is_scalar && b_is_scalar && internalTypesAB_bothScalars),
internalTypesAB_bothViews ||
(a_is_scalar && b_is_scalar && internalTypesAB_bothScalars),
"KokkosBlas::Impl::AxpbyUnificationAttemptTraits::performChecks()"
", onDevice, invalid combination of types");
}
Expand Down
23 changes: 18 additions & 5 deletions blas/src/KokkosBlas1_axpby.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,32 @@ void axpby(const execution_space& exec_space, const AV& a, const XMV& X,
InternalTypeY internal_Y = Y;

if constexpr (AxpbyTraits::internalTypesAB_bothScalars) {
if constexpr (AxpbyTraits::a_is_scalar && AxpbyTraits::b_is_scalar && AxpbyTraits::onDevice) {
// ********************************************************************
// The unification logic applies the following general rules:
// 1) In a 'onHost' case, it makes the internal types for 'a' and 'b'
// to be both scalars (hence the name 'internalTypesAB_bothScalars')
// 2) In a 'onDevice' case, it makes the internal types for 'a' and 'b'
// to be Kokkos views. For performance reasons in Trilinos, the only
// exception for this rule is when the input types for both 'a' and
// 'b' are already scalars, in which case the internal types for 'a'
// and 'b' become scalars as well, eventually changing precision in
// order to match the precisions of 'X' and 'Y'.
// ********************************************************************
if constexpr (AxpbyTraits::a_is_scalar && AxpbyTraits::b_is_scalar &&
AxpbyTraits::onDevice) {
// ******************************************************************
// In this special case, 'a' and 'b' are kept as scalar, evantually
// changing precision to match the precisions of 'X' and 'Y'
// We are in the exception situation for rule 2
// ******************************************************************
InternalTypeA internal_a(a);
InternalTypeA internal_b(b);

Impl::Axpby<execution_space, InternalTypeA, InternalTypeX, InternalTypeB,
InternalTypeY>::axpby(exec_space, internal_a, internal_X,
internal_b, internal_Y);
}
else {
} else {
// ******************************************************************
// We are in rule 1, that is, we are in a 'onHost' case now
// ******************************************************************
InternalTypeA internal_a(Impl::getScalarValueFromVariableAtHost<
AV, Impl::typeRank<AV>()>::getValue(a));
InternalTypeB internal_b(Impl::getScalarValueFromVariableAtHost<
Expand Down

0 comments on commit 11d369b

Please sign in to comment.