Skip to content

Commit

Permalink
add build_row_ptrs tests
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Nov 8, 2021
1 parent 0339bde commit 3c39b8c
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions omp/test/components/device_matrix_data_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/matrix_data.hpp>


#include "core/components/device_matrix_data_kernels.hpp"
#include "core/test/utils.hpp"
#include "core/test/utils/assertions.hpp"

Expand Down Expand Up @@ -159,4 +160,50 @@ TYPED_TEST(DeviceMatrixData, DoesntRemoveZerosIfThereAreNone)
}


TYPED_TEST(DeviceMatrixData, BuildRowPtrsIsEquivalentToRef)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using device_matrix_data = gko::device_matrix_data<value_type, index_type>;
auto ref = gko::ReferenceExecutor::create();
auto ref_sorted_data =
device_matrix_data::create_from_host(ref, this->sorted_host_data);
auto device_sorted_data = device_matrix_data::create_from_host(
this->exec, this->sorted_host_data);
gko::Array<index_type> row_ptrs(ref, ref_sorted_data.size[0] + 1);
gko::Array<index_type> drow_ptrs(this->exec,
device_sorted_data.size[0] + 1);

gko::kernels::reference::components::build_row_ptrs(
ref, ref_sorted_data.nonzeros, ref_sorted_data.size[0],
row_ptrs.get_data());
gko::kernels::omp::components::build_row_ptrs(
this->exec, device_sorted_data.nonzeros, device_sorted_data.size[0],
drow_ptrs.get_data());

GKO_ASSERT_ARRAY_EQ(row_ptrs, drow_ptrs);
}


TYPED_TEST(DeviceMatrixData, BuildEmptyRowPtrsIsEquivalentToRef)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using device_matrix_data = gko::device_matrix_data<value_type, index_type>;
auto ref = gko::ReferenceExecutor::create();
auto ref_data = device_matrix_data{ref, gko::dim<2>{10, 10}};
auto device_data = device_matrix_data{this->exec, gko::dim<2>{10, 10}};
gko::Array<index_type> row_ptrs(ref, ref_data.size[0] + 1);
gko::Array<index_type> drow_ptrs(this->exec, device_data.size[0] + 1);

gko::kernels::reference::components::build_row_ptrs(
ref, ref_data.nonzeros, ref_data.size[0], row_ptrs.get_data());
gko::kernels::omp::components::build_row_ptrs(
this->exec, device_data.nonzeros, device_data.size[0],
drow_ptrs.get_data());

GKO_ASSERT_ARRAY_EQ(row_ptrs, drow_ptrs);
}


} // namespace

0 comments on commit 3c39b8c

Please sign in to comment.