Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor batched serial pbtrs implementation details and tests #2504

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions batched/dense/src/KokkosBatched_Pbtrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ namespace KokkosBatched {
/// where U is an upper triangular matrix, U**H is the transpose of U, and
/// L is lower triangular matrix, L**H is the transpose of L.
///
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D
/// view
/// \tparam BViewType: Input type for a right-hand side and the solution,
/// needs to be a 1D view
/// \tparam ArgUplo: Type indicating whether A is the upper (Uplo::Upper) or lower (Uplo::Lower) triangular matrix
/// \tparam ArgAlgo: Type indicating the blocked (KokkosBatched::Algo::Pbtrs::Blocked) or unblocked
/// (KokkosBatched::Algo::Pbtrs::Unblocked) algorithm to be used
///
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D view
/// \tparam BViewType: Input type for a right-hand side and the solution, needs to be a 1D view
///
/// \param ab [in]: ab is a ldab by n banded matrix, with ( kd + 1 ) diagonals
/// \param b [inout]: right-hand side and the solution, a rank 1 view
Expand All @@ -45,6 +47,10 @@ namespace KokkosBatched {

template <typename ArgUplo, typename ArgAlgo>
struct SerialPbtrs {
static_assert(
std::is_same_v<ArgUplo, Uplo::Upper> || std::is_same_v<ArgUplo, Uplo::Lower>,
"KokkosBatched::pbtrs: Use Uplo::Upper for upper triangular matrix or Uplo::Lower for lower triangular matrix");
static_assert(std::is_same_v<ArgAlgo, Algo::Pbtrs::Unblocked>, "KokkosBatched::pbtrs: Use Algo::Pbtrs::Unblocked");
template <typename ABViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ABViewType &ab, const BViewType &b);
};
Expand Down
2 changes: 0 additions & 2 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@
#include "Test_Batched_SerialPttrs_Complex.hpp"
#include "Test_Batched_SerialPbtrf.hpp"
#include "Test_Batched_SerialPbtrs.hpp"
#include "Test_Batched_SerialPbtrs_Real.hpp"
#include "Test_Batched_SerialPbtrs_Complex.hpp"
#include "Test_Batched_SerialLaswp.hpp"
#include "Test_Batched_SerialIamax.hpp"
#include "Test_Batched_SerialGetrf.hpp"
Expand Down
163 changes: 117 additions & 46 deletions batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include "KokkosBatched_Pbtrs.hpp"
#include "Test_Batched_DenseUtils.hpp"

using namespace KokkosBatched;

namespace Test {
namespace Pbtrs {

Expand All @@ -36,14 +34,14 @@ struct ParamTag {
template <typename DeviceType, typename ABViewType, typename ParamTagType, typename AlgoTagType>
struct Functor_BatchedSerialPbtrf {
using execution_space = typename DeviceType::execution_space;
ABViewType _ab;
ABViewType m_ab;

KOKKOS_INLINE_FUNCTION
Functor_BatchedSerialPbtrf(const ABViewType &ab) : _ab(ab) {}
Functor_BatchedSerialPbtrf(const ABViewType &ab) : m_ab(ab) {}

KOKKOS_INLINE_FUNCTION
void operator()(const ParamTagType &, const int k) const {
auto sub_ab = Kokkos::subview(_ab, k, Kokkos::ALL(), Kokkos::ALL());
auto sub_ab = Kokkos::subview(m_ab, k, Kokkos::ALL(), Kokkos::ALL());

KokkosBatched::SerialPbtrf<typename ParamTagType::uplo, AlgoTagType>::invoke(sub_ab);
}
Expand All @@ -53,24 +51,24 @@ struct Functor_BatchedSerialPbtrf {
std::string name_region("KokkosBatched::Test::SerialPbtrs");
const std::string name_value_type = Test::value_type_name<value_type>();
std::string name = name_region + name_value_type;
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, _ab.extent(0));
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, m_ab.extent(0));
Kokkos::parallel_for(name.c_str(), policy, *this);
}
};

template <typename DeviceType, typename ABViewType, typename BViewType, typename ParamTagType, typename AlgoTagType>
struct Functor_BatchedSerialPbtrs {
using execution_space = typename DeviceType::execution_space;
ABViewType _ab;
BViewType _b;
ABViewType m_ab;
BViewType m_b;

KOKKOS_INLINE_FUNCTION
Functor_BatchedSerialPbtrs(const ABViewType &ab, const BViewType &b) : _ab(ab), _b(b) {}
Functor_BatchedSerialPbtrs(const ABViewType &ab, const BViewType &b) : m_ab(ab), m_b(b) {}

KOKKOS_INLINE_FUNCTION
void operator()(const ParamTagType &, const int k, int &info) const {
auto sub_ab = Kokkos::subview(_ab, k, Kokkos::ALL(), Kokkos::ALL());
auto bb = Kokkos::subview(_b, k, Kokkos::ALL());
auto sub_ab = Kokkos::subview(m_ab, k, Kokkos::ALL(), Kokkos::ALL());
auto bb = Kokkos::subview(m_b, k, Kokkos::ALL());

info += KokkosBatched::SerialPbtrs<typename ParamTagType::uplo, AlgoTagType>::invoke(sub_ab, bb);
}
Expand All @@ -82,7 +80,7 @@ struct Functor_BatchedSerialPbtrs {
std::string name = name_region + name_value_type;
int info_sum = 0;
Kokkos::Profiling::pushRegion(name.c_str());
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, _b.extent(0));
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, m_b.extent(0));
Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum);
Kokkos::Profiling::popRegion();
return info_sum;
Expand All @@ -92,70 +90,88 @@ struct Functor_BatchedSerialPbtrs {
template <typename DeviceType, typename ScalarType, typename AViewType, typename xViewType, typename yViewType>
struct Functor_BatchedSerialGemv {
using execution_space = typename DeviceType::execution_space;
AViewType _a;
xViewType _x;
yViewType _y;
ScalarType _alpha, _beta;
AViewType m_a;
xViewType m_x;
yViewType m_y;
ScalarType m_alpha, m_beta;

KOKKOS_INLINE_FUNCTION
Functor_BatchedSerialGemv(const ScalarType alpha, const AViewType &a, const xViewType &x, const ScalarType beta,
const yViewType &y)
: _a(a), _x(x), _y(y), _alpha(alpha), _beta(beta) {}
: m_a(a), m_x(x), m_y(y), m_alpha(alpha), m_beta(beta) {}

KOKKOS_INLINE_FUNCTION
void operator()(const int k) const {
auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL());
auto xx = Kokkos::subview(_x, k, Kokkos::ALL());
auto yy = Kokkos::subview(_y, k, Kokkos::ALL());
auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL());
auto xx = Kokkos::subview(m_x, k, Kokkos::ALL());
auto yy = Kokkos::subview(m_y, k, Kokkos::ALL());

KokkosBlas::SerialGemv<Trans::NoTranspose, Algo::Gemv::Unblocked>::invoke(_alpha, aa, xx, _beta, yy);
KokkosBlas::SerialGemv<Trans::NoTranspose, Algo::Gemv::Unblocked>::invoke(m_alpha, aa, xx, m_beta, yy);
}

inline void run() {
using value_type = typename AViewType::non_const_value_type;
std::string name_region("KokkosBatched::Test::SerialPbtrs");
const std::string name_value_type = Test::value_type_name<value_type>();
std::string name = name_region + name_value_type;
Kokkos::RangePolicy<execution_space> policy(0, _x.extent(0));
Kokkos::RangePolicy<execution_space> policy(0, m_x.extent(0));
Kokkos::parallel_for(name.c_str(), policy, *this);
}
};

