diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index 4a19ca18962..a81493d2666 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -16,6 +16,7 @@ target_sources(ginkgo base/mpi.cpp base/mtx_io.cpp base/perturbation.cpp + base/segmented_array.cpp base/timer.cpp base/version.cpp config/property_tree.cpp diff --git a/core/base/segmented_array.cpp b/core/base/segmented_array.cpp new file mode 100644 index 00000000000..4fba0851bfc --- /dev/null +++ b/core/base/segmented_array.cpp @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#include + + +#include "core/base/array_access.hpp" +#include "core/components/prefix_sum_kernels.hpp" + + +namespace gko { +namespace { + + +GKO_REGISTER_OPERATION(prefix_sum, components::prefix_sum_nonnegative); + + +} + +template +segmented_array::segmented_array(std::shared_ptr exec) + : buffer_(exec), offsets_(exec, 1) +{ + offsets_.fill(0); +} + + +array sizes_to_offsets(const gko::array& sizes) +{ + auto exec = sizes.get_executor(); + array offsets(exec, sizes.get_size() + 1); + exec->copy(sizes.get_size(), sizes.get_const_data(), offsets.get_data()); + exec->run(make_prefix_sum(offsets.get_data(), offsets.get_size())); + return offsets; +} + + +template +segmented_array segmented_array::create_from_sizes( + const gko::array& sizes) +{ + return create_from_offsets(sizes_to_offsets(sizes)); +} + + +template +segmented_array segmented_array::create_from_sizes( + gko::array buffer, const gko::array& sizes) +{ + return create_from_offsets(std::move(buffer), sizes_to_offsets(sizes)); +} + + +template +segmented_array segmented_array::create_from_offsets( + gko::array offsets) +{ + GKO_THROW_IF_INVALID(offsets.get_size() > 0, + "The offsets for segmented_arrays require at least " + "one element."); + auto size = + static_cast(get_element(offsets, offsets.get_size() - 1)); + return create_from_offsets(array{offsets.get_executor(), size}, + std::move(offsets)); +} + + +template +segmented_array segmented_array::create_from_offsets( + gko::array buffer, gko::array offsets) +{ + GKO_ASSERT_EQ(buffer.get_size(), + get_element(offsets, offsets.get_size() - 1)); + segmented_array result(buffer.get_executor()); + result.offsets_ = std::move(offsets); + result.buffer_ = std::move(buffer); + return result; +} + + +template +segmented_array::segmented_array(std::shared_ptr exec, + segmented_array&& other) + : segmented_array(exec) +{ + *this = std::move(other); +} + + +template +segmented_array::segmented_array(std::shared_ptr exec, + const segmented_array& other) + : segmented_array(exec) +{ + *this = other; +} + + +template +segmented_array::segmented_array(const segmented_array& other) + : segmented_array(other.get_executor()) +{ + *this = other; +} + + +template +segmented_array::segmented_array(segmented_array&& other) + : segmented_array(other.get_executor()) +{ + *this = std::move(other); +} + + +template +segmented_array& segmented_array::operator=(const segmented_array& other) +{ + if (this != &other) { + buffer_ = other.buffer_; + offsets_ = other.offsets_; + } + return *this; +} + + +template +segmented_array& segmented_array::operator=(segmented_array&& other) +{ + if (this != &other) { + buffer_ = std::move(other.buffer_); + offsets_ = std::exchange(other.offsets_, + array{other.get_executor(), {0}}); + } + return *this; +} + + +#define GKO_DECLARE_SEGMENTED_ARRAY(_type) class segmented_array<_type> + +GKO_INSTANTIATE_FOR_EACH_POD_TYPE(GKO_DECLARE_SEGMENTED_ARRAY); + + +} // namespace gko diff --git a/core/base/segmented_array.hpp b/core/base/segmented_array.hpp new file mode 100644 index 00000000000..a3d8c5dc337 --- /dev/null +++ b/core/base/segmented_array.hpp @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GINKGO_SEGMENTED_ARRAY_HPP +#define GINKGO_SEGMENTED_ARRAY_HPP + + +#include + + +namespace gko { + + +/** + * Helper struct storing an array segment + * + * @tparam T The value type of the array + */ +template +struct array_segment { + T* begin; + T* end; +}; + + +/** + * Helper function to create a device-compatible view of an array segment. + */ +template +constexpr array_segment get_array_segment(segmented_array& sarr, + size_type segment_id) +{ + assert(segment_id < sarr.get_segment_count()); + auto offsets = sarr.get_offsets().get_const_data(); + auto data = sarr.get_flat_data(); + return {data + offsets[segment_id], data + offsets[segment_id + 1]}; +} + + +/** + * Helper function to create a device-compatible view of a const array segment. + */ +template +constexpr array_segment get_array_segment( + const segmented_array& sarr, size_type segment_id) +{ + assert(segment_id < sarr.get_segment_count()); + auto offsets = sarr.get_offsets().get_const_data(); + auto data = sarr.get_const_flat_data(); + return {data + offsets[segment_id], data + offsets[segment_id + 1]}; +} + + +} // namespace gko + +#endif // GINKGO_SEGMENTED_ARRAY_HPP diff --git a/include/ginkgo/core/base/segmented_array.hpp b/include/ginkgo/core/base/segmented_array.hpp new file mode 100644 index 00000000000..40b4691a580 --- /dev/null +++ b/include/ginkgo/core/base/segmented_array.hpp @@ -0,0 +1,225 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#pragma once +#include + + +#include +#include + + +namespace gko { + +/** + * \brief A minimal interface for a segmented array. + * + * The segmented array is stored as a flat buffer with an offsets array. + * The segment `i` contains the index range `[offset[i], offset[i + 1])` of the + * flat buffer. + * + * \tparam T value type stored in the arrays + */ +template +struct segmented_array { + /** + * Create an empty segmented array + * + * @param exec executor for storage arrays + */ + explicit segmented_array(std::shared_ptr exec); + + /** + * Creates an uninitialized segmented array with predefined segment sizes. + * + * @param exec executor for storage arrays + * @param sizes the sizes of each segment + */ + static segmented_array create_from_sizes(const gko::array& sizes); + + /** + * Creates a segmented array from a flat buffer and segment sizes. + * + * @param buffer the flat buffer whose size has to match the sum of sizes + * @param sizes the sizes of each segment + */ + static segmented_array create_from_sizes(gko::array buffer, + const gko::array& sizes); + + /** + * Creates an uninitialized segmented array from offsets. + * + * @param offsets the index offsets for each segment, and the total size of + * the buffer as last element + */ + static segmented_array create_from_offsets(gko::array offsets); + + /** + * Creates a segmented array from a flat buffer and offsets. + * + * @param buffer the flat buffer whose size has to match the last element + * of offsets + * @param offsets the index offsets for each segment, and the total size of + * the buffer as last element + */ + static segmented_array create_from_offsets(gko::array buffer, + gko::array offsets); + + /** + * Copies a segmented array to a different executor. + * + * @param exec the executor to copy to + * @param other the segmented array to copy from + */ + segmented_array(std::shared_ptr exec, + const segmented_array& other); + + /** + * Moves a segmented array to a different executor. + * + * @param exec the executor to move to + * @param other the segmented array to move from + */ + segmented_array(std::shared_ptr exec, + segmented_array&& other); + + segmented_array(const segmented_array& other); + + segmented_array(segmented_array&& other) noexcept(false); + + segmented_array& operator=(const segmented_array& other); + + segmented_array& operator=(segmented_array&&) noexcept(false); + + /** + * Get the total size of the stored buffer. + * + * @return the total size of the stored buffer. + */ + size_type get_size() const; + + /** + * Get the number of segments. + * + * @return the number of segments + */ + size_type get_segment_count() const; + + /** + * Access to the flat buffer. + * + * @return the flat buffer + */ + T* get_flat_data(); + + /** + * Const-access to the flat buffer + * + * @return the flat buffer + */ + const T* get_const_flat_data() const; + + /** + * Access to the segment offsets. + * + * @return the segment offsets + */ + const gko::array& get_offsets() const; + + /** + * Access the executor. + * + * @return the executor + */ + std::shared_ptr get_executor() const; + +private: + gko::array buffer_; + gko::array offsets_; +}; + +template +size_type segmented_array::get_size() const +{ + return buffer_.get_size(); +} + + +template +size_type segmented_array::get_segment_count() const +{ + return offsets_.get_size() ? offsets_.get_size() - 1 : 0; +} + + +template +T* segmented_array::get_flat_data() +{ + return buffer_.get_data(); +} + + +template +const T* segmented_array::get_const_flat_data() const +{ + return buffer_.get_const_data(); +} + + +template +const gko::array& segmented_array::get_offsets() const +{ + return offsets_; +} + + +template +std::shared_ptr segmented_array::get_executor() const +{ + return buffer_.get_executor(); +} + + +namespace detail { +template +struct temporary_clone_helper> { + static std::unique_ptr> create( + std::shared_ptr exec, segmented_array* ptr, + bool copy_data) + { + if (copy_data) { + return std::make_unique>( + make_array_view(exec, ptr->get_size(), ptr->get_flat_data()), + ptr->get_offsets()); + } else { + return std::make_unique>(std::move(exec), + ptr->get_offsets()); + } + } +}; + +template +struct temporary_clone_helper> { + static std::unique_ptr> create( + std::shared_ptr exec, const segmented_array* ptr, + bool) + { + return std::make_unique>( + make_array_view(exec, ptr->get_size(), ptr->get_const_flat_data()), + ptr->get_offsets()); + } +}; + + +template +class copy_back_deleter> + : public copy_back_deleter_from_assignment> { +public: + using copy_back_deleter_from_assignment< + segmented_array>::copy_back_deleter_from_assignment; +}; + + +} // namespace detail +} // namespace gko diff --git a/include/ginkgo/ginkgo.hpp b/include/ginkgo/ginkgo.hpp index 8a42757c9d4..9052565c9f3 100644 --- a/include/ginkgo/ginkgo.hpp +++ b/include/ginkgo/ginkgo.hpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include