From e0c7983be152a90fc639cda77e514350f666d2f3 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Fri, 3 Feb 2023 19:23:50 +0100 Subject: [PATCH] review updates and interface improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pull in convert_to/move_to pointer_param overloads - make Executor::copy_from work on shared_ptr - remove unnecessary add_scaled_identity overload - simplify gko::write - deprecate moving copy_from - simplify move_from - make temporary_clone/conversion work on pointer_param - make make_(const_)dense_view work on pointer_param - make stopping criterion updater work on pointer_param - add documentation - fix accidentally deleted copy constructor - deprecate moving copy_from Co-authored-by: Thomas Grützmacher Co-authored-by: Marcel Koch --- core/distributed/matrix.cpp | 9 +- core/matrix/dense.cpp | 69 +++++----- core/test/base/mtx_io.cpp | 6 +- include/ginkgo/core/base/executor.hpp | 20 +-- include/ginkgo/core/base/lin_op.hpp | 33 +++-- include/ginkgo/core/base/mtx_io.hpp | 11 +- .../ginkgo/core/base/polymorphic_object.hpp | 118 ++++++------------ include/ginkgo/core/base/temporary_clone.hpp | 29 +++-- .../ginkgo/core/base/temporary_conversion.hpp | 6 +- include/ginkgo/core/base/utils_helper.hpp | 41 +++++- include/ginkgo/core/distributed/matrix.hpp | 8 +- include/ginkgo/core/matrix/coo.hpp | 20 +-- include/ginkgo/core/matrix/csr.hpp | 4 +- include/ginkgo/core/matrix/dense.hpp | 23 ++-- include/ginkgo/core/stop/criterion.hpp | 17 ++- 15 files changed, 215 insertions(+), 199 deletions(-) diff --git a/core/distributed/matrix.cpp b/core/distributed/matrix.cpp index f20201ad2db..64f4d1b2bf7 100644 --- a/core/distributed/matrix.cpp +++ b/core/distributed/matrix.cpp @@ -66,7 +66,7 @@ Matrix::Matrix( template Matrix::Matrix( std::shared_ptr exec, mpi::communicator comm, - const LinOp* local_matrix_type) + pointer_param local_matrix_type) : Matrix(exec, comm, local_matrix_type, local_matrix_type) {} @@ -74,7 +74,8 @@ Matrix::Matrix( template Matrix::Matrix( std::shared_ptr exec, mpi::communicator comm, - const LinOp* local_matrix_template, const LinOp* non_local_matrix_template) + pointer_param local_matrix_template, + pointer_param non_local_matrix_template) : EnableDistributedLinOp< Matrix>{exec}, DistributedBase{comm}, @@ -173,8 +174,8 @@ void Matrix::read_distributed( // build local, non-local matrix data and communication structures exec->run(matrix::make_build_local_nonlocal( - data, make_temporary_clone(exec, row_partition.get()).get(), - make_temporary_clone(exec, col_partition.get()).get(), local_part, + data, make_temporary_clone(exec, row_partition).get(), + make_temporary_clone(exec, col_partition).get(), local_part, local_row_idxs, local_col_idxs, local_values, non_local_row_idxs, non_local_col_idxs, non_local_values, recv_gather_idxs, recv_sizes_array, non_local_to_global_)); diff --git a/core/matrix/dense.cpp b/core/matrix/dense.cpp index 177ac9c7368..b8d4ff9ba1c 100644 --- a/core/matrix/dense.cpp +++ b/core/matrix/dense.cpp @@ -156,7 +156,7 @@ template void Dense::scale(pointer_param alpha) { auto exec = this->get_executor(); - this->scale_impl(make_temporary_clone(exec, alpha.get()).get()); + this->scale_impl(make_temporary_clone(exec, alpha).get()); } @@ -164,7 +164,7 @@ template void Dense::inv_scale(pointer_param alpha) { auto exec = this->get_executor(); - this->inv_scale_impl(make_temporary_clone(exec, alpha.get()).get()); + this->inv_scale_impl(make_temporary_clone(exec, alpha).get()); } @@ -173,8 +173,8 @@ void Dense::add_scaled(pointer_param alpha, pointer_param b) { auto exec = this->get_executor(); - this->add_scaled_impl(make_temporary_clone(exec, alpha.get()).get(), - make_temporary_clone(exec, b.get()).get()); + this->add_scaled_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get()); } @@ -183,8 +183,8 @@ void Dense::sub_scaled(pointer_param alpha, pointer_param b) { auto exec = this->get_executor(); - this->sub_scaled_impl(make_temporary_clone(exec, alpha.get()).get(), - make_temporary_clone(exec, b.get()).get()); + this->sub_scaled_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get()); } @@ -193,9 +193,8 @@ void Dense::compute_dot(pointer_param b, pointer_param result) const { auto exec = this->get_executor(); - this->compute_dot_impl( - make_temporary_clone(exec, b.get()).get(), - make_temporary_output_clone(exec, result.get()).get()); + this->compute_dot_impl(make_temporary_clone(exec, b).get(), + make_temporary_output_clone(exec, result).get()); } @@ -205,8 +204,8 @@ void Dense::compute_conj_dot(pointer_param b, { auto exec = this->get_executor(); this->compute_conj_dot_impl( - make_temporary_clone(exec, b.get()).get(), - make_temporary_output_clone(exec, result.get()).get()); + make_temporary_clone(exec, b).get(), + make_temporary_output_clone(exec, result).get()); } @@ -214,8 +213,7 @@ template void Dense::compute_norm2(pointer_param result) const { auto exec = this->get_executor(); - this->compute_norm2_impl( - make_temporary_output_clone(exec, result.get()).get()); + this->compute_norm2_impl(make_temporary_output_clone(exec, result).get()); } @@ -223,8 +221,7 @@ template void Dense::compute_norm1(pointer_param result) const { auto exec = this->get_executor(); - this->compute_norm1_impl( - make_temporary_output_clone(exec, result.get()).get()); + this->compute_norm1_impl(make_temporary_output_clone(exec, result).get()); } @@ -355,10 +352,10 @@ void Dense::compute_dot(pointer_param b, tmp.clear(); tmp.set_executor(exec); } - auto local_b = make_temporary_clone(exec, b.get()); - auto local_res = make_temporary_clone(exec, result.get()); - auto dense_b = make_temporary_conversion(local_b.get()); - auto dense_res = make_temporary_conversion(local_res.get()); + auto local_b = make_temporary_clone(exec, b); + auto local_res = make_temporary_clone(exec, result); + auto dense_b = make_temporary_conversion(local_b); + auto dense_res = make_temporary_conversion(local_res); exec->run( dense::make_compute_dot(this, dense_b.get(), dense_res.get(), tmp)); } @@ -390,10 +387,10 @@ void Dense::compute_conj_dot(pointer_param b, tmp.clear(); tmp.set_executor(exec); } - auto local_b = make_temporary_clone(exec, b.get()); - auto local_res = make_temporary_clone(exec, result.get()); - auto dense_b = make_temporary_conversion(local_b.get()); - auto dense_res = make_temporary_conversion(local_res.get()); + auto local_b = make_temporary_clone(exec, b); + auto local_res = make_temporary_clone(exec, result); + auto dense_b = make_temporary_conversion(local_b); + auto dense_res = make_temporary_conversion(local_res); exec->run(dense::make_compute_conj_dot(this, dense_b.get(), dense_res.get(), tmp)); } @@ -424,7 +421,7 @@ void Dense::compute_norm2(pointer_param result, tmp.clear(); tmp.set_executor(exec); } - auto local_result = make_temporary_clone(exec, result.get()); + auto local_result = make_temporary_clone(exec, result); auto dense_res = make_temporary_conversion>( local_result.get()); exec->run(dense::make_compute_norm2(this, dense_res.get(), tmp)); @@ -453,7 +450,7 @@ void Dense::compute_norm1(pointer_param result, tmp.clear(); tmp.set_executor(exec); } - auto local_result = make_temporary_clone(exec, result.get()); + auto local_result = make_temporary_clone(exec, result); auto dense_res = make_temporary_conversion>( local_result.get()); exec->run(dense::make_compute_norm1(this, dense_res.get(), tmp)); @@ -1026,7 +1023,7 @@ void Dense::transpose(pointer_param> output) const GKO_ASSERT_EQUAL_DIMENSIONS(output, gko::transpose(this->get_size())); auto exec = this->get_executor(); exec->run(dense::make_transpose( - this, make_temporary_output_clone(exec, output.get()).get())); + this, make_temporary_output_clone(exec, output).get())); } @@ -1037,7 +1034,7 @@ void Dense::conj_transpose( GKO_ASSERT_EQUAL_DIMENSIONS(output, gko::transpose(this->get_size())); auto exec = this->get_executor(); exec->run(dense::make_conj_transpose( - this, make_temporary_output_clone(exec, output.get()).get())); + this, make_temporary_output_clone(exec, output).get())); } @@ -1344,8 +1341,8 @@ void Dense::row_gather(pointer_param alpha, pointer_param beta, pointer_param out) const { - auto dense_alpha = make_temporary_conversion(alpha.get()); - auto dense_beta = make_temporary_conversion(beta.get()); + auto dense_alpha = make_temporary_conversion(alpha); + auto dense_beta = make_temporary_conversion(beta); GKO_ASSERT_EQUAL_DIMENSIONS(dense_alpha, gko::dim<2>(1, 1)); GKO_ASSERT_EQUAL_DIMENSIONS(dense_beta, gko::dim<2>(1, 1)); gather_mixed_real_complex( @@ -1362,8 +1359,8 @@ void Dense::row_gather(pointer_param alpha, pointer_param beta, pointer_param out) const { - auto dense_alpha = make_temporary_conversion(alpha.get()); - auto dense_beta = make_temporary_conversion(beta.get()); + auto dense_alpha = make_temporary_conversion(alpha); + auto dense_beta = make_temporary_conversion(beta); GKO_ASSERT_EQUAL_DIMENSIONS(dense_alpha, gko::dim<2>(1, 1)); GKO_ASSERT_EQUAL_DIMENSIONS(dense_beta, gko::dim<2>(1, 1)); gather_mixed_real_complex( @@ -1498,7 +1495,7 @@ void Dense::extract_diagonal( GKO_ASSERT_EQ(output->get_size()[0], diag_size); exec->run(dense::make_extract_diagonal( - this, make_temporary_output_clone(exec, output.get()).get())); + this, make_temporary_output_clone(exec, output).get())); } @@ -1538,7 +1535,7 @@ void Dense::compute_absolute( auto exec = this->get_executor(); exec->run(dense::make_outplace_absolute_dense( - this, make_temporary_output_clone(exec, output.get()).get())); + this, make_temporary_output_clone(exec, output).get())); } @@ -1559,7 +1556,7 @@ void Dense::make_complex(pointer_param result) const auto exec = this->get_executor(); exec->run(dense::make_make_complex( - this, make_temporary_output_clone(exec, result.get()).get())); + this, make_temporary_output_clone(exec, result).get())); } @@ -1580,7 +1577,7 @@ void Dense::get_real(pointer_param result) const auto exec = this->get_executor(); exec->run(dense::make_get_real( - this, make_temporary_output_clone(exec, result.get()).get())); + this, make_temporary_output_clone(exec, result).get())); } @@ -1601,7 +1598,7 @@ void Dense::get_imag(pointer_param result) const auto exec = this->get_executor(); exec->run(dense::make_get_imag( - this, make_temporary_output_clone(exec, result.get()).get())); + this, make_temporary_output_clone(exec, result).get())); } diff --git a/core/test/base/mtx_io.cpp b/core/test/base/mtx_io.cpp index 2575dd79bfd..e04542e7964 100644 --- a/core/test/base/mtx_io.cpp +++ b/core/test/base/mtx_io.cpp @@ -1087,12 +1087,16 @@ TYPED_TEST(RealDummyLinOpTest, WritesLinOpToStreamDefault) auto lin_op = gko::read>( iss, gko::ReferenceExecutor::create()); std::ostringstream oss{}; + std::ostringstream oss_const{}; - write(oss, lend(lin_op)); + write(oss, lin_op); + write(oss_const, std::unique_ptr>{ + std::move(lin_op)}); ASSERT_EQ(oss.str(), "%%MatrixMarket matrix coordinate real general\n2 3 6\n1 1 1\n1 " "2 3\n1 3 2\n2 1 0\n2 2 5\n2 3 0\n"); + ASSERT_EQ(oss_const.str(), oss.str()); } diff --git a/include/ginkgo/core/base/executor.hpp b/include/ginkgo/core/base/executor.hpp index a3fa7800275..5c7e21a815c 100644 --- a/include/ginkgo/core/base/executor.hpp +++ b/include/ginkgo/core/base/executor.hpp @@ -765,19 +765,19 @@ class Executor : public log::EnableLogging { * where the data will be copied to */ template - void copy_from(const Executor* src_exec, size_type num_elems, + void copy_from(pointer_param src_exec, size_type num_elems, const T* src_ptr, T* dest_ptr) const { const auto src_loc = reinterpret_cast(src_ptr); const auto dest_loc = reinterpret_cast(dest_ptr); this->template log( - src_exec, this, src_loc, dest_loc, num_elems * sizeof(T)); - if (this != src_exec) { + src_exec.get(), this, src_loc, dest_loc, num_elems * sizeof(T)); + if (this != src_exec.get()) { src_exec->template log( - src_exec, this, src_loc, dest_loc, num_elems * sizeof(T)); + src_exec.get(), this, src_loc, dest_loc, num_elems * sizeof(T)); } try { - this->raw_copy_from(src_exec, num_elems * sizeof(T), src_ptr, + this->raw_copy_from(src_exec.get(), num_elems * sizeof(T), src_ptr, dest_ptr); } catch (NotSupported&) { #if (GKO_VERBOSE_LEVEL >= 1) && !defined(NDEBUG) @@ -787,7 +787,7 @@ class Executor : public log::EnableLogging { << std::endl; #endif auto src_master = src_exec->get_master().get(); - if (num_elems > 0 && src_master != src_exec) { + if (num_elems > 0 && src_master != src_exec.get()) { auto* master_ptr = src_exec->get_master()->alloc(num_elems); src_master->copy_from(src_exec, num_elems, src_ptr, master_ptr); @@ -796,10 +796,10 @@ class Executor : public log::EnableLogging { } } this->template log( - src_exec, this, src_loc, dest_loc, num_elems * sizeof(T)); - if (this != src_exec) { + src_exec.get(), this, src_loc, dest_loc, num_elems * sizeof(T)); + if (this != src_exec.get()) { src_exec->template log( - src_exec, this, src_loc, dest_loc, num_elems * sizeof(T)); + src_exec.get(), this, src_loc, dest_loc, num_elems * sizeof(T)); } } @@ -879,6 +879,8 @@ class Executor : public log::EnableLogging { this->EnableLogging::remove_logger(logger); } + using EnableLogging::remove_logger; + /** * Sets the logger event propagation mode for the executor. * This controls whether events that happen at objects created on this diff --git a/include/ginkgo/core/base/lin_op.hpp b/include/ginkgo/core/base/lin_op.hpp index 300fe84f0cc..882f4d3c33e 100644 --- a/include/ginkgo/core/base/lin_op.hpp +++ b/include/ginkgo/core/base/lin_op.hpp @@ -161,8 +161,8 @@ class LinOp : public EnableAbstractPolymorphicObject { x.get()); this->validate_application_parameters(b.get(), x.get()); auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); this->template log(this, b.get(), x.get()); return this; @@ -178,8 +178,8 @@ class LinOp : public EnableAbstractPolymorphicObject { x.get()); this->validate_application_parameters(b.get(), x.get()); auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); this->template log(this, b.get(), x.get()); return this; @@ -203,10 +203,10 @@ class LinOp : public EnableAbstractPolymorphicObject { this->validate_application_parameters(alpha.get(), b.get(), beta.get(), x.get()); auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, alpha.get()).get(), - make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, beta.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, beta).get(), + make_temporary_clone(exec, x).get()); this->template log( this, alpha.get(), b.get(), beta.get(), x.get()); return this; @@ -225,10 +225,10 @@ class LinOp : public EnableAbstractPolymorphicObject { this->validate_application_parameters(alpha.get(), b.get(), beta.get(), x.get()); auto exec = this->get_executor(); - this->apply_impl(make_temporary_clone(exec, alpha.get()).get(), - make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, beta.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, beta).get(), + make_temporary_clone(exec, x).get()); this->template log( this, alpha.get(), b.get(), beta.get(), x.get()); return this; @@ -838,7 +838,8 @@ class ScaledIdentityAddable { * @param b Scalar to multiply this before adding the scaled identity to * it. */ - void add_scaled_identity(const LinOp* const a, const LinOp* const b) + void add_scaled_identity(pointer_param const a, + pointer_param const b) { GKO_ASSERT_IS_SCALAR(a); GKO_ASSERT_IS_SCALAR(b); @@ -847,12 +848,6 @@ class ScaledIdentityAddable { add_scaled_identity_impl(ae.get(), be.get()); } - void add_scaled_identity(pointer_param const a, - pointer_param const b) - { - add_scaled_identity(a.get(), b.get()); - } - private: virtual void add_scaled_identity_impl(const LinOp* a, const LinOp* b) = 0; }; diff --git a/include/ginkgo/core/base/mtx_io.hpp b/include/ginkgo/core/base/mtx_io.hpp index 5e106adb8b4..a6efded28db 100644 --- a/include/ginkgo/core/base/mtx_io.hpp +++ b/include/ginkgo/core/base/mtx_io.hpp @@ -320,12 +320,13 @@ struct mtx_io_traits { * @param matrix the matrix to write * @param layout the layout used in the output */ -template >> +template inline void write( StreamType&& os, MatrixPtrType&& matrix, - layout_type layout = detail::mtx_io_traits::default_layout) + layout_type layout = detail::mtx_io_traits>>>::default_layout) { + using MatrixType = detail::pointee>; matrix_data data{}; @@ -347,10 +348,10 @@ inline void write( * @param os output stream where the data is to be written * @param matrix the matrix to write */ -template >> +template inline void write_binary(StreamType&& os, MatrixPtrType&& matrix) { + using MatrixType = detail::pointee>; matrix_data data{}; diff --git a/include/ginkgo/core/base/polymorphic_object.hpp b/include/ginkgo/core/base/polymorphic_object.hpp index b4b5512e863..3f848d413f5 100644 --- a/include/ginkgo/core/base/polymorphic_object.hpp +++ b/include/ginkgo/core/base/polymorphic_object.hpp @@ -175,12 +175,20 @@ class PolymorphicObject : public log::EnableLogging { * @param other the object to move from * * @return this + * + * @tparam Derived the actual pointee type of the parameter, it needs to be + * derived from PolymorphicObject. + * @tparam Deleter the deleter of the unique_ptr parameter */ - template - std::enable_if_t< - std::is_base_of>::value, - PolymorphicObject>* - copy_from(std::unique_ptr&& other) + template + [[deprecated( + "This function will be removed in a future release, the replacement " + "will copy instead of move. If a move in intended, use move_to " + "instead.")]] std:: + enable_if_t< + std::is_base_of>::value, + PolymorphicObject>* + copy_from(std::unique_ptr&& other) { this->template log( exec_.get(), other.get(), this); @@ -190,15 +198,25 @@ class PolymorphicObject : public log::EnableLogging { return copied; } - template + /** + * @copydoc copy_from(const PolymorphicObject*) + * + * @tparam Derived the actual pointee type of the parameter, it needs to be + * derived from PolymorphicObject. + * @tparam Deleter the deleter of the unique_ptr parameter + */ + template std::enable_if_t< std::is_base_of>::value, PolymorphicObject>* - copy_from(const std::unique_ptr& other) + copy_from(const std::unique_ptr& other) { return this->copy_from(other.get()); } + /** + * @copydoc copy_from(const PolymorphicObject*) + */ PolymorphicObject* copy_from( const std::shared_ptr& other) { @@ -216,54 +234,14 @@ class PolymorphicObject : public log::EnableLogging { * * @return this */ - PolymorphicObject* move_from(PolymorphicObject* other) - { - this->template log( - exec_.get(), other, this); - auto moved = this->move_from_impl(other); - this->template log( - exec_.get(), other, this); - return moved; - } - - /** - * Moves another object into this object. - * - * This is the polymorphic equivalent of the move assignment operator. - * - * @see move_from_impl(std::unique_ptr) - * - * @param other the object to move from - * - * @return this - */ - template - std::enable_if_t< - std::is_base_of>::value, - PolymorphicObject>* - move_from(std::unique_ptr&& other) + PolymorphicObject* move_from(pointer_param other) { this->template log( exec_.get(), other.get(), this); - auto copied = this->copy_from_impl(std::move(other)); + auto moved = this->move_from_impl(other.get()); this->template log( exec_.get(), other.get(), this); - return copied; - } - - template - std::enable_if_t< - std::is_base_of>::value, - PolymorphicObject>* - move_from(const std::unique_ptr& other) - { - return move_from(other.get()); - } - - PolymorphicObject* move_from( - const std::shared_ptr& other) - { - return move_from(other.get()); + return moved; } /** @@ -430,10 +408,14 @@ class EnableAbstractPolymorphicObject : public PolymorphicBase { } template - std::enable_if_t< - std::is_base_of>::value, - AbstractObject>* - copy_from(std::unique_ptr&& other) + [[deprecated( + "This function will be removed in a future release, the replacement " + "will copy instead of move. If a move in intended, use move_to " + "instead.")]] std:: + enable_if_t< + std::is_base_of>::value, + AbstractObject>* + copy_from(std::unique_ptr&& other) { return static_cast( this->PolymorphicBase::copy_from(std::move(other))); @@ -454,34 +436,10 @@ class EnableAbstractPolymorphicObject : public PolymorphicBase { return copy_from(other.get()); } - AbstractObject* move_from(PolymorphicObject* other) - { - return static_cast( - this->PolymorphicBase::move_from(other)); - } - - template - std::enable_if_t< - std::is_base_of>::value, - AbstractObject>* - move_from(std::unique_ptr&& other) + AbstractObject* move_from(pointer_param other) { return static_cast( - this->PolymorphicBase::move_from(std::move(other))); - } - - template - std::enable_if_t< - std::is_base_of>::value, - AbstractObject>* - move_from(const std::unique_ptr& other) - { - return move_from(other.get()); - } - - AbstractObject* move_from(const std::shared_ptr& other) - { - return move_from(other.get()); + this->PolymorphicBase::move_from(other.get())); } AbstractObject* clear() diff --git a/include/ginkgo/core/base/temporary_clone.hpp b/include/ginkgo/core/base/temporary_clone.hpp index 87ba89fb0f2..33bd431956a 100644 --- a/include/ginkgo/core/base/temporary_clone.hpp +++ b/include/ginkgo/core/base/temporary_clone.hpp @@ -140,19 +140,19 @@ class temporary_clone { * @param copy_data should the data be copied to the executor, or should * only the result be copied back afterwards? */ - explicit temporary_clone(std::shared_ptr exec, pointer ptr, - bool copy_data = true) + explicit temporary_clone(std::shared_ptr exec, + pointer_param ptr, bool copy_data = true) { if (ptr->get_executor()->memory_accessible(exec)) { // just use the object we already have - handle_ = handle_type(ptr, null_deleter()); + handle_ = handle_type(ptr.get(), null_deleter()); } else { // clone the object to the new executor and make sure it's copied // back before we delete it handle_ = handle_type(temporary_clone_helper::create( - std::move(exec), ptr, copy_data) + std::move(exec), ptr.get(), copy_data) .release(), - copy_back_deleter(ptr)); + copy_back_deleter(ptr.get())); } } @@ -198,11 +198,12 @@ class temporary_clone { * @param exec the executor where the clone will be created * @param ptr a pointer to the object of which the clone will be created */ -template -detail::temporary_clone make_temporary_clone( - std::shared_ptr exec, T* ptr) +template +detail::temporary_clone())>> +make_temporary_clone(std::shared_ptr exec, Ptr&& ptr) { - return detail::temporary_clone(std::move(exec), ptr); + using T = std::remove_reference_t())>; + return detail::temporary_clone(std::move(exec), std::forward(ptr)); } @@ -218,14 +219,16 @@ detail::temporary_clone make_temporary_clone( * @param exec the executor where the uninitialized clone will be created * @param ptr a pointer to the object of which the clone will be created */ -template -detail::temporary_clone make_temporary_output_clone( - std::shared_ptr exec, T* ptr) +template +detail::temporary_clone())>> +make_temporary_output_clone(std::shared_ptr exec, Ptr&& ptr) { + using T = std::remove_reference_t())>; static_assert( !std::is_const::value, "make_temporary_output_clone should only be used on non-const objects"); - return detail::temporary_clone(std::move(exec), ptr, false); + return detail::temporary_clone(std::move(exec), std::forward(ptr), + false); } diff --git a/include/ginkgo/core/base/temporary_conversion.hpp b/include/ginkgo/core/base/temporary_conversion.hpp index e94974fce58..520f1e899aa 100644 --- a/include/ginkgo/core/base/temporary_conversion.hpp +++ b/include/ginkgo/core/base/temporary_conversion.hpp @@ -234,14 +234,14 @@ class temporary_conversion { * try out for converting ptr to type T. */ template - static temporary_conversion create(lin_op_type* ptr) + static temporary_conversion create(pointer_param ptr) { T* cast_ptr{}; - if ((cast_ptr = dynamic_cast(ptr))) { + if ((cast_ptr = dynamic_cast(ptr.get()))) { return handle_type{cast_ptr, null_deleter{}}; } else { return conversion_helper::template convert< - T>(ptr); + T>(ptr.get()); } } diff --git a/include/ginkgo/core/base/utils_helper.hpp b/include/ginkgo/core/base/utils_helper.hpp index 240843ead54..705219a1138 100644 --- a/include/ginkgo/core/base/utils_helper.hpp +++ b/include/ginkgo/core/base/utils_helper.hpp @@ -56,28 +56,60 @@ namespace gko { class Executor; +/** + * This class is used for function parameters in the place of raw pointers. + * It can be converted to from raw pointers, shared pointers and unique pointers + * of the specified type or any derived type. This allows functions to be called + * without having to use gko::lend or calling .get() for every pointer argument. + * It probably has no use outside of function parameters, as it is immutable. + * + * @tparam T the pointed-to type + */ template class pointer_param { public: + /** Initializes the pointer_param from a raw pointer. */ pointer_param(T* ptr) : ptr_{ptr} {} + /** Initializes the pointer_param from a shared_ptr. */ template ::value>* = nullptr> pointer_param(const std::shared_ptr& ptr) : pointer_param{ptr.get()} {} + /** Initializes the pointer_param from a unique_ptr. */ template ::value>* = nullptr> pointer_param(const std::unique_ptr& ptr) : pointer_param{ptr.get()} {} + pointer_param(const pointer_param&) = default; + + pointer_param(pointer_param&&) = default; + + /** Initializes the pointer_param from a pointer_param of a derived type. */ + template ::value>* = nullptr> + pointer_param(const pointer_param& ptr) : pointer_param{ptr.get()} + {} + + /** @return a reference to the underlying pointee. */ T& operator*() const { return *ptr_; } + /** @return the underlying pointer. */ T* operator->() const { return ptr_; } + /** @return the underlying pointer. */ T* get() const { return ptr_; } + /** @return true iff the underlying pointer is non-null. */ + explicit operator bool() const { return ptr_; } + + pointer_param& operator=(const pointer_param&) = delete; + + pointer_param& operator=(pointer_param&&) = delete; + private: T* ptr_; }; @@ -97,6 +129,11 @@ struct pointee_impl> { using type = T; }; +template +struct pointee_impl> { + using type = T; +}; + template struct pointee_impl> { using type = T; @@ -276,7 +313,7 @@ inline typename std::remove_reference::type&& give( * same as calling .get() on the smart pointer. */ template -[[deprecated("no longer necessary due to pointer_param")]] inline +[[deprecated("no longer necessary, just pass the object without lend")]] inline typename std::enable_if::value, detail::pointee*>::type lend(const Pointer& p) @@ -295,7 +332,7 @@ template * returns `p`. */ template -[[deprecated("no longer necessary due to pointer_param")]] inline +[[deprecated("no longer necessary, just pass the object without lend")]] inline typename std::enable_if::value, detail::pointee*>::type lend(const Pointer& p) diff --git a/include/ginkgo/core/distributed/matrix.hpp b/include/ginkgo/core/distributed/matrix.hpp index d7ffd4acc51..0eaec728a1c 100644 --- a/include/ginkgo/core/distributed/matrix.hpp +++ b/include/ginkgo/core/distributed/matrix.hpp @@ -516,7 +516,8 @@ class Matrix * same runtime type. */ explicit Matrix(std::shared_ptr exec, - mpi::communicator comm, const LinOp* matrix_template); + mpi::communicator comm, + pointer_param matrix_template); /** * Creates an empty distributed matrix with specified types for the local @@ -533,8 +534,9 @@ class Matrix * constructed with the same runtime type. */ explicit Matrix(std::shared_ptr exec, - mpi::communicator comm, const LinOp* local_matrix_template, - const LinOp* non_local_matrix_template); + mpi::communicator comm, + pointer_param local_matrix_template, + pointer_param non_local_matrix_template); /** * Starts a non-blocking communication of the values of b that are shared diff --git a/include/ginkgo/core/matrix/coo.hpp b/include/ginkgo/core/matrix/coo.hpp index 9214dc2b0b3..5e6e7cf6b5d 100644 --- a/include/ginkgo/core/matrix/coo.hpp +++ b/include/ginkgo/core/matrix/coo.hpp @@ -220,8 +220,8 @@ class Coo : public EnableLinOp>, { this->validate_application_parameters(b.get(), x.get()); auto exec = this->get_executor(); - this->apply2_impl(make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply2_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return this; } @@ -233,8 +233,8 @@ class Coo : public EnableLinOp>, { this->validate_application_parameters(b.get(), x.get()); auto exec = this->get_executor(); - this->apply2_impl(make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply2_impl(make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return this; } @@ -253,9 +253,9 @@ class Coo : public EnableLinOp>, this->validate_application_parameters(b.get(), x.get()); GKO_ASSERT_EQUAL_DIMENSIONS(alpha, dim<2>(1, 1)); auto exec = this->get_executor(); - this->apply2_impl(make_temporary_clone(exec, alpha.get()).get(), - make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply2_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return this; } @@ -269,9 +269,9 @@ class Coo : public EnableLinOp>, this->validate_application_parameters(b.get(), x.get()); GKO_ASSERT_EQUAL_DIMENSIONS(alpha, dim<2>(1, 1)); auto exec = this->get_executor(); - this->apply2_impl(make_temporary_clone(exec, alpha.get()).get(), - make_temporary_clone(exec, b.get()).get(), - make_temporary_clone(exec, x.get()).get()); + this->apply2_impl(make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, b).get(), + make_temporary_clone(exec, x).get()); return this; } diff --git a/include/ginkgo/core/matrix/csr.hpp b/include/ginkgo/core/matrix/csr.hpp index be9681c6ad0..e9156dbe69f 100644 --- a/include/ginkgo/core/matrix/csr.hpp +++ b/include/ginkgo/core/matrix/csr.hpp @@ -930,7 +930,7 @@ class Csr : public EnableLinOp>, { auto exec = this->get_executor(); GKO_ASSERT_EQUAL_DIMENSIONS(alpha, dim<2>(1, 1)); - this->scale_impl(make_temporary_clone(exec, alpha.get()).get()); + this->scale_impl(make_temporary_clone(exec, alpha).get()); } /** @@ -943,7 +943,7 @@ class Csr : public EnableLinOp>, { auto exec = this->get_executor(); GKO_ASSERT_EQUAL_DIMENSIONS(alpha, dim<2>(1, 1)); - this->inv_scale_impl(make_temporary_clone(exec, alpha.get()).get()); + this->inv_scale_impl(make_temporary_clone(exec, alpha).get()); } /** diff --git a/include/ginkgo/core/matrix/dense.hpp b/include/ginkgo/core/matrix/dense.hpp index 27ae225f5b8..f634ff3ec7f 100644 --- a/include/ginkgo/core/matrix/dense.hpp +++ b/include/ginkgo/core/matrix/dense.hpp @@ -35,6 +35,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include @@ -1294,11 +1295,14 @@ struct temporary_clone_helper> { * * @param vector the vector on which to create the view */ -template -std::unique_ptr> make_dense_view( - matrix::Dense* vector) +template +std::unique_ptr< + matrix::Dense>::value_type>> +make_dense_view(VecPtr&& vector) { - return matrix::Dense::create_view_of(vector); + using value_type = + typename detail::pointee>::value_type; + return matrix::Dense::create_view_of(vector); } @@ -1309,11 +1313,14 @@ std::unique_ptr> make_dense_view( * * @param vector the vector on which to create the view */ -template -std::unique_ptr> make_const_dense_view( - const matrix::Dense* vector) +template +std::unique_ptr>::value_type>> +make_const_dense_view(VecPtr&& vector) { - return matrix::Dense::create_const_view_of(vector); + using value_type = + typename detail::pointee>::value_type; + return matrix::Dense::create_const_view_of(vector); } diff --git a/include/ginkgo/core/stop/criterion.hpp b/include/ginkgo/core/stop/criterion.hpp index d5df4378ad9..0e2b3a038fd 100644 --- a/include/ginkgo/core/stop/criterion.hpp +++ b/include/ginkgo/core/stop/criterion.hpp @@ -114,13 +114,22 @@ class Criterion : public EnableAbstractPolymorphicObject { return *this; \ } \ mutable _type _name##_ {} +#define GKO_UPDATER_REGISTER_PTR_PARAMETER(_type, _name) \ + const Updater& _name(pointer_param<_type> value) const \ + { \ + _name##_ = value.get(); \ + return *this; \ + } \ + mutable _type* _name##_ {} GKO_UPDATER_REGISTER_PARAMETER(size_type, num_iterations); - GKO_UPDATER_REGISTER_PARAMETER(const LinOp*, residual); - GKO_UPDATER_REGISTER_PARAMETER(const LinOp*, residual_norm); - GKO_UPDATER_REGISTER_PARAMETER(const LinOp*, implicit_sq_residual_norm); - GKO_UPDATER_REGISTER_PARAMETER(const LinOp*, solution); + GKO_UPDATER_REGISTER_PTR_PARAMETER(const LinOp, residual); + GKO_UPDATER_REGISTER_PTR_PARAMETER(const LinOp, residual_norm); + GKO_UPDATER_REGISTER_PTR_PARAMETER(const LinOp, + implicit_sq_residual_norm); + GKO_UPDATER_REGISTER_PTR_PARAMETER(const LinOp, solution); +#undef GKO_UPDATER_REGISTER_PTR_PARAMETER #undef GKO_UPDATER_REGISTER_PARAMETER private: