Skip to content

Commit

Permalink
Add conversions into given matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzgoebel committed Nov 4, 2020
1 parent 7eadc1f commit 0e17307
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 14 deletions.
33 changes: 33 additions & 0 deletions core/matrix/dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,17 @@ Dense<ValueType>::make_complex() const
}


template <typename ValueType>
void Dense<ValueType>::make_complex(Dense<to_complex<ValueType>> *result) const
{
auto exec = this->get_executor();

GKO_ASSERT_EQUAL_DIMENSIONS(this, result);

exec->run(dense::make_make_complex(this, result));
}


template <typename ValueType>
std::unique_ptr<typename Dense<ValueType>::absolute_type>
Dense<ValueType>::get_real() const
Expand All @@ -826,6 +837,17 @@ Dense<ValueType>::get_real() const
}


template <typename ValueType>
void Dense<ValueType>::get_real(Dense<remove_complex<ValueType>> *result) const
{
auto exec = this->get_executor();

GKO_ASSERT_EQUAL_DIMENSIONS(this, result);

exec->run(dense::make_get_real(this, result));
}


template <typename ValueType>
std::unique_ptr<typename Dense<ValueType>::absolute_type>
Dense<ValueType>::get_imag() const
Expand All @@ -840,6 +862,17 @@ Dense<ValueType>::get_imag() const
}


template <typename ValueType>
void Dense<ValueType>::get_imag(Dense<remove_complex<ValueType>> *result) const
{
auto exec = this->get_executor();

GKO_ASSERT_EQUAL_DIMENSIONS(this, result);

exec->run(dense::make_get_imag(this, result));
}


#define GKO_DECLARE_DENSE_MATRIX(_type) class Dense<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_MATRIX);

Expand Down
39 changes: 39 additions & 0 deletions cuda/test/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,19 @@ TEST_F(Dense, MakeComplexIsEquivalentToRef)
}


TEST_F(Dense, MakeComplexWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto complex_x = ComplexMtx::create(ref, x->get_size());
x->make_complex(complex_x.get());
auto dcomplex_x = ComplexMtx::create(cuda, x->get_size());
dx->make_complex(dcomplex_x.get());

GKO_ASSERT_MTX_NEAR(complex_x, dcomplex_x, 0);
}


TEST_F(Dense, GetRealIsEquivalentToRef)
{
set_up_apply_data();
Expand All @@ -665,6 +678,19 @@ TEST_F(Dense, GetRealIsEquivalentToRef)
}


TEST_F(Dense, GetRealWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto real_x = Mtx::create(ref, x->get_size());
x->get_real(real_x.get());
auto dreal_x = Mtx::create(cuda, dx->get_size());
dx->get_real(dreal_x.get());

GKO_ASSERT_MTX_NEAR(real_x, dreal_x, 0);
}


TEST_F(Dense, GetImagIsEquivalentToRef)
{
set_up_apply_data();
Expand All @@ -676,4 +702,17 @@ TEST_F(Dense, GetImagIsEquivalentToRef)
}


TEST_F(Dense, GetImagWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto imag_x = Mtx::create(ref, x->get_size());
x->get_imag(imag_x.get());
auto dimag_x = Mtx::create(cuda, dx->get_size());
dx->get_imag(dimag_x.get());

GKO_ASSERT_MTX_NEAR(imag_x, dimag_x, 0);
}


} // namespace
39 changes: 39 additions & 0 deletions hip/test/matrix/dense_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,19 @@ TEST_F(Dense, MakeComplexIsEquivalentToRef)
}


TEST_F(Dense, MakeComplexWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto complex_x = ComplexMtx::create(ref, x->get_size());
x->make_complex(complex_x.get());
auto dcomplex_x = ComplexMtx::create(hip, x->get_size());
dx->make_complex(dcomplex_x.get());

GKO_ASSERT_MTX_NEAR(complex_x, dcomplex_x, 0);
}


