Skip to content

Commit

Permalink
Review updates.
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai <yhmtsai@gmail.com>
  • Loading branch information
pratikvn and yhmtsai committed Mar 17, 2022
1 parent 091642c commit 7add1a0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
6 changes: 5 additions & 1 deletion omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ void compute_submatrix_from_index_set(
auto num_row_subsets = row_index_set.get_num_subsets();
auto row_subset_begin = row_index_set.get_subsets_begin();
auto row_subset_end = row_index_set.get_subsets_end();
auto row_superset_indices = row_index_set.get_superset_indices();
auto res_row_ptrs = result->get_row_ptrs();
auto res_col_idxs = result->get_col_idxs();
auto res_values = result->get_values();
Expand All @@ -850,10 +851,13 @@ void compute_submatrix_from_index_set(
const auto src_col_idxs = source->get_const_col_idxs();
const auto src_values = source->get_const_values();

size_type res_nnz = 0;
#pragma unroll
for (size_type set = 0; set < num_row_subsets; ++set) {
for (auto row = row_subset_begin[set]; row < row_subset_end[set];
++row) {
auto local_row =
row - row_subset_begin[set] + row_superset_indices[set];
auto res_nnz = res_row_ptrs[local_row];
for (size_type i = src_ptrs[row]; i < src_ptrs[row + 1]; ++i) {
auto index = src_col_idxs[i];
if (index >= col_index_set.get_size()) {
Expand Down
26 changes: 26 additions & 0 deletions omp/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,32 @@ TEST_F(Csr, ComputeSubmatrixIsEquivalentToRef)
}


TEST_F(Csr, CalculateNnzPerRowInIndexSetIsEquivalentToRef)
{
using Mtx = gko::matrix::Csr<>;
using IndexType = int;
using ValueType = double;
set_up_mat_data();
gko::IndexSet<IndexType> rset{
this->ref, {42, 7, 8, 9, 10, 22, 25, 26, 34, 35, 36, 51}};
gko::IndexSet<IndexType> cset{this->ref,
{42, 22, 24, 26, 28, 30, 81, 82, 83, 88}};
gko::IndexSet<IndexType> drset(this->omp, rset);
gko::IndexSet<IndexType> dcset(this->omp, cset);
auto size = this->mtx2->get_size();
auto row_nnz = gko::Array<int>(this->ref, rset.get_num_elems() + 1);
row_nnz.fill(gko::zero<int>());
auto drow_nnz = gko::Array<int>(this->omp, row_nnz);

gko::kernels::reference::csr::calculate_nonzeros_per_row_in_index_set(
this->ref, this->mtx2.get(), rset, cset, &row_nnz);
gko::kernels::omp::csr::calculate_nonzeros_per_row_in_index_set(
this->omp, this->dmtx2.get(), drset, dcset, &drow_nnz);

GKO_ASSERT_ARRAY_EQ(row_nnz, drow_nnz);
}


TEST_F(Csr, ComputeSubmatrixFromIndexSetIsEquivalentToRef)
{
using Mtx = gko::matrix::Csr<>;
Expand Down
2 changes: 2 additions & 0 deletions reference/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,7 +1781,9 @@ TYPED_TEST(Csr, CanGetSubmatrixWithIndexSet)
I<T>{0.0, 3.0, 0.0, 7.5, 1.0} // 6
},
this->exec);

ASSERT_EQ(mat->get_num_stored_elements(), 23);

{
SCOPED_TRACE("Small square 2x2");
auto row_set = gko::IndexSet<index_type>(this->exec, {0, 1});
Expand Down

0 comments on commit 7add1a0

Please sign in to comment.