From 66c753dcb8a56ed4bf0a307bf7df540d70846660 Mon Sep 17 00:00:00 2001 From: Fritz Goebel Date: Tue, 9 Jun 2020 14:46:27 +0200 Subject: [PATCH] add extract_diagonal for Hybrid --- common/matrix/hybrid_kernels.hpp.inc | 19 +++++++++- core/device_hooks/common_kernels.inc.cpp | 6 ++++ core/matrix/coo.cpp | 2 +- core/matrix/csr.cpp | 2 +- core/matrix/dense.cpp | 2 +- core/matrix/ell.cpp | 2 +- core/matrix/hybrid.cpp | 14 ++++++++ core/matrix/hybrid_kernels.hpp | 9 ++++- core/matrix/sellp.cpp | 2 +- cuda/matrix/coo_kernels.cu | 6 ++-- cuda/matrix/csr_kernels.cu | 6 ++-- cuda/matrix/ell_kernels.cu | 8 ++--- cuda/matrix/hybrid_kernels.cu | 25 +++++++++++++ cuda/test/matrix/hybrid_kernels.cpp | 17 +++++++++ hip/matrix/coo_kernels.hip.cpp | 6 ++-- hip/matrix/csr_kernels.hip.cpp | 8 ++--- hip/matrix/ell_kernels.hip.cpp | 8 ++--- hip/matrix/hybrid_kernels.hip.cpp | 26 ++++++++++++++ hip/test/matrix/hybrid_kernels.hip.cpp | 17 +++++++++ include/ginkgo/core/matrix/coo.hpp | 2 +- include/ginkgo/core/matrix/csr.hpp | 2 +- include/ginkgo/core/matrix/dense.hpp | 2 +- include/ginkgo/core/matrix/ell.hpp | 2 +- include/ginkgo/core/matrix/hybrid.hpp | 7 ++++ include/ginkgo/core/matrix/sellp.hpp | 2 +- omp/matrix/hybrid_kernels.cpp | 24 +++++++++++++ omp/test/matrix/hybrid_kernels.cpp | 17 +++++++++ reference/matrix/hybrid_kernels.cpp | 23 ++++++++++++ reference/test/matrix/hybrid_kernels.cpp | 45 ++++++++++++++++++++++++ 29 files changed, 278 insertions(+), 33 deletions(-) diff --git a/common/matrix/hybrid_kernels.hpp.inc b/common/matrix/hybrid_kernels.hpp.inc index a2c9d2c7ae4..f058da8d1b7 100644 --- a/common/matrix/hybrid_kernels.hpp.inc +++ b/common/matrix/hybrid_kernels.hpp.inc @@ -139,4 +139,21 @@ __global__ __launch_bounds__(default_block_size) void add( } -} // namespace kernel \ No newline at end of file +template +__global__ __launch_bounds__(default_block_size) void coo_extract_diagonal( + size_type nnz, const ValueType *__restrict__ orig_values, + const IndexType *__restrict__ orig_row_idxs, + const IndexType *__restrict__ orig_col_idxs, size_type diag_stride, + ValueType *__restrict__ diag) +{ + const auto tidx = thread::get_thread_id_flat(); + + if (tidx < nnz) { + if (orig_row_idxs[tidx] == orig_col_idxs[tidx]) { + diag[diag_stride * orig_row_idxs[tidx]] = orig_values[tidx]; + } + } +} + + +} // namespace kernel diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 8aa8ca93ead..c22b2ab4bf8 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -730,6 +730,12 @@ GKO_NOT_COMPILED(GKO_HOOK_MODULE); GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_HYBRID_COUNT_NONZEROS_KERNEL); +template +GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL(ValueType, IndexType) +GKO_NOT_COMPILED(GKO_HOOK_MODULE); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL); + } // namespace hybrid diff --git a/core/matrix/coo.cpp b/core/matrix/coo.cpp index 78cf1c31493..9c01866a50f 100644 --- a/core/matrix/coo.cpp +++ b/core/matrix/coo.cpp @@ -217,7 +217,7 @@ void Coo::write(mat_data &data) const template -void Coo::extract_diagonal(Dense *diag) +void Coo::extract_diagonal(Dense *diag) const { GKO_ASSERT_EQ(std::min(this->get_size()[0], this->get_size()[1]), diag->get_size()[0]); diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index 7e9480de2ee..4966faf0b72 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -500,7 +500,7 @@ bool Csr::is_sorted_by_column_index() const template -void Csr::extract_diagonal(Dense *diag) +void Csr::extract_diagonal(Dense *diag) const { GKO_ASSERT_EQ(std::min(this->get_size()[0], this->get_size()[1]), diag->get_size()[0]); diff --git a/core/matrix/dense.cpp b/core/matrix/dense.cpp index 160ded7559c..216906eb2ab 100644 --- a/core/matrix/dense.cpp +++ b/core/matrix/dense.cpp @@ -739,7 +739,7 @@ std::unique_ptr Dense::inverse_column_permute( template -void Dense::extract_diagonal(Dense *diag) +void Dense::extract_diagonal(Dense *diag) const { GKO_ASSERT_EQ(std::min(this->get_size()[0], this->get_size()[1]), diag->get_size()[0]); diff --git a/core/matrix/ell.cpp b/core/matrix/ell.cpp index 3635661fc81..abfb57a34c7 100644 --- a/core/matrix/ell.cpp +++ b/core/matrix/ell.cpp @@ -235,7 +235,7 @@ void Ell::write(mat_data &data) const template -void Ell::extract_diagonal(Dense *diag) +void Ell::extract_diagonal(Dense *diag) const { GKO_ASSERT_EQ(std::min(this->get_size()[0], this->get_size()[1]), diag->get_size()[0]); diff --git a/core/matrix/hybrid.cpp b/core/matrix/hybrid.cpp index adbb48bd1aa..78518cedabf 100644 --- a/core/matrix/hybrid.cpp +++ b/core/matrix/hybrid.cpp @@ -54,6 +54,7 @@ namespace hybrid { GKO_REGISTER_OPERATION(convert_to_dense, hybrid::convert_to_dense); GKO_REGISTER_OPERATION(convert_to_csr, hybrid::convert_to_csr); GKO_REGISTER_OPERATION(count_nonzeros, hybrid::count_nonzeros); +GKO_REGISTER_OPERATION(extract_diagonal, hybrid::extract_diagonal); } // namespace hybrid @@ -263,6 +264,19 @@ void Hybrid::write(mat_data &data) const } +template +void Hybrid::extract_diagonal( + Dense *diag) const +{ + GKO_ASSERT_EQ(std::min(this->get_size()[0], this->get_size()[1]), + diag->get_size()[0]); + GKO_ASSERT_EQ(diag->get_size()[1], 1); + + auto exec = this->get_executor(); + exec->run(hybrid::make_extract_diagonal(this, diag)); +} + + #define GKO_DECLARE_HYBRID_MATRIX(ValueType, IndexType) \ class Hybrid GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_HYBRID_MATRIX); diff --git a/core/matrix/hybrid_kernels.hpp b/core/matrix/hybrid_kernels.hpp index 788fe66e15b..8a3a7cc8563 100644 --- a/core/matrix/hybrid_kernels.hpp +++ b/core/matrix/hybrid_kernels.hpp @@ -59,13 +59,20 @@ namespace kernels { const matrix::Hybrid *source, \ size_type *result) +#define GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL(ValueType, IndexType) \ + void extract_diagonal(std::shared_ptr exec, \ + const matrix::Hybrid *orig, \ + matrix::Dense *diag) + #define GKO_DECLARE_ALL_AS_TEMPLATES \ template \ GKO_DECLARE_HYBRID_CONVERT_TO_DENSE_KERNEL(ValueType, IndexType); \ template \ GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL(ValueType, IndexType); \ template \ - GKO_DECLARE_HYBRID_COUNT_NONZEROS_KERNEL(ValueType, IndexType) + GKO_DECLARE_HYBRID_COUNT_NONZEROS_KERNEL(ValueType, IndexType); \ + template \ + GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL(ValueType, IndexType) namespace omp { diff --git a/core/matrix/sellp.cpp b/core/matrix/sellp.cpp index 2ef1e33e019..7839667e4e7 100644 --- a/core/matrix/sellp.cpp +++ b/core/matrix/sellp.cpp @@ -287,7 +287,7 @@ void Sellp::write(mat_data &data) const template -void Sellp::extract_diagonal(Dense *diag) +void Sellp::extract_diagonal(Dense *diag) const { GKO_ASSERT_EQ(std::min(this->get_size()[0], this->get_size()[1]), diag->get_size()[0]); diff --git a/cuda/matrix/coo_kernels.cu b/cuda/matrix/coo_kernels.cu index 004d0a2365d..99c3e2d5f8a 100644 --- a/cuda/matrix/coo_kernels.cu +++ b/cuda/matrix/coo_kernels.cu @@ -250,9 +250,9 @@ void extract_diagonal(std::shared_ptr exec, const matrix::Coo *orig, matrix::Dense *diag) { - auto nnz = orig->get_num_stored_elements(); - auto diag_size = diag->get_size()[0]; - auto diag_stride = diag->get_stride(); + const auto nnz = orig->get_num_stored_elements(); + const auto diag_size = diag->get_size()[0]; + const auto diag_stride = diag->get_stride(); auto num_blocks = ceildiv(diag_size, default_block_size); const auto orig_values = orig->get_const_values(); diff --git a/cuda/matrix/csr_kernels.cu b/cuda/matrix/csr_kernels.cu index f63d64b5d82..ff46d72952b 100644 --- a/cuda/matrix/csr_kernels.cu +++ b/cuda/matrix/csr_kernels.cu @@ -1035,9 +1035,9 @@ void extract_diagonal(std::shared_ptr exec, const matrix::Csr *orig, matrix::Dense *diag) { - auto nnz = orig->get_num_stored_elements(); - auto diag_size = diag->get_size()[0]; - auto diag_stride = diag->get_stride(); + const auto nnz = orig->get_num_stored_elements(); + const auto diag_size = diag->get_size()[0]; + const auto diag_stride = diag->get_stride(); auto num_blocks = ceildiv(config::warp_size * diag_size, default_block_size); diff --git a/cuda/matrix/ell_kernels.cu b/cuda/matrix/ell_kernels.cu index fb13ce0b7ef..8b8a67f4c5a 100644 --- a/cuda/matrix/ell_kernels.cu +++ b/cuda/matrix/ell_kernels.cu @@ -367,10 +367,10 @@ void extract_diagonal(std::shared_ptr exec, const matrix::Ell *orig, matrix::Dense *diag) { - auto max_nnz_per_row = orig->get_num_stored_elements_per_row(); - auto orig_stride = orig->get_stride(); - auto diag_size = diag->get_size()[0]; - auto diag_stride = diag->get_stride(); + const auto max_nnz_per_row = orig->get_num_stored_elements_per_row(); + const auto orig_stride = orig->get_stride(); + const auto diag_size = diag->get_size()[0]; + const auto diag_stride = diag->get_stride(); auto num_blocks = ceildiv(diag_size, default_block_size); const auto orig_values = orig->get_const_values(); diff --git a/cuda/matrix/hybrid_kernels.cu b/cuda/matrix/hybrid_kernels.cu index d18f2228b2c..be43dc19b97 100644 --- a/cuda/matrix/hybrid_kernels.cu +++ b/cuda/matrix/hybrid_kernels.cu @@ -179,6 +179,31 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_HYBRID_COUNT_NONZEROS_KERNEL); +template +void extract_diagonal(std::shared_ptr exec, + const matrix::Hybrid *orig, + matrix::Dense *diag) +{ + gko::kernels::cuda::ell::extract_diagonal(exec, orig->get_ell(), diag); + + const auto coo_row_idxs = orig->get_const_coo_row_idxs(); + const auto coo_col_idxs = orig->get_const_coo_col_idxs(); + const auto coo_values = orig->get_const_coo_values(); + const auto coo_nnz = orig->get_coo_num_stored_elements(); + + const auto diag_stride = diag->get_stride(); + const auto num_blocks = ceildiv(coo_nnz, default_block_size); + auto diag_values = diag->get_values(); + + kernel::coo_extract_diagonal<<>>( + coo_nnz, as_cuda_type(coo_values), as_cuda_type(coo_row_idxs), + as_cuda_type(coo_col_idxs), diag_stride, as_cuda_type(diag_values)); +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL); + + } // namespace hybrid } // namespace cuda } // namespace kernels diff --git a/cuda/test/matrix/hybrid_kernels.cpp b/cuda/test/matrix/hybrid_kernels.cpp index f3225882021..864077171e8 100644 --- a/cuda/test/matrix/hybrid_kernels.cpp +++ b/cuda/test/matrix/hybrid_kernels.cpp @@ -219,4 +219,21 @@ TEST_F(Hybrid, MoveToCsrIsEquivalentToRef) } +TEST_F(Hybrid, ExtractDiagonalIsquivalentToRef) +{ + set_up_apply_data(); + + auto diag_size = std::min(mtx->get_size()[0], mtx->get_size()[1]); + + auto diag = gko::matrix::Dense<>::create(mtx->get_executor(), + gko::dim<2>(diag_size, 1)); + auto ddiag = gko::matrix::Dense<>::create(dmtx->get_executor(), + gko::dim<2>(diag_size, 1)); + mtx->extract_diagonal(lend(diag)); + dmtx->extract_diagonal(lend(ddiag)); + + GKO_ASSERT_MTX_NEAR(diag.get(), ddiag.get(), 0); +} + + } // namespace diff --git a/hip/matrix/coo_kernels.hip.cpp b/hip/matrix/coo_kernels.hip.cpp index b78ca24d028..5745056055c 100644 --- a/hip/matrix/coo_kernels.hip.cpp +++ b/hip/matrix/coo_kernels.hip.cpp @@ -262,9 +262,9 @@ void extract_diagonal(std::shared_ptr exec, const matrix::Coo *orig, matrix::Dense *diag) { - auto nnz = orig->get_num_stored_elements(); - auto diag_size = diag->get_size()[0]; - auto diag_stride = diag->get_stride(); + const auto nnz = orig->get_num_stored_elements(); + const auto diag_size = diag->get_size()[0]; + const auto diag_stride = diag->get_stride(); auto num_blocks = ceildiv(diag_size, default_block_size); const auto orig_values = orig->get_const_values(); diff --git a/hip/matrix/csr_kernels.hip.cpp b/hip/matrix/csr_kernels.hip.cpp index 1b98a8a9af3..073b67ab919 100644 --- a/hip/matrix/csr_kernels.hip.cpp +++ b/hip/matrix/csr_kernels.hip.cpp @@ -1147,10 +1147,10 @@ void extract_diagonal(std::shared_ptr exec, const matrix::Csr *orig, matrix::Dense *diag) { - auto nnz = orig->get_num_stored_elements(); - auto diag_size = diag->get_size()[0]; - auto diag_stride = diag->get_stride(); - auto num_blocks = + const auto nnz = orig->get_num_stored_elements(); + const auto diag_size = diag->get_size()[0]; + const auto diag_stride = diag->get_stride(); + const auto num_blocks = ceildiv(config::warp_size * diag_size, default_block_size); const auto orig_values = orig->get_const_values(); diff --git a/hip/matrix/ell_kernels.hip.cpp b/hip/matrix/ell_kernels.hip.cpp index 70502df7b0c..6a268dd51c8 100644 --- a/hip/matrix/ell_kernels.hip.cpp +++ b/hip/matrix/ell_kernels.hip.cpp @@ -376,10 +376,10 @@ void extract_diagonal(std::shared_ptr exec, const matrix::Ell *orig, matrix::Dense *diag) { - auto max_nnz_per_row = orig->get_num_stored_elements_per_row(); - auto orig_stride = orig->get_stride(); - auto diag_size = diag->get_size()[0]; - auto diag_stride = diag->get_stride(); + const auto max_nnz_per_row = orig->get_num_stored_elements_per_row(); + const auto orig_stride = orig->get_stride(); + const auto diag_size = diag->get_size()[0]; + const auto diag_stride = diag->get_stride(); auto num_blocks = ceildiv(diag_size, default_block_size); const auto orig_values = orig->get_const_values(); diff --git a/hip/matrix/hybrid_kernels.hip.cpp b/hip/matrix/hybrid_kernels.hip.cpp index 558dc22e83f..74573e991fb 100644 --- a/hip/matrix/hybrid_kernels.hip.cpp +++ b/hip/matrix/hybrid_kernels.hip.cpp @@ -186,6 +186,32 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_HYBRID_COUNT_NONZEROS_KERNEL); +template +void extract_diagonal(std::shared_ptr exec, + const matrix::Hybrid *orig, + matrix::Dense *diag) +{ + gko::kernels::hip::ell::extract_diagonal(exec, orig->get_ell(), diag); + + const auto coo_row_idxs = orig->get_const_coo_row_idxs(); + const auto coo_col_idxs = orig->get_const_coo_col_idxs(); + const auto coo_values = orig->get_const_coo_values(); + const auto coo_nnz = orig->get_coo_num_stored_elements(); + + const auto diag_stride = diag->get_stride(); + const auto num_blocks = ceildiv(coo_nnz, default_block_size); + auto diag_values = diag->get_values(); + + hipLaunchKernelGGL(num_blocks, default_block_size, 0, 0, coo_nnz, + as_cuda_type(coo_values), as_cuda_type(coo_row_idxs), + as_cuda_type(coo_col_idxs), diag_stride, + as_cuda_type(diag_values)); +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL); + + } // namespace hybrid } // namespace hip } // namespace kernels diff --git a/hip/test/matrix/hybrid_kernels.hip.cpp b/hip/test/matrix/hybrid_kernels.hip.cpp index 83d2cc37c86..0f134e1d0b4 100644 --- a/hip/test/matrix/hybrid_kernels.hip.cpp +++ b/hip/test/matrix/hybrid_kernels.hip.cpp @@ -219,4 +219,21 @@ TEST_F(Hybrid, MoveToCsrIsEquivalentToRef) } +TEST_F(Hybrid, ExtractDiagonalIsquivalentToRef) +{ + set_up_apply_data(); + + auto diag_size = std::min(mtx->get_size()[0], mtx->get_size()[1]); + + auto diag = gko::matrix::Dense<>::create(mtx->get_executor(), + gko::dim<2>(diag_size, 1)); + auto ddiag = gko::matrix::Dense<>::create(dmtx->get_executor(), + gko::dim<2>(diag_size, 1)); + mtx->extract_diagonal(lend(diag)); + dmtx->extract_diagonal(lend(ddiag)); + + GKO_ASSERT_MTX_NEAR(diag.get(), ddiag.get(), 0); +} + + } // namespace diff --git a/include/ginkgo/core/matrix/coo.hpp b/include/ginkgo/core/matrix/coo.hpp index a4fa669ecc1..3a2f402e29a 100644 --- a/include/ginkgo/core/matrix/coo.hpp +++ b/include/ginkgo/core/matrix/coo.hpp @@ -251,7 +251,7 @@ class Coo : public EnableLinOp>, * * @param diag the vector into which the diagonal will be written */ - void extract_diagonal(Dense *diag); + void extract_diagonal(Dense *diag) const; protected: /** diff --git a/include/ginkgo/core/matrix/csr.hpp b/include/ginkgo/core/matrix/csr.hpp index f408c62ebc3..be5095dd78b 100644 --- a/include/ginkgo/core/matrix/csr.hpp +++ b/include/ginkgo/core/matrix/csr.hpp @@ -748,7 +748,7 @@ class Csr : public EnableLinOp>, * * @param diag the vector into which the diagonal will be written */ - void extract_diagonal(Dense *diag); + void extract_diagonal(Dense *diag) const; protected: /** diff --git a/include/ginkgo/core/matrix/dense.hpp b/include/ginkgo/core/matrix/dense.hpp index 982b3f6211b..6484078951b 100644 --- a/include/ginkgo/core/matrix/dense.hpp +++ b/include/ginkgo/core/matrix/dense.hpp @@ -262,7 +262,7 @@ class Dense : public EnableLinOp>, * * @param diag the vector into which the diagonal will be written */ - void extract_diagonal(Dense *diag); + void extract_diagonal(Dense *diag) const; /** diff --git a/include/ginkgo/core/matrix/ell.hpp b/include/ginkgo/core/matrix/ell.hpp index 425ac5ee915..753e51865ac 100644 --- a/include/ginkgo/core/matrix/ell.hpp +++ b/include/ginkgo/core/matrix/ell.hpp @@ -223,7 +223,7 @@ class Ell : public EnableLinOp>, * * @param diag the vector into which the diagonal will be written */ - void extract_diagonal(Dense *diag); + void extract_diagonal(Dense *diag) const; protected: /** diff --git a/include/ginkgo/core/matrix/hybrid.hpp b/include/ginkgo/core/matrix/hybrid.hpp index 5aa6887b7fb..8089ff3e57a 100644 --- a/include/ginkgo/core/matrix/hybrid.hpp +++ b/include/ginkgo/core/matrix/hybrid.hpp @@ -583,6 +583,13 @@ class Hybrid return *this; } + /** + * Extracts the diagonal entries of the matrix into a vector. + * + * @param diag the vector into which the diagonal will be written + */ + void extract_diagonal(Dense *diag) const; + protected: /** * Creates an uninitialized Hybrid matrix of specified method. diff --git a/include/ginkgo/core/matrix/sellp.hpp b/include/ginkgo/core/matrix/sellp.hpp index 7564a7aa45c..928b5a81e72 100644 --- a/include/ginkgo/core/matrix/sellp.hpp +++ b/include/ginkgo/core/matrix/sellp.hpp @@ -275,7 +275,7 @@ class Sellp : public EnableLinOp>, * * @param diag the vector into which the diagonal will be written */ - void extract_diagonal(Dense *diag); + void extract_diagonal(Dense *diag) const; protected: /** diff --git a/omp/matrix/hybrid_kernels.cpp b/omp/matrix/hybrid_kernels.cpp index 8282d7c7ab8..c8c377c59d8 100644 --- a/omp/matrix/hybrid_kernels.cpp +++ b/omp/matrix/hybrid_kernels.cpp @@ -210,6 +210,30 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_HYBRID_COUNT_NONZEROS_KERNEL); +template +void extract_diagonal(std::shared_ptr exec, + const matrix::Hybrid *orig, + matrix::Dense *diag) +{ + gko::kernels::omp::ell::extract_diagonal(exec, orig->get_ell(), diag); + + const auto coo_row_idxs = orig->get_const_coo_row_idxs(); + const auto coo_col_idxs = orig->get_const_coo_col_idxs(); + const auto coo_values = orig->get_const_coo_values(); + const auto coo_nnz = orig->get_coo_num_stored_elements(); + +#pragma omp parallel for + for (size_type i = 0; i < coo_nnz; i++) { + if (coo_row_idxs[i] == coo_col_idxs[i]) { + diag->at(coo_row_idxs[i], 0) = coo_values[i]; + } + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL); + + } // namespace hybrid } // namespace omp } // namespace kernels diff --git a/omp/test/matrix/hybrid_kernels.cpp b/omp/test/matrix/hybrid_kernels.cpp index 47e809fc5bd..0910e1a6059 100644 --- a/omp/test/matrix/hybrid_kernels.cpp +++ b/omp/test/matrix/hybrid_kernels.cpp @@ -218,4 +218,21 @@ TEST_F(Hybrid, MoveToCsrIsEquivalentToRef) } +TEST_F(Hybrid, ExtractDiagonalIsquivalentToRef) +{ + set_up_apply_data(); + + auto diag_size = std::min(mtx->get_size()[0], mtx->get_size()[1]); + + auto diag = gko::matrix::Dense<>::create(mtx->get_executor(), + gko::dim<2>(diag_size, 1)); + auto ddiag = gko::matrix::Dense<>::create(dmtx->get_executor(), + gko::dim<2>(diag_size, 1)); + mtx->extract_diagonal(lend(diag)); + dmtx->extract_diagonal(lend(ddiag)); + + GKO_ASSERT_MTX_NEAR(diag.get(), ddiag.get(), 0); +} + + } // namespace diff --git a/reference/matrix/hybrid_kernels.cpp b/reference/matrix/hybrid_kernels.cpp index 74e126334e2..f128bfbaa85 100644 --- a/reference/matrix/hybrid_kernels.cpp +++ b/reference/matrix/hybrid_kernels.cpp @@ -156,6 +156,29 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_HYBRID_COUNT_NONZEROS_KERNEL); +template +void extract_diagonal(std::shared_ptr exec, + const matrix::Hybrid *orig, + matrix::Dense *diag) +{ + gko::kernels::reference::ell::extract_diagonal(exec, orig->get_ell(), diag); + + const auto coo_row_idxs = orig->get_const_coo_row_idxs(); + const auto coo_col_idxs = orig->get_const_coo_col_idxs(); + const auto coo_values = orig->get_const_coo_values(); + const auto coo_nnz = orig->get_coo_num_stored_elements(); + + for (size_type i = 0; i < coo_nnz; i++) { + if (coo_row_idxs[i] == coo_col_idxs[i]) { + diag->at(coo_row_idxs[i], 0) = coo_values[i]; + } + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_HYBRID_EXTRACT_DIAGONAL_KERNEL); + + } // namespace hybrid } // namespace reference } // namespace kernels diff --git a/reference/test/matrix/hybrid_kernels.cpp b/reference/test/matrix/hybrid_kernels.cpp index d00502ca2fb..4d8ada06aaf 100644 --- a/reference/test/matrix/hybrid_kernels.cpp +++ b/reference/test/matrix/hybrid_kernels.cpp @@ -518,4 +518,49 @@ TYPED_TEST(Hybrid, MovesWithStrideToDense) } +TYPED_TEST(Hybrid, ExtractsDiagonal) +{ + auto matrix = this->mtx1->clone(); + auto exec = matrix->get_executor(); + using T = typename TestFixture::value_type; + auto diag = gko::matrix::Dense::create( + exec, + gko::dim<2>(std::min(matrix->get_size()[0], matrix->get_size()[1]), 1)); + + matrix->extract_diagonal(gko::lend(diag)); + + GKO_ASSERT_MTX_NEAR(diag, l({{1.}, {5.}}), 0.0); +} + + +TYPED_TEST(Hybrid, ExtractsDiagonalWithStride) +{ + auto matrix = this->mtx2->clone(); + auto exec = matrix->get_executor(); + using T = typename TestFixture::value_type; + auto diag = gko::matrix::Dense::create( + exec, + gko::dim<2>(std::min(matrix->get_size()[0], matrix->get_size()[1]), 1)); + + matrix->extract_diagonal(gko::lend(diag)); + + GKO_ASSERT_MTX_NEAR(diag, l({{1.}, {5.}}), 0.0); +} + + +TYPED_TEST(Hybrid, ExtractsDiagonalWithoutZeros) +{ + auto matrix = this->mtx3->clone(); + auto exec = matrix->get_executor(); + using T = typename TestFixture::value_type; + auto diag = gko::matrix::Dense::create( + exec, + gko::dim<2>(std::min(matrix->get_size()[0], matrix->get_size()[1]), 1)); + + matrix->extract_diagonal(gko::lend(diag)); + + GKO_ASSERT_MTX_NEAR(diag, l({{1.}, {5.}}), 0.0); +} + + } // namespace