Skip to content

Commit

Permalink
review updates
Browse files Browse the repository at this point in the history
Co-authored-by: Yu-Hsiang Tsai <yhmtsai@gmail.com>
  • Loading branch information
pratikvn and yhmtsai committed Jan 23, 2024
1 parent e243cf0 commit 281010c
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 12 deletions.
10 changes: 6 additions & 4 deletions common/cuda_hip/matrix/batch_dense_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,11 @@ __device__ __forceinline__ void scale_add(
const gko::batch::matrix::dense::batch_item<const ValueType>& mat,
const gko::batch::matrix::dense::batch_item<ValueType>& in_out)
{
// TODO: add stride support
for (int iz = threadIdx.x; iz < mat.num_rows * mat.num_cols;
iz += blockDim.x) {
const int row = iz / mat.stride;
const int col = iz % mat.stride;
const int row = iz / mat.num_cols;
const int col = iz % mat.num_cols;
in_out.values[row * in_out.stride + col] =
alpha * in_out.values[row * in_out.stride + col] +
mat.values[row * mat.stride + col];
Expand Down Expand Up @@ -213,10 +214,11 @@ __device__ __forceinline__ void add_scaled_identity(
const ValueType alpha, const ValueType beta,
const gko::batch::matrix::dense::batch_item<ValueType>& mat)
{
// TODO: add stride support
for (int iz = threadIdx.x; iz < mat.num_rows * mat.num_cols;
iz += blockDim.x) {
const int row = iz / mat.stride;
const int col = iz % mat.stride;
const int row = iz / mat.num_cols;
const int col = iz % mat.num_cols;
mat.values[row * mat.stride + col] *= beta;
if (row == col) {
mat.values[row * mat.stride + col] += alpha;
Expand Down
8 changes: 4 additions & 4 deletions core/matrix/batch_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ void Csr<ValueType, IndexType>::add_scaled_identity(

auto csr_mat = this->create_const_view_for_item(0);

bool has_diags{false};
exec->run(csr::make_check_diagonal_entries(csr_mat.get(), has_diags));
if (!has_diags) {
bool has_all_diags{false};
exec->run(csr::make_check_diagonal_entries(csr_mat.get(), has_all_diags));
if (!has_all_diags) {
GKO_UNSUPPORTED_MATRIX_PROPERTY(
"The matrix has one or more structurally zero diagonal entries!");
"The matrix is missing one or more diagonal entries!");
}
exec->run(csr::make_add_scaled_identity(
make_temporary_clone(exec, alpha).get(),
Expand Down
8 changes: 4 additions & 4 deletions core/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ void Ell<ValueType, IndexType>::add_scaled_identity(
auto csr_mat = gko::matrix::Csr<ValueType, IndexType>::create(exec);
this->create_const_view_for_item(0)->convert_to(csr_mat);

bool has_diags{false};
exec->run(ell::make_check_diagonal_entries(csr_mat.get(), has_diags));
if (!has_diags) {
bool has_all_diags{false};
exec->run(ell::make_check_diagonal_entries(csr_mat.get(), has_all_diags));
if (!has_all_diags) {
GKO_UNSUPPORTED_MATRIX_PROPERTY(
"The matrix has one or more structurally zero diagonal entries!");
"The matrix is missing one or more diagonal entries!");
}
exec->run(ell::make_add_scaled_identity(
make_temporary_clone(exec, alpha).get(),
Expand Down
21 changes: 21 additions & 0 deletions reference/test/matrix/batch_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,27 @@ TYPED_TEST(Csr, CanAddScaledIdentity)
}


TYPED_TEST(Csr, CanAddScaledIdentityForRectangular)
{
using BMtx = typename TestFixture::BMtx;
using BMVec = typename TestFixture::BMVec;
auto alpha = gko::batch::initialize<BMVec>({{2.0}, {-1.0}}, this->exec);
auto beta = gko::batch::initialize<BMVec>({{3.0}, {-2.0}}, this->exec);
auto mat =
gko::batch::initialize<BMtx>({{{1.0, 2.0, 0.0}, {3.0, 1.0, 1.0}},
{{2.0, -2.0, 0.0}, {1.0, -1.0, 2.0}}},
this->exec, 5);

mat->add_scaled_identity(alpha, beta);

auto result_mat =
gko::batch::initialize<BMtx>({{{5.0, 6.0, 0.0}, {9.0, 5.0, 3.0}},
{{-5.0, 4.0, 0.0}, {-2.0, 1.0, -4.0}}},
this->exec, 5);
GKO_ASSERT_BATCH_MTX_NEAR(mat.get(), result_mat.get(), 0.);
}


TYPED_TEST(Csr, AddScaledIdentityFailsOnMatrixWithoutDiagonal)
{
using BMtx = typename TestFixture::BMtx;
Expand Down
21 changes: 21 additions & 0 deletions reference/test/matrix/batch_dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,27 @@ TYPED_TEST(Dense, CanAddScaledIdentity)
}


TYPED_TEST(Dense, CanAddScaledIdentityRectangular)
{
using BMtx = typename TestFixture::BMtx;
using BMVec = typename TestFixture::BMVec;
auto alpha = gko::batch::initialize<BMVec>({{2.0}, {-1.0}}, this->exec);
auto beta = gko::batch::initialize<BMVec>({{3.0}, {-2.0}}, this->exec);
auto mat =
gko::batch::initialize<BMtx>({{{1.0, 2.0, 0.0}, {3.0, 1.0, 1.0}},
{{2.0, -2.0, 0.0}, {1.0, -1.0, 2.0}}},
this->exec);

mat->add_scaled_identity(alpha, beta);

auto result_mat =
gko::batch::initialize<BMtx>({{{5.0, 6.0, 0.0}, {9.0, 5.0, 3.0}},
{{-5.0, 4.0, 0.0}, {-2.0, 1.0, -4.0}}},
this->exec);
GKO_ASSERT_BATCH_MTX_NEAR(mat.get(), result_mat.get(), 0.);
}


TYPED_TEST(Dense, ApplyFailsOnWrongNumberOfResultCols)
{
using BMVec = typename TestFixture::BMVec;
Expand Down
21 changes: 21 additions & 0 deletions reference/test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,27 @@ TYPED_TEST(Ell, CanAddScaledIdentity)
}


TYPED_TEST(Ell, CanAddScaledIdentityForRectangular)
{
using BMtx = typename TestFixture::BMtx;
using BMVec = typename TestFixture::BMVec;
auto alpha = gko::batch::initialize<BMVec>({{2.0}, {-1.0}}, this->exec);
auto beta = gko::batch::initialize<BMVec>({{3.0}, {-2.0}}, this->exec);
auto mat =
gko::batch::initialize<BMtx>({{{1.0, 2.0, 0.0}, {0.0, 1.0, 1.0}},
{{2.0, -2.0, 0.0}, {0.0, -1.0, 2.0}}},
this->exec, 2);

mat->add_scaled_identity(alpha, beta);

auto result_mat =
gko::batch::initialize<BMtx>({{{5.0, 6.0, 0.0}, {0.0, 5.0, 3.0}},
{{-5.0, 4.0, 0.0}, {0.0, 1.0, -4.0}}},
this->exec, 2);
GKO_ASSERT_BATCH_MTX_NEAR(mat.get(), result_mat.get(), 0.);
}


TYPED_TEST(Ell, AddScaledIdentityFailsOnMatrixWithoutDiagonal)
{
using BMtx = typename TestFixture::BMtx;
Expand Down
11 changes: 11 additions & 0 deletions test/matrix/batch_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,14 @@ TEST_F(Csr, AddScaledIdentityIsEquivalentToRef)

GKO_ASSERT_BATCH_MTX_NEAR(dmat, mat, r<value_type>::value);
}


TEST_F(Csr, AddScaledIdentityWithRecMatIsEquivalentToRef)
{
set_up_apply_data(2, 5, 151, 148, true);

mat->add_scaled_identity(alpha, beta);
dmat->add_scaled_identity(dalpha, dbeta);

GKO_ASSERT_BATCH_MTX_NEAR(dmat, mat, r<value_type>::value);
}
11 changes: 11 additions & 0 deletions test/matrix/batch_dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,14 @@ TEST_F(Dense, AddScaledIdentityIsEquivalentToRef)

GKO_ASSERT_BATCH_MTX_NEAR(dmat, mat, r<value_type>::value);
}


TEST_F(Dense, AddScaledIdentityRectMatIsEquivalentToRef)
{
set_up_apply_data(42, 40, 15);

mat->add_scaled_identity(alpha, beta);
dmat->add_scaled_identity(dalpha, dbeta);

GKO_ASSERT_BATCH_MTX_NEAR(dmat, mat, r<value_type>::value);
}
11 changes: 11 additions & 0 deletions test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,14 @@ TEST_F(Ell, AddScaledIdentityIsEquivalentToRef)

GKO_ASSERT_BATCH_MTX_NEAR(dmat, mat, r<value_type>::value);
}


TEST_F(Ell, AddScaledIdentityWithRecMatIsEquivalentToRef)
{
set_up_apply_data(2, 5, 151, 155, true);

mat->add_scaled_identity(alpha, beta);
dmat->add_scaled_identity(dalpha, dbeta);

GKO_ASSERT_BATCH_MTX_NEAR(dmat, mat, r<value_type>::value);
}

0 comments on commit 281010c

Please sign in to comment.