Skip to content

Commit

Permalink
Some more edge case tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Nov 10, 2021
1 parent 77087ad commit 57cc8d8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
24 changes: 13 additions & 11 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1651,15 +1651,16 @@ GKO_ENABLE_DEFAULT_HOST(calc_nnz_in_span, calc_nnz_in_span);


template <typename ValueType, typename IndexType>
void compute_submat(size_type num_rows, size_type num_cols, size_type num_nnz,
size_type row_offset, size_type col_offset,
const IndexType* __restrict__ src_row_ptrs,
const IndexType* __restrict__ src_col_idxs,
const ValueType* __restrict__ src_values,
const IndexType* __restrict__ res_row_ptrs,
IndexType* __restrict__ res_col_idxs,
ValueType* __restrict__ res_values,
sycl::nd_item<3> item_ct1)
void compute_submatrix_idxs_and_vals(size_type num_rows, size_type num_cols,
size_type num_nnz, size_type row_offset,
size_type col_offset,
const IndexType* __restrict__ src_row_ptrs,
const IndexType* __restrict__ src_col_idxs,
const ValueType* __restrict__ src_values,
const IndexType* __restrict__ res_row_ptrs,
IndexType* __restrict__ res_col_idxs,
ValueType* __restrict__ res_values,
sycl::nd_item<3> item_ct1)
{
const auto tidx = thread::get_thread_id_flat(item_ct1);
if (tidx < num_rows) {
Expand All @@ -1676,7 +1677,8 @@ void compute_submat(size_type num_rows, size_type num_cols, size_type num_nnz,
}
}

GKO_ENABLE_DEFAULT_HOST(compute_submat, compute_submat);
GKO_ENABLE_DEFAULT_HOST(compute_submatrix_idxs_and_vals,
compute_submatrix_idxs_and_vals);


} // namespace kernel
Expand Down Expand Up @@ -1718,7 +1720,7 @@ void compute_submatrix(std::shared_ptr<const DefaultExecutor> exec,
const auto num_nnz = source->get_num_stored_elements();
auto grid_dim = ceildiv(num_rows, default_block_size);
auto block_dim = default_block_size;
kernel::compute_submat(
kernel::compute_submatrix_idxs_and_vals(
grid_dim, block_dim, 0, exec->get_queue(), num_rows, num_cols, num_nnz,
row_offset, col_offset, source->get_const_row_ptrs(),
source->get_const_col_idxs(), source->get_const_values(),
Expand Down
38 changes: 29 additions & 9 deletions reference/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1649,15 +1649,13 @@ TYPED_TEST(Csr, CanGetSubmatrix2)
using T = typename TestFixture::value_type;
auto mat = gko::initialize<Mtx>(
{
// clang-format off
I<T>{1.0, 3.0, 4.5, 0.0, 2.0}, // 0
I<T>{1.0, 0.0, 4.5, 7.5, 3.0}, // 1
I<T>{0.0, 3.0, 4.5, 0.0, 2.0}, // 2
I<T>{0.0,-1.0, 2.5, 0.0, 2.0}, // 3
I<T>{1.0, 0.0,-1.0, 3.5, 1.0}, // 4
I<T>{0.0, 1.0, 0.0, 0.0, 2.0}, // 5
I<T>{0.0, 3.0, 0.0, 7.5, 1.0} // 6
// clang-format on
I<T>{1.0, 3.0, 4.5, 0.0, 2.0}, // 0
I<T>{1.0, 0.0, 4.5, 7.5, 3.0}, // 1
I<T>{0.0, 3.0, 4.5, 0.0, 2.0}, // 2
I<T>{0.0, -1.0, 2.5, 0.0, 2.0}, // 3
I<T>{1.0, 0.0, -1.0, 3.5, 1.0}, // 4
I<T>{0.0, 1.0, 0.0, 0.0, 2.0}, // 5
I<T>{0.0, 3.0, 0.0, 7.5, 1.0} // 6
},
this->exec);
ASSERT_EQ(mat->get_num_stored_elements(), 23);
Expand Down Expand Up @@ -1698,6 +1696,28 @@ TYPED_TEST(Csr, CanGetSubmatrix2)

GKO_EXPECT_MTX_NEAR(sub_mat4.get(), ref4.get(), 0.0);
}
{
auto sub_mat5 = mat->create_submatrix(gko::span(0, 7), gko::span(0, 5));
auto ref5 = gko::initialize<Mtx>(
{
I<T>{1.0, 3.0, 4.5, 0.0, 2.0}, // 0
I<T>{1.0, 0.0, 4.5, 7.5, 3.0}, // 1
I<T>{0.0, 3.0, 4.5, 0.0, 2.0}, // 2
I<T>{0.0, -1.0, 2.5, 0.0, 2.0}, // 3
I<T>{1.0, 0.0, -1.0, 3.5, 1.0}, // 4
I<T>{0.0, 1.0, 0.0, 0.0, 2.0}, // 5
I<T>{0.0, 3.0, 0.0, 7.5, 1.0} // 6
},
this->exec);

GKO_EXPECT_MTX_NEAR(sub_mat5.get(), ref5.get(), 0.0);
}
{
auto sub_mat7 = mat->create_submatrix(gko::span(0, 1), gko::span(0, 1));
auto ref7 = gko::initialize<Mtx>({I<T>{1.0}}, this->exec);

GKO_EXPECT_MTX_NEAR(sub_mat7.get(), ref7.get(), 0.0);
}
}


Expand Down

0 comments on commit 57cc8d8

Please sign in to comment.