diff --git a/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc b/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc index e55e7a60471..5c00358c5a0 100644 --- a/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc +++ b/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc @@ -33,7 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template __device__ __forceinline__ void simple_apply( - const gko::batch::matrix::batch_ell::batch_item& mat, + const gko::batch::matrix::ell::batch_item& mat, const ValueType* const __restrict__ b, ValueType* const __restrict__ x) { const auto num_rows = mat.num_rows; @@ -60,7 +60,7 @@ template __global__ __launch_bounds__( default_block_size, sm_oversubscription) void simple_apply_kernel(const gko::batch::matrix:: - batch_ell::uniform_batch< + ell::uniform_batch< const ValueType> mat, const gko::batch:: @@ -88,7 +88,7 @@ __global__ __launch_bounds__( template __device__ __forceinline__ void advanced_apply( const ValueType alpha, - const gko::batch::matrix::batch_ell::batch_item& mat, + const gko::batch::matrix::ell::batch_item& mat, const ValueType* const __restrict__ b, const ValueType beta, ValueType* const __restrict__ x) { @@ -121,10 +121,9 @@ __global__ __launch_bounds__( const ValueType> alpha, const gko::batch::matrix:: - batch_ell:: - uniform_batch< - const ValueType> - mat, + ell::uniform_batch< + const ValueType> + mat, const gko::batch:: multi_vector:: uniform_batch< diff --git a/core/matrix/batch_ell.cpp b/core/matrix/batch_ell.cpp index 0d903b10968..f421fdf2b49 100644 --- a/core/matrix/batch_ell.cpp +++ b/core/matrix/batch_ell.cpp @@ -104,22 +104,10 @@ template std::unique_ptr> Ell::create_with_config_of( ptr_param> other) -{ - // De-referencing `other` before calling the functions (instead of - // using operator `->`) is currently required to be compatible with - // CUDA 10.1. - // Otherwise, it results in a compile error. - return (*other).create_with_same_config(); -} - - -template -std::unique_ptr> -Ell::create_with_same_config() const { return Ell::create( - this->get_executor(), this->get_size(), - this->get_num_stored_elements_per_row()); + other->get_executor(), other->get_size(), + other->get_num_stored_elements_per_row()); } @@ -163,12 +151,7 @@ template void Ell::apply_impl(const MultiVector* b, MultiVector* x) const { - GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items()); - GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items()); - - GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size()); - GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size()); - GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size()); + this->validate_application_parameters(b, x); this->get_executor()->run(ell::make_simple_apply(this, b, x)); } @@ -179,14 +162,7 @@ void Ell::apply_impl(const MultiVector* alpha, const MultiVector* beta, MultiVector* x) const { - GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items()); - GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items()); - - GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size()); - GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size()); - GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size()); - GKO_ASSERT_EQUAL_DIMENSIONS(alpha->get_common_size(), gko::dim<2>(1, 1)); - GKO_ASSERT_EQUAL_DIMENSIONS(beta->get_common_size(), gko::dim<2>(1, 1)); + this->validate_application_parameters(alpha, b, beta, x); this->get_executor()->run( ell::make_advanced_apply(alpha, this, b, beta, x)); } diff --git a/core/matrix/batch_struct.hpp b/core/matrix/batch_struct.hpp index 2eed40882bc..eeeeebd53d6 100644 --- a/core/matrix/batch_struct.hpp +++ b/core/matrix/batch_struct.hpp @@ -83,7 +83,7 @@ struct uniform_batch { } // namespace dense -namespace batch_ell { +namespace ell { /** @@ -109,7 +109,7 @@ struct batch_item { template struct uniform_batch { using value_type = ValueType; - using index_type = int; + using index_type = int32; using entry_type = batch_item; ValueType* values; @@ -127,7 +127,7 @@ struct uniform_batch { }; -} // namespace batch_ell +} // namespace ell template @@ -165,8 +165,8 @@ GKO_ATTRIBUTES GKO_INLINE dense::batch_item extract_batch_item( template -GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item to_const( - const batch_ell::batch_item& b) +GKO_ATTRIBUTES GKO_INLINE ell::batch_item to_const( + const ell::batch_item& b) { return {b.values, b.col_idxs, b.stride, b.num_rows, b.num_cols, b.num_stored_elems_per_row}; @@ -174,8 +174,8 @@ GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item to_const( template -GKO_ATTRIBUTES GKO_INLINE batch_ell::uniform_batch to_const( - const batch_ell::uniform_batch& ub) +GKO_ATTRIBUTES GKO_INLINE ell::uniform_batch to_const( + const ell::uniform_batch& ub) { return {ub.values, ub.col_idxs, ub.num_batch_items, ub.stride, ub.num_rows, ub.num_cols, ub.num_stored_elems_per_row}; @@ -183,8 +183,8 @@ GKO_ATTRIBUTES GKO_INLINE batch_ell::uniform_batch to_const( template -GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item extract_batch_item( - const batch_ell::uniform_batch& batch, const size_type batch_idx) +GKO_ATTRIBUTES GKO_INLINE ell::batch_item extract_batch_item( + const ell::uniform_batch& batch, const size_type batch_idx) { return {batch.values + batch_idx * batch.num_stored_elems_per_row * batch.num_rows, @@ -196,7 +196,7 @@ GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item extract_batch_item( } template -GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item extract_batch_item( +GKO_ATTRIBUTES GKO_INLINE ell::batch_item extract_batch_item( ValueType* const batch_values, int* const batch_col_idxs, const int stride, const int num_rows, const int num_cols, int num_elems_per_row, const size_type batch_idx) diff --git a/core/test/matrix/batch_ell.cpp b/core/test/matrix/batch_ell.cpp index 2830705bf5f..e4dcab23917 100644 --- a/core/test/matrix/batch_ell.cpp +++ b/core/test/matrix/batch_ell.cpp @@ -144,6 +144,7 @@ TYPED_TEST(Ell, SparseMtxKnowsItsSizeAndValues) TYPED_TEST(Ell, CanBeEmpty) { auto empty = gko::batch::matrix::Ell::create(this->exec); + this->assert_empty(empty.get()); } @@ -151,6 +152,7 @@ TYPED_TEST(Ell, CanBeEmpty) TYPED_TEST(Ell, ReturnsNullValuesArrayWhenEmpty) { auto empty = gko::batch::matrix::Ell::create(this->exec); + ASSERT_EQ(empty->get_const_values(), nullptr); } @@ -284,7 +286,6 @@ TYPED_TEST(Ell, CanBeConstructedFromEllMatrices) using value_type = typename TestFixture::value_type; using EllMtx = typename TestFixture::EllMtx; using size_type = gko::size_type; - auto mat1 = gko::initialize({{-1.0, 0.0, 0.0}, {0.0, 2.5, 3.5}}, this->exec); auto mat2 = @@ -304,15 +305,14 @@ TYPED_TEST(Ell, CanBeConstructedFromEllMatricesByDuplication) using index_type = int; using EllMtx = typename TestFixture::EllMtx; using size_type = gko::size_type; - auto mat1 = gko::initialize({{1.0, 0.0, 0.0}, {0.0, 2.0, 0.0}}, this->exec); - auto bat_m = gko::batch::create_from_item>( this->exec, std::vector{mat1.get(), mat1.get(), mat1.get()}, mat1->get_num_stored_elements_per_row()); + auto m = gko::batch::create_from_item>( this->exec, 3, mat1.get(), mat1->get_num_stored_elements_per_row()); @@ -326,7 +326,6 @@ TYPED_TEST(Ell, CanBeConstructedByDuplicatingEllMatrices) using index_type = int; using EllMtx = typename TestFixture::EllMtx; using size_type = gko::size_type; - auto mat1 = gko::initialize({{-1.0, 0.0, 0.0}, {0.0, 2.5, 0.0}}, this->exec); auto mat2 = @@ -372,6 +371,7 @@ TYPED_TEST(Ell, CanBeListConstructed) { using value_type = typename TestFixture::value_type; using index_type = int; + auto m = gko::batch::initialize>( {{0.0, -1.0}, {1.0, 0.0}}, this->exec); diff --git a/cuda/matrix/batch_dense_kernels.cu b/cuda/matrix/batch_dense_kernels.cu index dd82e15b8cc..c693a3ae861 100644 --- a/cuda/matrix/batch_dense_kernels.cu +++ b/cuda/matrix/batch_dense_kernels.cu @@ -36,7 +36,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include +#include +#include #include "core/base/batch_struct.hpp" diff --git a/cuda/matrix/batch_ell_kernels.cu b/cuda/matrix/batch_ell_kernels.cu index ee6a99f04ca..6dd268a2d8e 100644 --- a/cuda/matrix/batch_ell_kernels.cu +++ b/cuda/matrix/batch_ell_kernels.cu @@ -34,18 +34,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include -#include +#include +#include #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "cuda/base/batch_struct.hpp" #include "cuda/base/config.hpp" -#include "cuda/base/cublas_bindings.hpp" -#include "cuda/base/pointer_mode_guard.hpp" #include "cuda/base/thrust.cuh" #include "cuda/components/cooperative_groups.cuh" #include "cuda/components/reduction.cuh" diff --git a/cuda/matrix/batch_struct.hpp b/cuda/matrix/batch_struct.hpp index 7a6a4ac7f00..e2db1ea6e97 100644 --- a/cuda/matrix/batch_struct.hpp +++ b/cuda/matrix/batch_struct.hpp @@ -38,6 +38,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include "core/base/batch_struct.hpp" @@ -91,16 +92,16 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch> +inline batch::matrix::ell::uniform_batch> get_batch_struct(const batch::matrix::Ell* const op) { return {as_cuda_type(op->get_const_values()), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -108,16 +109,16 @@ get_batch_struct(const batch::matrix::Ell* const op) * Generates a uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch> -get_batch_struct(batch::matrix::Ell* const op) +inline batch::matrix::ell::uniform_batch> get_batch_struct( + batch::matrix::Ell* const op) { return {as_cuda_type(op->get_values()), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index 1d1210cc270..fca265eceb0 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -39,17 +39,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include #include -#include #include #include "core/base/batch_struct.hpp" -#include "core/components/prefix_sum_kernels.hpp" #include "core/matrix/batch_struct.hpp" #include "dpcpp/base/batch_struct.hpp" -#include "dpcpp/base/config.hpp" #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/dpct.hpp" #include "dpcpp/base/helper.hpp" @@ -98,19 +94,19 @@ void simple_apply(std::shared_ptr exec, } // Launch a kernel that has nbatches blocks, each block has max group size - (exec->get_queue())->submit([&](sycl::handler& cgh) { + exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b, x_b, item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + simple_apply_kernel(mat_b, b_b, x_b, item_ct1); + }); }); } @@ -145,24 +141,24 @@ void advanced_apply(std::shared_ptr exec, const dim3 grid(num_batch_items); // Launch a kernel that has nbatches blocks, each block has max group size - (exec->get_queue())->submit([&](sycl::handler& cgh) { + exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto alpha_b = - batch::extract_batch_item(alpha_ub, group_id); - const auto beta_b = - batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, - item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto beta_b = + batch::extract_batch_item(beta_ub, group_id); + advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, + item_ct1); + }); }); } diff --git a/dpcpp/matrix/batch_ell_kernels.hpp.inc b/dpcpp/matrix/batch_ell_kernels.hpp.inc index 1048f2f8ff8..7500ae9e060 100644 --- a/dpcpp/matrix/batch_ell_kernels.hpp.inc +++ b/dpcpp/matrix/batch_ell_kernels.hpp.inc @@ -32,7 +32,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template __dpct_inline__ void simple_apply_kernel( - const gko::batch::matrix::batch_ell::batch_item& mat, + const gko::batch::matrix::ell::batch_item& mat, const gko::batch::multi_vector::batch_item& b, const gko::batch::multi_vector::batch_item& x, sycl::nd_item<3>& item_ct1) @@ -56,7 +56,7 @@ __dpct_inline__ void simple_apply_kernel( template __dpct_inline__ void advanced_apply_kernel( const gko::batch::multi_vector::batch_item& alpha, - const gko::batch::matrix::batch_ell::batch_item& mat, + const gko::batch::matrix::ell::batch_item& mat, const gko::batch::multi_vector::batch_item& b, const gko::batch::multi_vector::batch_item& beta, const gko::batch::multi_vector::batch_item& x, diff --git a/dpcpp/matrix/batch_struct.hpp b/dpcpp/matrix/batch_struct.hpp index 35ff1148dd5..f857653e05e 100644 --- a/dpcpp/matrix/batch_struct.hpp +++ b/dpcpp/matrix/batch_struct.hpp @@ -38,6 +38,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include "core/base/batch_struct.hpp" @@ -90,16 +91,16 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch -get_batch_struct(const batch::matrix::Ell* const op) +inline batch::matrix::ell::uniform_batch get_batch_struct( + const batch::matrix::Ell* const op) { return {op->get_const_values(), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -107,16 +108,16 @@ get_batch_struct(const batch::matrix::Ell* const op) * Generates a uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch get_batch_struct( +inline batch::matrix::ell::uniform_batch get_batch_struct( batch::matrix::Ell* const op) { return {op->get_values(), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/hip/matrix/batch_ell_kernels.hip.cpp b/hip/matrix/batch_ell_kernels.hip.cpp index fdd52c38f57..5c6d5179a21 100644 --- a/hip/matrix/batch_ell_kernels.hip.cpp +++ b/hip/matrix/batch_ell_kernels.hip.cpp @@ -35,18 +35,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include -#include -#include +#include +#include #include "core/base/batch_struct.hpp" #include "core/matrix/batch_struct.hpp" #include "hip/base/batch_struct.hip.hpp" #include "hip/base/config.hip.hpp" -#include "hip/base/hipblas_bindings.hip.hpp" -#include "hip/base/pointer_mode_guard.hip.hpp" #include "hip/base/thrust.hip.hpp" #include "hip/components/cooperative_groups.hip.hpp" #include "hip/components/reduction.hip.hpp" diff --git a/hip/matrix/batch_struct.hip.hpp b/hip/matrix/batch_struct.hip.hpp index a43d7d058b0..6f15b2d966a 100644 --- a/hip/matrix/batch_struct.hip.hpp +++ b/hip/matrix/batch_struct.hip.hpp @@ -38,6 +38,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include "core/base/batch_struct.hpp" @@ -91,16 +92,16 @@ get_batch_struct(batch::matrix::Dense* const op) * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch> +inline batch::matrix::ell::uniform_batch> get_batch_struct(const batch::matrix::Ell* const op) { return {as_hip_type(op->get_const_values()), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -108,16 +109,16 @@ get_batch_struct(const batch::matrix::Ell* const op) * Generates a uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch> -get_batch_struct(batch::matrix::Ell* const op) +inline batch::matrix::ell::uniform_batch> get_batch_struct( + batch::matrix::Ell* const op) { return {as_hip_type(op->get_values()), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/include/ginkgo/core/matrix/batch_ell.hpp b/include/ginkgo/core/matrix/batch_ell.hpp index 5cb5f73dec5..6f3db1bb96b 100644 --- a/include/ginkgo/core/matrix/batch_ell.hpp +++ b/include/ginkgo/core/matrix/batch_ell.hpp @@ -356,14 +356,6 @@ class Ell final col_idxs_.get_num_elems()); } - /** - * Creates a Ell matrix with the same configuration as the callers - * matrix. - * - * @returns a Ell matrix with the same configuration as the caller. - */ - std::unique_ptr create_with_same_config() const; - void apply_impl(const MultiVector* b, MultiVector* x) const; diff --git a/omp/matrix/batch_dense_kernels.cpp b/omp/matrix/batch_dense_kernels.cpp index 2d0b7ed4d40..b91a4133dba 100644 --- a/omp/matrix/batch_dense_kernels.cpp +++ b/omp/matrix/batch_dense_kernels.cpp @@ -36,8 +36,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include -#include +#include +#include #include "core/base/batch_struct.hpp" diff --git a/omp/matrix/batch_ell_kernels.cpp b/omp/matrix/batch_ell_kernels.cpp index 20ea4614e7d..17710a97366 100644 --- a/omp/matrix/batch_ell_kernels.cpp +++ b/omp/matrix/batch_ell_kernels.cpp @@ -36,8 +36,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include -#include +#include +#include #include "core/base/batch_struct.hpp" diff --git a/reference/matrix/batch_dense_kernels.cpp b/reference/matrix/batch_dense_kernels.cpp index 3d7ef03a3bd..87d73bb8e34 100644 --- a/reference/matrix/batch_dense_kernels.cpp +++ b/reference/matrix/batch_dense_kernels.cpp @@ -36,9 +36,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include -#include -#include +#include +#include #include "core/base/batch_struct.hpp" diff --git a/reference/matrix/batch_ell_kernels.cpp b/reference/matrix/batch_ell_kernels.cpp index a3f69827c02..1d3a0e1ef94 100644 --- a/reference/matrix/batch_ell_kernels.cpp +++ b/reference/matrix/batch_ell_kernels.cpp @@ -36,9 +36,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include -#include -#include +#include +#include #include "core/base/batch_struct.hpp" diff --git a/reference/matrix/batch_ell_kernels.hpp.inc b/reference/matrix/batch_ell_kernels.hpp.inc index 41d0a00ddcd..44de2a57af9 100644 --- a/reference/matrix/batch_ell_kernels.hpp.inc +++ b/reference/matrix/batch_ell_kernels.hpp.inc @@ -32,7 +32,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template inline void simple_apply_kernel( - const gko::batch::matrix::batch_ell::batch_item& a, + const gko::batch::matrix::ell::batch_item& a, const gko::batch::multi_vector::batch_item& b, const gko::batch::multi_vector::batch_item& c) { @@ -55,7 +55,7 @@ inline void simple_apply_kernel( template inline void advanced_apply_kernel( const ValueType alpha, - const gko::batch::matrix::batch_ell::batch_item& a, + const gko::batch::matrix::ell::batch_item& a, const gko::batch::multi_vector::batch_item& b, const ValueType beta, const gko::batch::multi_vector::batch_item& c) diff --git a/reference/matrix/batch_struct.hpp b/reference/matrix/batch_struct.hpp index 3b562450ee0..fb0e08c16f5 100644 --- a/reference/matrix/batch_struct.hpp +++ b/reference/matrix/batch_struct.hpp @@ -95,16 +95,16 @@ inline batch::matrix::dense::uniform_batch get_batch_struct( * Generates an immutable uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch -get_batch_struct(const batch::matrix::Ell* const op) +inline batch::matrix::ell::uniform_batch get_batch_struct( + const batch::matrix::Ell* const op) { return {op->get_const_values(), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -112,16 +112,16 @@ get_batch_struct(const batch::matrix::Ell* const op) * Generates a uniform batch struct from a batch of ell matrices. */ template -inline batch::matrix::batch_ell::uniform_batch get_batch_struct( +inline batch::matrix::ell::uniform_batch get_batch_struct( batch::matrix::Ell* const op) { return {op->get_values(), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/reference/test/matrix/batch_ell_kernels.cpp b/reference/test/matrix/batch_ell_kernels.cpp index 76b681c69f7..8a5806a9513 100644 --- a/reference/test/matrix/batch_ell_kernels.cpp +++ b/reference/test/matrix/batch_ell_kernels.cpp @@ -58,15 +58,13 @@ class Ell : public ::testing::Test { protected: using value_type = T; using size_type = gko::size_type; - using Mtx = gko::batch::matrix::Ell; - using MVec = gko::batch::MultiVector; + using BMtx = gko::batch::matrix::Ell; + using BMVec = gko::batch::MultiVector; using EllMtx = gko::matrix::Ell; using DenseMtx = gko::matrix::Dense; - using ComplexMtx = gko::to_complex; - using RealMtx = gko::remove_complex; Ell() : exec(gko::ReferenceExecutor::create()), - mtx_0(gko::batch::initialize( + mtx_0(gko::batch::initialize( {{I({1.0, -1.0, 1.5}), I({-2.0, 2.0, 3.0})}, {{1.0, -2.0, -0.5}, {1.0, -2.5, 4.0}}}, exec)), @@ -74,7 +72,7 @@ class Ell : public ::testing::Test { {I({1.0, -1.0, 1.5}), I({-2.0, 2.0, 3.0})}, exec)), mtx_01(gko::initialize( {I({1.0, -2.0, -0.5}), I({1.0, -2.5, 4.0})}, exec)), - b_0(gko::batch::initialize( + b_0(gko::batch::initialize( {{I({1.0, 0.0, 1.0}), I({2.0, 0.0, 1.0}), I({1.0, 0.0, 2.0})}, {I({-1.0, 1.0, 1.0}), I({1.0, -1.0, 1.0}), @@ -88,7 +86,7 @@ class Ell : public ::testing::Test { {I({-1.0, 1.0, 1.0}), I({1.0, -1.0, 1.0}), I({1.0, 0.0, 2.0})}, exec)), - x_0(gko::batch::initialize( + x_0(gko::batch::initialize( {{I({2.0, 0.0, 1.0}), I({2.0, 0.0, 2.0})}, {I({-2.0, 1.0, 1.0}), I({1.0, -1.0, -1.0})}}, exec)), @@ -99,13 +97,13 @@ class Ell : public ::testing::Test { {} std::shared_ptr exec; - std::unique_ptr mtx_0; + std::unique_ptr mtx_0; std::unique_ptr mtx_00; std::unique_ptr mtx_01; - std::unique_ptr b_0; + std::unique_ptr b_0; std::unique_ptr b_00; std::unique_ptr b_01; - std::unique_ptr x_0; + std::unique_ptr x_0; std::unique_ptr x_00; std::unique_ptr x_01; @@ -121,38 +119,10 @@ TYPED_TEST(Ell, AppliesToBatchMultiVector) using T = typename TestFixture::value_type; this->mtx_0->apply(this->b_0.get(), this->x_0.get()); + this->mtx_00->apply(this->b_00.get(), this->x_00.get()); this->mtx_01->apply(this->b_01.get(), this->x_01.get()); - - auto res = gko::batch::unbatch>(this->x_0.get()); - - GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), 0.); - GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), 0.); -} - - -TYPED_TEST(Ell, AppliesLinearCombinationWithSameAlphaToBatchMultiVector) -{ - using Mtx = typename TestFixture::Mtx; - using MVec = typename TestFixture::MVec; - using DenseMtx = typename TestFixture::DenseMtx; - using T = typename TestFixture::value_type; - auto alpha = gko::batch::initialize(2, {1.5}, this->exec); - auto beta = gko::batch::initialize(2, {-4.0}, this->exec); - auto alpha0 = gko::initialize({1.5}, this->exec); - auto alpha1 = gko::initialize({1.5}, this->exec); - auto beta0 = gko::initialize({-4.0}, this->exec); - auto beta1 = gko::initialize({-4.0}, this->exec); - - this->mtx_0->apply(alpha.get(), this->b_0.get(), beta.get(), - this->x_0.get()); - this->mtx_00->apply(alpha0.get(), this->b_00.get(), beta0.get(), - this->x_00.get()); - this->mtx_01->apply(alpha1.get(), this->b_01.get(), beta1.get(), - this->x_01.get()); - auto res = gko::batch::unbatch>(this->x_0.get()); - GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), 0.); GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), 0.); } @@ -160,12 +130,12 @@ TYPED_TEST(Ell, AppliesLinearCombinationWithSameAlphaToBatchMultiVector) TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector) { - using Mtx = typename TestFixture::Mtx; - using MVec = typename TestFixture::MVec; + using BMtx = typename TestFixture::BMtx; + using BMVec = typename TestFixture::BMVec; using DenseMtx = typename TestFixture::DenseMtx; using T = typename TestFixture::value_type; - auto alpha = gko::batch::initialize({{1.5}, {-1.0}}, this->exec); - auto beta = gko::batch::initialize({{2.5}, {-4.0}}, this->exec); + auto alpha = gko::batch::initialize({{1.5}, {-1.0}}, this->exec); + auto beta = gko::batch::initialize({{2.5}, {-4.0}}, this->exec); auto alpha0 = gko::initialize({1.5}, this->exec); auto alpha1 = gko::initialize({-1.0}, this->exec); auto beta0 = gko::initialize({2.5}, this->exec); @@ -173,13 +143,12 @@ TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector) this->mtx_0->apply(alpha.get(), this->b_0.get(), beta.get(), this->x_0.get()); + this->mtx_00->apply(alpha0.get(), this->b_00.get(), beta0.get(), this->x_00.get()); this->mtx_01->apply(alpha1.get(), this->b_01.get(), beta1.get(), this->x_01.get()); - auto res = gko::batch::unbatch>(this->x_0.get()); - GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), 0.); GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), 0.); } @@ -187,8 +156,8 @@ TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector) TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols) { - using MVec = typename TestFixture::MVec; - auto res = MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2}}); + using BMVec = typename TestFixture::BMVec; + auto res = BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2}}); ASSERT_THROW(this->mtx_0->apply(this->b_0.get(), res.get()), gko::DimensionMismatch); @@ -197,8 +166,8 @@ TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols) TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultRows) { - using MVec = typename TestFixture::MVec; - auto res = MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3}}); + using BMVec = typename TestFixture::BMVec; + auto res = BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3}}); ASSERT_THROW(this->mtx_0->apply(this->b_0.get(), res.get()), gko::DimensionMismatch); @@ -207,9 +176,9 @@ TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultRows) TYPED_TEST(Ell, ApplyFailsOnWrongInnerDimension) { - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; auto res = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); ASSERT_THROW(this->mtx_0->apply(res.get(), this->x_0.get()), gko::DimensionMismatch); @@ -218,13 +187,13 @@ TYPED_TEST(Ell, ApplyFailsOnWrongInnerDimension) TYPED_TEST(Ell, AdvancedApplyFailsOnWrongInnerDimension) { - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; auto res = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); auto alpha = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); auto beta = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); ASSERT_THROW( this->mtx_0->apply(alpha.get(), res.get(), beta.get(), this->x_0.get()), @@ -234,13 +203,13 @@ TYPED_TEST(Ell, AdvancedApplyFailsOnWrongInnerDimension) TYPED_TEST(Ell, AdvancedApplyFailsOnWrongAlphaDimension) { - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; auto res = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3, 3}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3, 3}}); auto alpha = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 1}}); auto beta = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); ASSERT_THROW( this->mtx_0->apply(alpha.get(), res.get(), beta.get(), this->x_0.get()), diff --git a/test/matrix/batch_ell_kernels.cpp b/test/matrix/batch_ell_kernels.cpp index bc1e0c7fb42..083af0a0938 100644 --- a/test/matrix/batch_ell_kernels.cpp +++ b/test/matrix/batch_ell_kernels.cpp @@ -55,18 +55,18 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. class Ell : public CommonTestFixture { protected: - using Mtx = gko::batch::matrix::Ell; - using MVec = gko::batch::MultiVector; + using BMtx = gko::batch::matrix::Ell; + using BMVec = gko::batch::MultiVector; Ell() : rand_engine(15) {} - template - std::unique_ptr gen_mtx(const gko::size_type num_batch_items, - gko::size_type num_rows, - gko::size_type num_cols, - int num_elems_per_row) + template + std::unique_ptr gen_mtx(const gko::size_type num_batch_items, + gko::size_type num_rows, + gko::size_type num_cols, + int num_elems_per_row) { - return gko::test::generate_random_batch_matrix( + return gko::test::generate_random_batch_matrix( num_batch_items, num_rows, num_cols, std::uniform_int_distribution<>(num_elems_per_row, num_elems_per_row), @@ -74,11 +74,11 @@ class Ell : public CommonTestFixture { num_elems_per_row); } - std::unique_ptr gen_mvec(const gko::size_type num_batch_items, - gko::size_type num_rows, - gko::size_type num_cols) + std::unique_ptr gen_mvec(const gko::size_type num_batch_items, + gko::size_type num_rows, + gko::size_type num_cols) { - return gko::test::generate_random_batch_matrix( + return gko::test::generate_random_batch_matrix( num_batch_items, num_rows, num_cols, std::uniform_int_distribution<>(num_cols, num_cols), std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); @@ -89,15 +89,16 @@ class Ell : public CommonTestFixture { { const int num_rows = 252; const int num_cols = 32; - x = gen_mtx(batch_size, num_rows, num_cols, num_elems_per_row); + GKO_ASSERT(num_elems_per_row <= num_cols); + mat = gen_mtx(batch_size, num_rows, num_cols, num_elems_per_row); y = gen_mvec(batch_size, num_cols, num_vecs); alpha = gen_mvec(batch_size, 1, 1); beta = gen_mvec(batch_size, 1, 1); - dx = gko::clone(exec, x); + dmat = gko::clone(exec, mat); dy = gko::clone(exec, y); dalpha = gko::clone(exec, alpha); dbeta = gko::clone(exec, beta); - expected = MVec::create( + expected = BMVec::create( ref, gko::batch_dim<2>(batch_size, gko::dim<2>{num_rows, num_vecs})); expected->fill(gko::one()); @@ -107,16 +108,16 @@ class Ell : public CommonTestFixture { std::ranlux48 rand_engine; const size_t batch_size = 11; - std::unique_ptr x; - std::unique_ptr y; - std::unique_ptr alpha; - std::unique_ptr beta; - std::unique_ptr expected; - std::unique_ptr dresult; - std::unique_ptr dx; - std::unique_ptr dy; - std::unique_ptr dalpha; - std::unique_ptr dbeta; + std::unique_ptr mat; + std::unique_ptr y; + std::unique_ptr alpha; + std::unique_ptr beta; + std::unique_ptr expected; + std::unique_ptr dresult; + std::unique_ptr dmat; + std::unique_ptr dy; + std::unique_ptr dalpha; + std::unique_ptr dbeta; }; @@ -124,8 +125,8 @@ TEST_F(Ell, SingleVectorApplyIsEquivalentToRef) { set_up_apply_data(1); - x->apply(y.get(), expected.get()); - dx->apply(dy.get(), dresult.get()); + mat->apply(y.get(), expected.get()); + dmat->apply(dy.get(), dresult.get()); GKO_ASSERT_BATCH_MTX_NEAR(dresult, expected, r::value); } @@ -135,8 +136,8 @@ TEST_F(Ell, SingleVectorAdvancedApplyIsEquivalentToRef) { set_up_apply_data(1); - x->apply(alpha.get(), y.get(), beta.get(), expected.get()); - dx->apply(dalpha.get(), dy.get(), dbeta.get(), dresult.get()); + mat->apply(alpha.get(), y.get(), beta.get(), expected.get()); + dmat->apply(dalpha.get(), dy.get(), dbeta.get(), dresult.get()); GKO_ASSERT_BATCH_MTX_NEAR(dresult, expected, r::value); } diff --git a/test/test_install/test_install.cpp b/test/test_install/test_install.cpp index 7e53ea8f165..c00bb594ecd 100644 --- a/test/test_install/test_install.cpp +++ b/test/test_install/test_install.cpp @@ -219,13 +219,20 @@ int main() auto test = batch_multi_vector_type::create(exec); } - // core/base/batch_dense.hpp + // core/matrix/batch_dense.hpp { using type1 = float; using batch_dense_type = gko::batch::matrix::Dense; auto test = batch_dense_type::create(exec); } + // core/matrix/batch_ell.hpp + { + using type1 = float; + using batch_ell_type = gko::batch::matrix::Ell; + auto test = batch_ell_type::create(exec); + } + // core/base/combination.hpp { using type1 = int;