TEST_F(Dense, GetRealIsEquivalentToRef)
{
set_up_apply_data();
Expand All @@ -648,6 +661,19 @@ TEST_F(Dense, GetRealIsEquivalentToRef)
}


TEST_F(Dense, GetRealWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto real_x = Mtx::create(ref, x->get_size());
x->get_real(real_x.get());
auto dreal_x = Mtx::create(hip, dx->get_size());
dx->get_real(dreal_x.get());

GKO_ASSERT_MTX_NEAR(real_x, dreal_x, 0);
}


TEST_F(Dense, GetImagIsEquivalentToRef)
{
set_up_apply_data();
Expand All @@ -659,4 +685,17 @@ TEST_F(Dense, GetImagIsEquivalentToRef)
}


TEST_F(Dense, GetImagWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto imag_x = Mtx::create(ref, x->get_size());
x->get_imag(imag_x.get());
auto dimag_x = Mtx::create(hip, dx->get_size());
dx->get_imag(dimag_x.get());

GKO_ASSERT_MTX_NEAR(imag_x, dimag_x, 0);
}


} // namespace
18 changes: 18 additions & 0 deletions include/ginkgo/core/matrix/dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,36 @@ class Dense
*/
std::unique_ptr<complex_type> make_complex() const;

/**
* Writes a complex copy of the original matrix to a given complex matrix.
* If the original matrix was real, the imaginary part of the result will
* be zero.
*/
void make_complex(Dense<to_complex<ValueType>> *result) const;

/**
* Creates a new real matrix and extracts the real part of the original
* matrix into that.
*/
std::unique_ptr<absolute_type> get_real() const;

/**
* Extracts the real part of the original matrix into a given real matrix.
*/
void get_real(Dense<remove_complex<ValueType>> *result) const;

/**
* Creates a new real matrix and extracts the imaginary part of the
* original matrix into that.
*/
std::unique_ptr<absolute_type> get_imag() const;

/**
* Extracts the imaginary part of the original matrix into a given real
* matrix.
*/
void get_imag(Dense<remove_complex<ValueType>> *result) const;

/**
* Returns a pointer to the array of values of the matrix.
*
Expand Down
39 changes: 39 additions & 0 deletions omp/test/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,19 @@ TEST_F(Dense, MakeComplexIsEquivalentToRef)
}


TEST_F(Dense, MakeComplexWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto complex_x = ComplexMtx::create(ref, x->get_size());
x->make_complex(complex_x.get());
auto dcomplex_x = ComplexMtx::create(omp, x->get_size());
dx->make_complex(dcomplex_x.get());

GKO_ASSERT_MTX_NEAR(complex_x, dcomplex_x, 0);
}


TEST_F(Dense, GetRealIsEquivalentToRef)
{
set_up_apply_data();
Expand All @@ -780,6 +793,19 @@ TEST_F(Dense, GetRealIsEquivalentToRef)
}


TEST_F(Dense, GetRealWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto real_x = Mtx::create(ref, x->get_size());
x->get_real(real_x.get());
auto dreal_x = Mtx::create(omp, dx->get_size());
dx->get_real(dreal_x.get());

GKO_ASSERT_MTX_NEAR(real_x, dreal_x, 0);
}


TEST_F(Dense, GetImagIsEquivalentToRef)
{
set_up_apply_data();
Expand All @@ -791,4 +817,17 @@ TEST_F(Dense, GetImagIsEquivalentToRef)
}


TEST_F(Dense, GetImagWithGivenResultIsEquivalentToRef)
{
set_up_apply_data();

auto imag_x = Mtx::create(ref, x->get_size());
x->get_imag(imag_x.get());
auto dimag_x = Mtx::create(omp, dx->get_size());
dx->get_imag(dimag_x.get());

GKO_ASSERT_MTX_NEAR(imag_x, dimag_x, 0);
}


} // namespace
Loading

0 comments on commit 0e17307

Please sign in to comment.