Skip to content

Commit

Permalink
Add tests for mixed precision ELL
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzgoebel committed Mar 12, 2021
1 parent 1723c8b commit f4265c8
Show file tree
Hide file tree
Showing 4 changed files with 990 additions and 2 deletions.
222 changes: 221 additions & 1 deletion cuda/test/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Ell : public ::testing::Test {
protected:
using Mtx = gko::matrix::Ell<>;
using Vec = gko::matrix::Dense<>;
using Vec2 = gko::matrix::Dense<float>;
using ComplexVec = gko::matrix::Dense<std::complex<double>>;

Ell() : rand_engine(42) {}
Expand Down Expand Up @@ -92,38 +93,59 @@ class Ell : public ::testing::Test {
stride);
mtx->copy_from(gen_mtx(num_rows, num_cols));
expected = gen_mtx(num_rows, num_vectors);
expected2 = Vec2::create(ref);
expected2->copy_from(expected.get());
y = gen_mtx(num_cols, num_vectors);
y2 = Vec2::create(ref);
y2->copy_from(y.get());
alpha = gko::initialize<Vec>({2.0}, ref);
alpha2 = gko::initialize<Vec2>({2.0}, ref);
beta = gko::initialize<Vec>({-1.0}, ref);
beta2 = gko::initialize<Vec2>({-1.0}, ref);
dmtx = Mtx::create(cuda);
dmtx->copy_from(mtx.get());
dresult = Vec::create(cuda);
dresult->copy_from(expected.get());
dresult2 = Vec2::create(cuda);
dresult2->copy_from(expected2.get());
dy = Vec::create(cuda);
dy->copy_from(y.get());
dy2 = Vec2::create(cuda);
dy2->copy_from(y2.get());
dalpha = Vec::create(cuda);
dalpha->copy_from(alpha.get());
dalpha2 = Vec2::create(cuda);
dalpha2->copy_from(alpha2.get());
dbeta = Vec::create(cuda);
dbeta->copy_from(beta.get());
dbeta2 = Vec2::create(cuda);
dbeta2->copy_from(beta2.get());
}


std::shared_ptr<gko::ReferenceExecutor> ref;
std::shared_ptr<const gko::CudaExecutor> cuda;

std::ranlux48 rand_engine;

std::unique_ptr<Mtx> mtx;
std::unique_ptr<Vec> expected;
std::unique_ptr<Vec2> expected2;
std::unique_ptr<Vec> y;
std::unique_ptr<Vec2> y2;
std::unique_ptr<Vec> alpha;
std::unique_ptr<Vec2> alpha2;
std::unique_ptr<Vec> beta;
std::unique_ptr<Vec2> beta2;

std::unique_ptr<Mtx> dmtx;
std::unique_ptr<Vec> dresult;
std::unique_ptr<Vec2> dresult2;
std::unique_ptr<Vec> dy;
std::unique_ptr<Vec2> dy2;
std::unique_ptr<Vec> dalpha;
std::unique_ptr<Vec2> dalpha2;
std::unique_ptr<Vec> dbeta;
std::unique_ptr<Vec2> dbeta2;
};


Expand All @@ -138,6 +160,39 @@ TEST_F(Ell, SimpleApplyIsEquivalentToRef)
}