template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
/// \brief Implementation details of batched pbtrs test
/// Confirm A * x = b, where
/// Confirm A * x = b, where
/// A: [[4, 1, 0],
/// [1, 4, 1],
/// [0, 1, 4]]
/// b: [1, 1, 1]
/// x: [3/14, 1/7, 3/14]
///
/// This corresponds to the following system of equations:
/// This corresponds to the following system of equations:
/// 4 x0 + x1 = 1
/// x0 + 4 x1 + x2 = 1
/// x1 + 4 x2 = 1
///
/// We confirm this with the factorized band matrix Ub or Lb.
/// For upper banded storage, Ab = Ub**H * Ub
/// Ub: [[0, 1/sqrt(4), 1/sqrt(4 - (1/sqrt(4))**2)],
/// [sqrt(4), sqrt(4 - (1/sqrt(4))**2), sqrt(4 - 1/sqrt(4 - (1/sqrt(4))**2))],]
/// For lower banded storage, Ab = Lb * Lb**H
/// Lb: [[sqrt(4), sqrt(4 - (1/sqrt(4))**2), sqrt(4 - 1/sqrt(4 - (1/sqrt(4))**2))],
/// [1/sqrt(4), 1/sqrt(4 - (1/sqrt(4))**2), 0],]
///
/// \param N [in] Batch size of RHS (banded matrix can also be batched matrix)
/// \param k [in] Number of superdiagonals or subdiagonals of matrix A
/// \param BlkSize [in] Block size of matrix A
template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
void impl_test_batched_pbtrs_analytical(const int N) {
using ats = typename Kokkos::ArithTraits<ScalarType>;
using RealType = typename ats::mag_type;
using View2DType = Kokkos::View<ScalarType **, LayoutType, DeviceType>;
using View3DType = Kokkos::View<ScalarType ***, LayoutType, DeviceType>;

constexpr int BlkSize = 3, k = 1;
View3DType A("A", N, BlkSize, BlkSize), A_reconst("A_reconst", N, BlkSize, BlkSize);
View3DType Ab("Ab", N, k + 1, BlkSize); // Banded matrix
View2DType x0("x0", N, BlkSize), x_ref("x_ref", N, BlkSize), y0("y0", N, BlkSize); // Solutions
const int BlkSize = 3, k = 1;
View3DType Ab("Ab", N, k + 1, BlkSize); // In band storage
View2DType x0("x0", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions

auto h_A_reconst = Kokkos::create_mirror_view(A_reconst);
auto h_x_ref = Kokkos::create_mirror_view(x_ref);
auto h_Ab = Kokkos::create_mirror_view(Ab);
auto h_x_ref = Kokkos::create_mirror_view(x_ref);

for (int ib = 0; ib < N; ib++) {
for (int i = 0; i < BlkSize; i++) {
for (int j = 0; j < BlkSize; j++) {
h_A_reconst(ib, i, j) = i == j ? 4.0 : 1.0;
}
if (std::is_same_v<typename ParamTagType::uplo, KokkosBatched::Uplo::Upper>) {
// Ub
h_Ab(ib, 1, 0) = Kokkos::sqrt(4.0);
h_Ab(ib, 0, 1) = 1.0 / h_Ab(ib, 1, 0);
h_Ab(ib, 1, 1) = Kokkos::sqrt(4.0 - h_Ab(ib, 0, 1) * h_Ab(ib, 0, 1));
h_Ab(ib, 0, 2) = 1.0 / h_Ab(ib, 1, 1);
h_Ab(ib, 1, 2) = Kokkos::sqrt(4.0 - h_Ab(ib, 0, 2) * h_Ab(ib, 0, 2));
} else {
// Lb
h_Ab(ib, 0, 0) = Kokkos::sqrt(4.0);
h_Ab(ib, 1, 0) = 1.0 / h_Ab(ib, 0, 0);
h_Ab(ib, 0, 1) = Kokkos::sqrt(4.0 - h_Ab(ib, 1, 0) * h_Ab(ib, 1, 0));
h_Ab(ib, 1, 1) = 1.0 / h_Ab(ib, 0, 1);
h_Ab(ib, 0, 2) = Kokkos::sqrt(4.0 - h_Ab(ib, 1, 1) * h_Ab(ib, 1, 1));
}

h_x_ref(ib, 0) = 3.0 / 14.0;
Expand All @@ -166,15 +182,7 @@ void impl_test_batched_pbtrs_analytical(const int N) {
Kokkos::fence();

Kokkos::deep_copy(x0, ScalarType(1.0));
Kokkos::deep_copy(A_reconst, h_A_reconst);

// Create banded triangluar matrix in normal and banded storage
using ArgUplo = typename ParamTagType::uplo;
create_banded_pds_matrix<View3DType, View3DType, ArgUplo>(A_reconst, A, k, false);
create_banded_triangular_matrix<View3DType, View3DType, ArgUplo>(A_reconst, Ab, k, true);

// Factorize with Pbtrf: A = U**H * U or A = L * L**H
Functor_BatchedSerialPbtrf<DeviceType, View3DType, ParamTagType, AlgoTagType>(Ab).run();
Kokkos::deep_copy(Ab, h_Ab);

// pbtrs (Note, Ab is a factorized matrix of A)
auto info = Functor_BatchedSerialPbtrs<DeviceType, View3DType, View2DType, ParamTagType, AlgoTagType>(Ab, x0).run();
Expand All @@ -194,13 +202,16 @@ void impl_test_batched_pbtrs_analytical(const int N) {
}
}

template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
/// \brief Implementation details of batched pbtrs test
/// Confirm A * x = b, where
/// Confirm A * x = b, where A is a real symmetric positive definitie
/// or complex Hermitian band matrix. A is storead in a band storage.
/// A must be factorized as A=U**H*U or A=L*L**H (Cholesky factorization)
/// by pbtrf.
///
/// \param N [in] Batch size of RHS (banded matrix can also be batched matrix)
/// \param k [in] Number of superdiagonals or subdiagonals of matrix A
/// \param BlkSize [in] Block size of matrix A
template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
void impl_test_batched_pbtrs(const int N, const int k, const int BlkSize) {
using ats = typename Kokkos::ArithTraits<ScalarType>;
using RealType = typename ats::mag_type;
Expand Down Expand Up @@ -293,3 +304,63 @@ int test_batched_pbtrs() {

return 0;
}

#if defined(KOKKOSKERNELS_INST_FLOAT)
TEST_F(TestCategory, test_batched_pbtrs_l_float) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, float, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_float) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, float, param_tag_type, algo_tag_type>();
}
#endif

#if defined(KOKKOSKERNELS_INST_DOUBLE)
TEST_F(TestCategory, test_batched_pbtrs_l_double) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, double, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_double) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, double, param_tag_type, algo_tag_type>();
}
#endif

#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT)
TEST_F(TestCategory, test_batched_pbtrs_l_fcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, Kokkos::complex<float>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_fcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, Kokkos::complex<float>, param_tag_type, algo_tag_type>();
}
#endif

#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE)
TEST_F(TestCategory, test_batched_pbtrs_l_dcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, Kokkos::complex<double>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_dcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, Kokkos::complex<double>, param_tag_type, algo_tag_type>();
}
#endif
45 changes: 0 additions & 45 deletions batched/dense/unit_test/Test_Batched_SerialPbtrs_Complex.hpp

This file was deleted.

Loading
Loading