diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index 21f188ccc5f..047d362e8dd 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources(ginkgo PRIVATE base/array.cpp base/batch_multi_vector.cpp + base/block_operator.cpp base/combination.cpp base/composition.cpp base/dense_cache.cpp diff --git a/core/base/block_operator.cpp b/core/base/block_operator.cpp new file mode 100644 index 00000000000..1efc5e3b152 --- /dev/null +++ b/core/base/block_operator.cpp @@ -0,0 +1,302 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include + + +#include +#include + + +#include "core/base/dispatch_helper.hpp" + + +namespace gko { +namespace { + + +template +auto dispatch_dense(Fn&& fn, LinOp* v) +{ + return run*, matrix::Dense*, + matrix::Dense>*, + matrix::Dense>*>(v, std::forward(fn)); +} + + +template +auto create_vector_blocks(LinOpType* vector, + const std::vector& spans) +{ + return [=](size_type i) { + return dispatch_dense( + [&](auto* dense) -> std::unique_ptr { + GKO_ENSURE_IN_BOUNDS(i, spans.size()); + return dense->create_submatrix(spans[i], + {0, dense->get_size()[1]}); + }, + const_cast(vector)); + }; +} + + +const LinOp* find_non_zero_in_row( + const std::vector>>& blocks, + size_type row) +{ + auto it = std::find_if(blocks[row].begin(), blocks[row].end(), + [](const auto& b) { return b.get() != nullptr; }); + GKO_THROW_IF_INVALID(it != blocks[row].end(), + "Encountered row with only nullptrs."); + return it->get(); +} + + +const LinOp* find_non_zero_in_col( + const std::vector>>& blocks, + size_type col) +{ + auto it = std::find_if(blocks.begin(), blocks.end(), [col](const auto& b) { + return b[col].get() != nullptr; + }); + GKO_THROW_IF_INVALID(it != blocks.end(), + "Encountered columns with only nullptrs."); + return it->at(col).get(); +} + + +void validate_blocks( + const std::vector>>& blocks) +{ + GKO_THROW_IF_INVALID(blocks.empty() || !blocks.front().empty(), + "Blocks must either be empty, or a 2D std::vector."); + // all rows have same number of columns + for (size_type row = 1; row < blocks.size(); ++row) { + GKO_ASSERT_EQ(blocks[row].size(), blocks.front().size()); + } + // within each row and each column the blocks have the same number of rows + // and columns respectively + for (size_type row = 0; row < blocks.size(); ++row) { + auto non_zero_row = find_non_zero_in_row(blocks, row); + for (size_type col = 0; col < blocks.front().size(); ++col) { + auto non_zero_col = find_non_zero_in_col(blocks, col); + if (blocks[row][col]) { + GKO_ASSERT_EQUAL_COLS(blocks[row][col], non_zero_col); + GKO_ASSERT_EQUAL_ROWS(blocks[row][col], non_zero_row); + } + } + } +} + + +template +std::vector compute_local_spans( + size_type num_blocks, + const std::vector>>& blocks, + Fn&& get_size) +{ + validate_blocks(blocks); + std::vector local_spans; + size_type offset = 0; + for (size_type i = 0; i < num_blocks; ++i) { + auto local_size = get_size(i); + local_spans.emplace_back(offset, offset + local_size); + offset += local_size; + } + return local_spans; +} + + +dim<2> compute_global_size( + const std::vector>>& blocks) +{ + validate_blocks(blocks); + if (blocks.empty()) { + return {}; + } + size_type num_rows = 0; + for (size_type row = 0; row < blocks.size(); ++row) { + num_rows += find_non_zero_in_row(blocks, row)->get_size()[0]; + } + size_type num_cols = 0; + for (size_type col = 0; col < blocks.front().size(); ++col) { + num_cols += find_non_zero_in_col(blocks, col)->get_size()[1]; + } + return {num_rows, num_cols}; +} + + +} // namespace + + +std::unique_ptr BlockOperator::create( + std::shared_ptr exec) +{ + return std::unique_ptr(new BlockOperator(std::move(exec))); +} + + +std::unique_ptr BlockOperator::create( + std::shared_ptr exec, + std::vector>> blocks) +{ + return std::unique_ptr( + new BlockOperator(std::move(exec), std::move(blocks))); +} + + +BlockOperator::BlockOperator(std::shared_ptr exec) + : EnableLinOp(std::move(exec)) +{} + + +BlockOperator::BlockOperator( + std::shared_ptr exec, + std::vector>> blocks) + : EnableLinOp(exec, compute_global_size(blocks)), + block_size_(blocks.empty() + ? dim<2>{} + : dim<2>(blocks.size(), blocks.front().size())), + row_spans_(compute_local_spans( + block_size_[0], blocks, + [&](auto i) { + return find_non_zero_in_row(blocks, i)->get_size()[0]; + })), + col_spans_(compute_local_spans(block_size_[1], blocks, [&](auto i) { + return find_non_zero_in_col(blocks, i)->get_size()[1]; + })) +{ + for (auto& row : blocks) { + for (auto& block : row) { + if (block && block->get_executor() != exec) { + blocks_.push_back(gko::clone(exec, block)); + } else { + blocks_.push_back(std::move(block)); + } + } + } +} + + +void init_one_cache(std::shared_ptr exec, + const detail::DenseCache& one_cache) +{ + if (one_cache.get() == nullptr) { + one_cache.init(std::move(exec), {1, 1}); + one_cache->fill(one()); + } +} + + +void BlockOperator::apply_impl(const LinOp* b, LinOp* x) const +{ + auto block_b = create_vector_blocks(b, col_spans_); + auto block_x = create_vector_blocks(x, row_spans_); + + init_one_cache(this->get_executor(), one_); + for (size_type row = 0; row < block_size_[0]; ++row) { + bool first_in_row = true; + for (size_type col = 0; col < block_size_[1]; ++col) { + if (!block_at(row, col)) { + continue; + } + if (first_in_row) { + block_at(row, col)->apply(block_b(col), block_x(row)); + first_in_row = false; + } else { + block_at(row, col)->apply(one_.get(), block_b(col), one_.get(), + block_x(row)); + } + } + } +} + + +void BlockOperator::apply_impl(const LinOp* alpha, const LinOp* b, + const LinOp* beta, LinOp* x) const +{ + auto block_b = create_vector_blocks(b, col_spans_); + auto block_x = create_vector_blocks(x, row_spans_); + + init_one_cache(this->get_executor(), one_); + for (size_type row = 0; row < block_size_[0]; ++row) { + bool first_in_row = true; + for (size_type col = 0; col < block_size_[1]; ++col) { + if (!block_at(row, col)) { + continue; + } + if (first_in_row) { + block_at(row, col)->apply(alpha, block_b(col), beta, + block_x(row)); + first_in_row = false; + } else { + block_at(row, col)->apply(alpha, block_b(col), one_.get(), + block_x(row)); + } + } + } +} + + +BlockOperator::BlockOperator(const BlockOperator& other) + : EnableLinOp(other.get_executor()) +{ + *this = other; +} + + +BlockOperator::BlockOperator(BlockOperator&& other) noexcept + : EnableLinOp(other.get_executor()) +{ + *this = std::move(other); +} + + +BlockOperator& BlockOperator::operator=(const BlockOperator& other) +{ + if (this != &other) { + auto exec = this->get_executor(); + + set_size(other.get_size()); + block_size_ = other.get_block_size(); + col_spans_ = other.col_spans_; + row_spans_ = other.row_spans_; + blocks_.clear(); + for (const auto& block : other.blocks_) { + blocks_.emplace_back(block == nullptr ? nullptr + : gko::clone(exec, block)); + } + } + return *this; +} + + +BlockOperator& BlockOperator::operator=(BlockOperator&& other) +{ + if (this != &other) { + auto exec = this->get_executor(); + + set_size(other.get_size()); + other.set_size({}); + + block_size_ = std::exchange(other.block_size_, dim<2>{}); + col_spans_ = std::move(other.col_spans_); + row_spans_ = std::move(other.row_spans_); + blocks_ = std::move(other.blocks_); + if (exec != other.get_executor()) { + for (auto& block : blocks_) { + if (block != nullptr) { + block = gko::clone(exec, block); + } + } + } + } + return *this; +} + + +} // namespace gko diff --git a/core/base/dispatch_helper.hpp b/core/base/dispatch_helper.hpp index 5c030dedd17..62eccfde0c6 100644 --- a/core/base/dispatch_helper.hpp +++ b/core/base/dispatch_helper.hpp @@ -13,20 +13,55 @@ namespace gko { +namespace detail { + + +/** + * + * @copydoc run + * + * @note this is the end case + */ +template +ReturnType run_impl(T obj, Func&&, Args&&...) +{ + GKO_NOT_SUPPORTED(obj); +} + +/** + * @copydoc run + * + * @note This has additionally the return type encoded. + */ +template +ReturnType run_impl(T obj, Func&& f, Args&&... args) +{ + if (auto dobj = dynamic_cast(obj)) { + return f(dobj, std::forward(args)...); + } else { + return run_impl(obj, std::forward(f), + std::forward(args)...); + } +} /** * run uses template to go through the list and select the valid * template and run it. * - * @tparam T the type of input object - * @tparam Func the function will run if the object can be converted to K - * @tparam ...Args the additional arguments for the Func + * @tparam Base the Base class with one template + * @tparam T the type of input object waiting converted + * @tparam Func the validation + * @tparam ...Args the variadic arguments. * * @note this is the end case */ -template -void run(T obj, Func, Args...) +template class Base, typename T, + typename Func, typename... Args> +ReturnType run_impl(T obj, Func, Args...) { GKO_NOT_SUPPORTED(obj); } @@ -35,45 +70,68 @@ void run(T obj, Func, Args...) * run uses template to go through the list and select the valid * template and run it. * - * @tparam K the current type tried in the conversion + * @tparam Base the Base class with one template + * @tparam K the current template type of B. pointer of const Base is tried + * in the conversion. * @tparam ...Types the other types will be tried in the conversion if K fails - * @tparam T the type of input object - * @tparam Func the function will run if the object can be converted to K + * @tparam T the type of input object waiting converted + * @tparam Func the function will run if the object can be converted to pointer + * of const Base * @tparam ...Args the additional arguments for the Func * * @param obj the input object waiting converted * @param f the function will run if obj can be converted successfully * @param args the additional arguments for the function */ -template -void run(T obj, Func f, Args... args) +template class Base, typename K, + typename... Types, typename T, typename Func, typename... Args> +ReturnType run_impl(T obj, Func&& f, Args&&... args) { - if (auto dobj = dynamic_cast(obj)) { - f(dobj, args...); + if (auto dobj = std::dynamic_pointer_cast>(obj)) { + return f(dobj, args...); } else { - run(obj, f, args...); + return run_impl( + obj, std::forward(f), std::forward(args)...); } } + +} // namespace detail + + /** * run uses template to go through the list and select the valid * template and run it. * - * @tparam Base the Base class with one template - * @tparam T the type of input object waiting converted - * @tparam Func the validation - * @tparam ...Args the variadic arguments. + * @tparam K the current type tried in the conversion + * @tparam ...Types the other types will be tried in the conversion if K fails + * @tparam T the type of input object + * @tparam Func the function will run if the object can be converted to K + * @tparam ...Args the additional arguments for the Func * - * @note this is the end case + * @param obj the input object waiting converted + * @param f the function will run if obj can be converted successfully + * @param args the additional arguments for the function + * + * @note This assumes that each invocation of f with types (K, Types...) + * returns the same type + * + * @return the result of f invoked with obj cast to the first matching type */ -template