diff --git a/core/base/index_set.cpp b/core/base/index_set.cpp index eb383882f49..bf61689d567 100644 --- a/core/base/index_set.cpp +++ b/core/base/index_set.cpp @@ -69,6 +69,40 @@ void IndexSet::populate_subsets(const gko::Array &indices) } +template +IndexType IndexSet::get_global_index(const IndexType &index) const +{ + auto exec = this->get_executor(); + auto loc_idx = + Array(exec, std::initializer_list{index}); + auto glob_idx = + Array(exec, std::initializer_list{index}); + + GKO_ASSERT(this->get_num_subsets() >= 1); + exec->run(index_set::make_local_to_global( + this->index_space_size_, &this->subsets_begin_, &this->subsets_end_, + &this->superset_cumulative_indices_, &loc_idx, &glob_idx)); + return glob_idx.get_data()[0]; +} + + +template +IndexType IndexSet::get_local_index(const IndexType &index) const +{ + auto exec = this->get_executor(); + auto loc_idx = + Array(exec, std::initializer_list{index}); + auto glob_idx = + Array(exec, std::initializer_list{index}); + + GKO_ASSERT(this->get_num_subsets() >= 1); + exec->run(index_set::make_global_to_local( + this->index_space_size_, &this->subsets_begin_, &this->subsets_end_, + &this->superset_cumulative_indices_, &glob_idx, &loc_idx)); + return loc_idx.get_data()[0]; +} + + template Array IndexSet::get_global_indices_from_local( const Array &local_indices) const @@ -78,7 +112,7 @@ Array IndexSet::get_global_indices_from_local( gko::Array(exec, local_indices.get_num_elems()); GKO_ASSERT(this->get_num_subsets() >= 1); - exec->run(index_set::make_global_to_local( + exec->run(index_set::make_local_to_global( this->index_space_size_, &this->subsets_begin_, &this->subsets_end_, &this->superset_cumulative_indices_, &local_indices, &global_indices)); return std::move(global_indices); @@ -90,10 +124,11 @@ Array IndexSet::get_local_indices_from_global( const Array &global_indices) const { auto exec = this->get_executor(); - auto local_indices = gko::Array(exec); + auto local_indices = + gko::Array(exec, global_indices.get_num_elems()); GKO_ASSERT(this->get_num_subsets() >= 1); - exec->run(index_set::make_local_to_global( + exec->run(index_set::make_global_to_local( this->index_space_size_, &this->subsets_begin_, &this->subsets_end_, &this->superset_cumulative_indices_, &global_indices, &local_indices)); return std::move(local_indices); diff --git a/core/test/base/index_set.cpp b/core/test/base/index_set.cpp index 8a2945064b5..3f723af6b08 100644 --- a/core/test/base/index_set.cpp +++ b/core/test/base/index_set.cpp @@ -78,16 +78,6 @@ class IndexSet : public ::testing::Test { } } - static void assert_equal_arrays(const T num_elems, const T *a, const T *b) - { - if (num_elems > 0) { - for (auto i = 0; i < num_elems; ++i) { - ASSERT_EQ(a[i], b[i]); - } - } - } - - std::shared_ptr exec; }; @@ -149,48 +139,4 @@ TYPED_TEST(IndexSet, KnowsItsSize) } -TYPED_TEST(IndexSet, CanBeConstructedFromIndices) -{ - auto idx_arr = gko::Array{this->exec, {0, 1, 2, 4, 6, 7, 8, 9}}; - auto begin_comp = gko::Array{this->exec, {0, 4, 6}}; - auto end_comp = gko::Array{this->exec, {3, 5, 10}}; - auto superset_comp = gko::Array{this->exec, {0, 3, 4, 8}}; - - auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; - - ASSERT_EQ(idx_set.get_size(), 10); - ASSERT_EQ(idx_set.get_num_subsets(), 3); - ASSERT_EQ(idx_set.get_num_subsets(), begin_comp.get_num_elems()); - auto num_elems = idx_set.get_num_subsets(); - this->assert_equal_arrays(num_elems, idx_set.get_subsets_begin(), - begin_comp.get_data()); - this->assert_equal_arrays(num_elems, idx_set.get_subsets_end(), - end_comp.get_data()); - this->assert_equal_arrays(num_elems, idx_set.get_superset_indices(), - superset_comp.get_data()); -} - - -TYPED_TEST(IndexSet, CanBeConstructedFromNonSortedIndices) -{ - auto idx_arr = gko::Array{this->exec, {9, 1, 4, 2, 6, 8, 0, 7}}; - auto begin_comp = gko::Array{this->exec, {0, 4, 6}}; - auto end_comp = gko::Array{this->exec, {3, 5, 10}}; - auto superset_comp = gko::Array{this->exec, {0, 3, 4, 8}}; - - auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; - - ASSERT_EQ(idx_set.get_size(), 10); - ASSERT_EQ(idx_set.get_num_subsets(), 3); - ASSERT_EQ(idx_set.get_num_subsets(), begin_comp.get_num_elems()); - auto num_elems = idx_set.get_num_subsets(); - this->assert_equal_arrays(num_elems, idx_set.get_subsets_begin(), - begin_comp.get_data()); - this->assert_equal_arrays(num_elems, idx_set.get_subsets_end(), - end_comp.get_data()); - this->assert_equal_arrays(num_elems, idx_set.get_superset_indices(), - superset_comp.get_data()); -} - - } // namespace diff --git a/include/ginkgo/core/base/index_set.hpp b/include/ginkgo/core/base/index_set.hpp index 9c0ea29885b..83efe262e40 100644 --- a/include/ginkgo/core/base/index_set.hpp +++ b/include/ginkgo/core/base/index_set.hpp @@ -116,6 +116,10 @@ class IndexSet { index_type get_num_elems() const { return this->num_stored_indices_; }; + index_type get_global_index(const index_type &local_index) const; + + index_type get_local_index(const index_type &global_index) const; + Array get_global_indices_from_local( const Array &local_indices) const; diff --git a/reference/base/index_set_kernels.cpp b/reference/base/index_set_kernels.cpp index 24842c65b9a..b7338d4cbcc 100644 --- a/reference/base/index_set_kernels.cpp +++ b/reference/base/index_set_kernels.cpp @@ -84,14 +84,16 @@ void populate_subsets(std::shared_ptr exec, tmp_subset_begin.push_back(tmp_indices.get_data()[0]); tmp_subset_superset_index.push_back(0); for (auto i = 1; i < num_indices; ++i) { - if (tmp_indices.get_data()[i] == (tmp_indices.get_data()[i - 1]) + 1) { + if ((tmp_indices.get_data()[i] == + (tmp_indices.get_data()[i - 1] + 1)) || + (tmp_indices.get_data()[i] == tmp_indices.get_data()[i - 1])) { continue; } else { tmp_subset_end.push_back(tmp_indices.get_data()[i - 1] + 1); tmp_subset_superset_index.push_back( tmp_subset_superset_index.back() + tmp_subset_end.back() - tmp_subset_begin.back()); - if (i + 1 < num_indices) { + if (i < num_indices) { tmp_subset_begin.push_back(tmp_indices.get_data()[i]); } } @@ -126,7 +128,32 @@ void global_to_local(std::shared_ptr exec, const Array *superset_indices, const Array *global_indices, Array *local_indices) -{} +{ + for (auto i = 0; i < global_indices->get_num_elems(); ++i) { + auto index = global_indices->get_const_data()[i]; + GKO_ASSERT(index < index_space_size); + auto bucket = std::distance( + subset_begin->get_const_data(), + std::lower_bound( + subset_begin->get_const_data(), + subset_begin->get_const_data() + subset_begin->get_num_elems(), + index, [](const IndexType &sub_idx, const IndexType idx) { + return sub_idx <= idx; + })); + auto shifted_bucket = bucket == 0 ? 0 : (bucket - 1); + if (subset_end->get_const_data()[shifted_bucket] <= index) { + local_indices->get_data()[i] = -1; + } else { + local_indices->get_data()[i] = + index - subset_begin->get_const_data()[shifted_bucket] + + superset_indices->get_const_data()[shifted_bucket]; + } + std::cout << " g index " << index << " bucket " << bucket << " l idx " + << local_indices->get_data()[i] << " subset begin " + << subset_begin->get_const_data()[shifted_bucket] + << std::endl; + } +} GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE( GKO_DECLARE_INDEX_SET_GLOBAL_TO_LOCAL_KERNEL); @@ -140,7 +167,29 @@ void local_to_global(std::shared_ptr exec, const Array *superset_indices, const Array *local_indices, Array *global_indices) -{} +{ + for (auto i = 0; i < local_indices->get_num_elems(); ++i) { + auto index = local_indices->get_const_data()[i]; + GKO_ASSERT( + index <= + (superset_indices + ->get_const_data()[superset_indices->get_num_elems() - 1] - + 1)); + auto bucket = std::distance( + superset_indices->get_const_data(), + std::lower_bound(superset_indices->get_const_data(), + superset_indices->get_const_data() + + superset_indices->get_num_elems(), + index, + [](const IndexType &sup_idx, const IndexType idx) { + return sup_idx <= idx; + })); + auto shifted_bucket = bucket == 0 ? 0 : (bucket - 1); + global_indices->get_data()[i] = + subset_begin->get_const_data()[shifted_bucket] + index - + superset_indices->get_const_data()[shifted_bucket]; + } +} GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE( GKO_DECLARE_INDEX_SET_LOCAL_TO_GLOBAL_KERNEL); diff --git a/reference/test/base/CMakeLists.txt b/reference/test/base/CMakeLists.txt index 3386bb01e20..661a60caf9d 100644 --- a/reference/test/base/CMakeLists.txt +++ b/reference/test/base/CMakeLists.txt @@ -1,4 +1,5 @@ ginkgo_create_test(combination) ginkgo_create_test(composition) +ginkgo_create_test(index_set) ginkgo_create_test(perturbation) ginkgo_create_test(utils) diff --git a/reference/test/base/index_set.cpp b/reference/test/base/index_set.cpp new file mode 100644 index 00000000000..47df1720b96 --- /dev/null +++ b/reference/test/base/index_set.cpp @@ -0,0 +1,210 @@ +/************************************************************* +Copyright (c) 2017-2020, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include + + +#include +#include +#include +#include + + +#include "core/test/utils.hpp" + + +namespace { + + +template +class IndexSet : public ::testing::Test { +protected: + using value_type = T; + IndexSet() : exec(gko::ReferenceExecutor::create()) {} + + void TearDown() + { + if (exec != nullptr) { + // ensure that previous calls finished and didn't throw an error + ASSERT_NO_THROW(exec->synchronize()); + } + } + + static void assert_equal_to_original(gko::IndexSet &a) + { + ASSERT_EQ(a.get_size(), 10); + } + + static void assert_equal_index_sets(gko::IndexSet &a, + gko::IndexSet &b) + { + ASSERT_EQ(a.get_size(), b.get_size()); + ASSERT_EQ(a.get_num_subsets(), b.get_num_subsets()); + if (a.get_num_subsets() > 0) { + for (auto i = 0; i < a.get_num_subsets(); ++i) { + EXPECT_EQ(a.get_subsets_begin()[i], b.get_subsets_begin()[i]); + EXPECT_EQ(a.get_subsets_end()[i], b.get_subsets_end()[i]); + EXPECT_EQ(a.get_superset_indices()[i], + b.get_superset_indices()[i]); + } + } + } + + static void assert_equal_arrays(const T num_elems, const T *a, const T *b) + { + if (num_elems > 0) { + for (auto i = 0; i < num_elems; ++i) { + EXPECT_EQ(a[i], b[i]); + } + } + } + + + std::shared_ptr exec; +}; + +TYPED_TEST_SUITE(IndexSet, gko::test::IndexTypes); + + +TYPED_TEST(IndexSet, CanBeConstructedFromIndices) +{ + auto idx_arr = gko::Array{this->exec, {0, 1, 2, 4, 6, 7, 8, 9}}; + auto begin_comp = gko::Array{this->exec, {0, 4, 6}}; + auto end_comp = gko::Array{this->exec, {3, 5, 10}}; + auto superset_comp = gko::Array{this->exec, {0, 3, 4, 8}}; + + auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; + + ASSERT_EQ(idx_set.get_size(), 10); + ASSERT_EQ(idx_set.get_num_subsets(), 3); + ASSERT_EQ(idx_set.get_num_subsets(), begin_comp.get_num_elems()); + auto num_subsets = idx_set.get_num_subsets(); + this->assert_equal_arrays(num_subsets, idx_set.get_subsets_begin(), + begin_comp.get_data()); + this->assert_equal_arrays(num_subsets, idx_set.get_subsets_end(), + end_comp.get_data()); + this->assert_equal_arrays(num_subsets, idx_set.get_superset_indices(), + superset_comp.get_data()); +} + + +TYPED_TEST(IndexSet, CanBeConstructedFromNonSortedIndices) +{ + auto idx_arr = gko::Array{this->exec, {9, 1, 4, 2, 6, 8, 0, 7}}; + auto begin_comp = gko::Array{this->exec, {0, 4, 6}}; + auto end_comp = gko::Array{this->exec, {3, 5, 10}}; + auto superset_comp = gko::Array{this->exec, {0, 3, 4, 8}}; + + auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; + + ASSERT_EQ(idx_set.get_size(), 10); + ASSERT_EQ(idx_set.get_num_subsets(), 3); + ASSERT_EQ(idx_set.get_num_subsets(), begin_comp.get_num_elems()); + auto num_subsets = idx_set.get_num_subsets(); + this->assert_equal_arrays(num_subsets, idx_set.get_subsets_begin(), + begin_comp.get_data()); + this->assert_equal_arrays(num_subsets, idx_set.get_subsets_end(), + end_comp.get_data()); + this->assert_equal_arrays(num_subsets, idx_set.get_superset_indices(), + superset_comp.get_data()); +} + + +TYPED_TEST(IndexSet, CanGetGlobalIndex) +{ + auto idx_arr = gko::Array{this->exec, {0, 1, 2, 4, 6, 7, 8, 9}}; + auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; + ASSERT_EQ(idx_set.get_num_elems(), 8); + EXPECT_EQ(idx_set.get_global_index(0), 0); + EXPECT_EQ(idx_set.get_global_index(1), 1); + EXPECT_EQ(idx_set.get_global_index(2), 2); + EXPECT_EQ(idx_set.get_global_index(3), 4); + EXPECT_EQ(idx_set.get_global_index(4), 6); + EXPECT_EQ(idx_set.get_global_index(5), 7); + EXPECT_EQ(idx_set.get_global_index(6), 8); + EXPECT_EQ(idx_set.get_global_index(7), 9); +} + + +TYPED_TEST(IndexSet, CanGetGlobalIndexFromArrays) +{ + auto idx_arr = gko::Array{this->exec, {0, 1, 2, 4, 6, 7, 8, 9}}; + auto lidx_arr = gko::Array{this->exec, {0, 1, 4, 6, 7}}; + auto gidx_arr = gko::Array{this->exec, {0, 1, 6, 8, 9}}; + auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; + ASSERT_EQ(idx_set.get_num_elems(), 8); + auto idx_set_gidx = idx_set.get_global_indices_from_local(lidx_arr); + this->assert_equal_arrays(gidx_arr.get_num_elems(), + idx_set_gidx.get_const_data(), + gidx_arr.get_const_data()); +} + + +TYPED_TEST(IndexSet, CanGetLocalIndex) +{ + auto idx_arr = gko::Array{this->exec, {0, 1, 2, 4, 6, 7, 8, 9}}; + auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; + ASSERT_EQ(idx_set.get_num_elems(), 8); + EXPECT_EQ(idx_set.get_local_index(6), 4); + EXPECT_EQ(idx_set.get_local_index(7), 5); + EXPECT_EQ(idx_set.get_local_index(0), 0); + EXPECT_EQ(idx_set.get_local_index(8), 6); + EXPECT_EQ(idx_set.get_local_index(4), 3); +} + + +TYPED_TEST(IndexSet, CanDetectNonExistentIndices) +{ + auto idx_arr = gko::Array{ + this->exec, {0, 8, 1, 2, 3, 4, 6, 11, 9, 5, 7, 28, 39}}; + auto idx_set = gko::IndexSet{this->exec, 45, idx_arr}; + ASSERT_EQ(idx_set.get_num_elems(), 13); + EXPECT_EQ(idx_set.get_local_index(11), 10); + EXPECT_EQ(idx_set.get_local_index(22), -1); +} + + +TYPED_TEST(IndexSet, CanGetLocalIndexFromArrays) +{ + auto idx_arr = gko::Array{this->exec, {0, 1, 2, 4, 6, 7, 8, 9}}; + auto gidx_arr = gko::Array{this->exec, {6, 0, 4, 8, 9}}; + auto lidx_arr = gko::Array{this->exec, {4, 0, 3, 6, 7}}; + auto idx_set = gko::IndexSet{this->exec, 10, idx_arr}; + ASSERT_EQ(idx_set.get_num_elems(), 8); + auto idx_set_lidx = idx_set.get_local_indices_from_global(gidx_arr); + this->assert_equal_arrays(lidx_arr.get_num_elems(), + idx_set_lidx.get_const_data(), + lidx_arr.get_const_data()); +} + + +} // namespace