TEST_F(Ell, MixedSimpleApplyIsEquivalentToRef1)
{
set_up_apply_data();

mtx->apply(y2.get(), expected2.get());
dmtx->apply(dy2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, MixedSimpleApplyIsEquivalentToRef2)
{
set_up_apply_data();

mtx->apply(y2.get(), expected.get());
dmtx->apply(dy2.get(), dresult.get());

GKO_ASSERT_MTX_NEAR(dresult, expected, 1e-14);
}


TEST_F(Ell, MixedSimpleApplyIsEquivalentToRef3)
{
set_up_apply_data();

mtx->apply(y.get(), expected2.get());
dmtx->apply(dy.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, AdvancedApplyIsEquivalentToRef)
{
set_up_apply_data();
Expand All @@ -149,6 +204,39 @@ TEST_F(Ell, AdvancedApplyIsEquivalentToRef)
}


TEST_F(Ell, MixedAdvancedApplyIsEquivalentToRef1)
{
set_up_apply_data();

mtx->apply(alpha2.get(), y2.get(), beta2.get(), expected2.get());
dmtx->apply(dalpha2.get(), dy2.get(), dbeta2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, MixedAdvancedApplyIsEquivalentToRef2)
{
set_up_apply_data();

mtx->apply(alpha2.get(), y2.get(), beta.get(), expected.get());
dmtx->apply(dalpha2.get(), dy2.get(), dbeta.get(), dresult.get());

GKO_ASSERT_MTX_NEAR(dresult, expected, 1e-14);
}


TEST_F(Ell, MixedAdvancedApplyIsEquivalentToRef3)
{
set_up_apply_data();

mtx->apply(alpha.get(), y.get(), beta2.get(), expected2.get());
dmtx->apply(dalpha.get(), dy.get(), dbeta2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, SimpleApplyWithStrideIsEquivalentToRef)
{
set_up_apply_data(532, 231, 1, 300, 600);
Expand All @@ -160,6 +248,39 @@ TEST_F(Ell, SimpleApplyWithStrideIsEquivalentToRef)
}


TEST_F(Ell, MixedSimpleApplyWithStrideIsEquivalentToRef1)
{
set_up_apply_data(532, 231, 1, 300, 600);

mtx->apply(y2.get(), expected2.get());
dmtx->apply(dy2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, MixedSimpleApplyWithStrideIsEquivalentToRef2)
{
set_up_apply_data(532, 231, 1, 300, 600);

mtx->apply(y2.get(), expected.get());
dmtx->apply(dy2.get(), dresult.get());

GKO_ASSERT_MTX_NEAR(dresult, expected, 1e-14);
}


TEST_F(Ell, MixedSimpleApplyWithStrideIsEquivalentToRef3)
{
set_up_apply_data(532, 231, 1, 300, 600);

mtx->apply(y.get(), expected2.get());
dmtx->apply(dy.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, AdvancedApplyWithStrideIsEquivalentToRef)
{
set_up_apply_data(532, 231, 1, 300, 600);
Expand All @@ -170,6 +291,39 @@ TEST_F(Ell, AdvancedApplyWithStrideIsEquivalentToRef)
}


TEST_F(Ell, MixedAdvancedApplyWithStrideIsEquivalentToRef1)
{
set_up_apply_data(532, 231, 1, 300, 600);

mtx->apply(alpha2.get(), y2.get(), beta2.get(), expected2.get());
dmtx->apply(dalpha2.get(), dy2.get(), dbeta2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, MixedAdvancedApplyWithStrideIsEquivalentToRef2)
{
set_up_apply_data(532, 231, 1, 300, 600);

mtx->apply(alpha2.get(), y2.get(), beta.get(), expected.get());
dmtx->apply(dalpha2.get(), dy2.get(), dbeta.get(), dresult.get());

GKO_ASSERT_MTX_NEAR(dresult, expected, 1e-14);
}


TEST_F(Ell, MixedAdvancedApplyWithStrideIsEquivalentToRef3)
{
set_up_apply_data(532, 231, 1, 300, 600);

mtx->apply(alpha.get(), y.get(), beta2.get(), expected2.get());
dmtx->apply(dalpha.get(), dy.get(), dbeta2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, SimpleApplyWithStrideToDenseMatrixIsEquivalentToRef)
{
set_up_apply_data(532, 231, 3, 300, 600);
Expand All @@ -181,6 +335,39 @@ TEST_F(Ell, SimpleApplyWithStrideToDenseMatrixIsEquivalentToRef)
}


TEST_F(Ell, MixedSimpleApplyWithStrideToDenseMatrixIsEquivalentToRef1)
{
set_up_apply_data(532, 231, 3, 300, 600);

mtx->apply(y2.get(), expected2.get());
dmtx->apply(dy2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, MixedSimpleApplyWithStrideToDenseMatrixIsEquivalentToRef2)
{
set_up_apply_data(532, 231, 3, 300, 600);

mtx->apply(y2.get(), expected.get());
dmtx->apply(dy2.get(), dresult.get());

GKO_ASSERT_MTX_NEAR(dresult, expected, 1e-14);
}


TEST_F(Ell, MixedSimpleApplyWithStrideToDenseMatrixIsEquivalentToRef3)
{
set_up_apply_data(532, 231, 3, 300, 600);

mtx->apply(y.get(), expected2.get());
dmtx->apply(dy.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, AdvancedApplyWithStrideToDenseMatrixIsEquivalentToRef)
{
set_up_apply_data(532, 231, 3, 300, 600);
Expand All @@ -192,6 +379,39 @@ TEST_F(Ell, AdvancedApplyWithStrideToDenseMatrixIsEquivalentToRef)
}


TEST_F(Ell, MixedAdvancedApplyWithStrideToDenseMatrixIsEquivalentToRef1)
{
set_up_apply_data(532, 231, 3, 300, 600);

mtx->apply(alpha2.get(), y2.get(), beta2.get(), expected2.get());
dmtx->apply(dalpha2.get(), dy2.get(), dbeta2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, MixedAdvancedApplyWithStrideToDenseMatrixIsEquivalentToRef2)
{
set_up_apply_data(532, 231, 3, 300, 600);

mtx->apply(alpha2.get(), y2.get(), beta.get(), expected.get());
dmtx->apply(dalpha2.get(), dy2.get(), dbeta.get(), dresult.get());

GKO_ASSERT_MTX_NEAR(dresult, expected, 1e-14);
}


TEST_F(Ell, MixedAdvancedApplyWithStrideToDenseMatrixIsEquivalentToRef3)
{
set_up_apply_data(532, 231, 3, 300, 600);

mtx->apply(alpha.get(), y.get(), beta2.get(), expected2.get());
dmtx->apply(dalpha.get(), dy.get(), dbeta2.get(), dresult2.get());

GKO_ASSERT_MTX_NEAR(dresult2, expected2, 1e-14);
}


TEST_F(Ell, SimpleApplyByAtomicIsEquivalentToRef)
{
set_up_apply_data(10, 10000);
Expand Down
Loading

0 comments on commit f4265c8

Please sign in to comment.