Skip to content

Commit

Permalink
[dist] review updates:
Browse files Browse the repository at this point in the history
- fix binary search usage in cuda/hip
- refactor
- add tests for segmented array assertion

Co-authored-by: Tobias Ribizel <mail@ribizel.de>
Co-authored-by: Pratik Nayak <pratik.nayak4@gmail.com>
Co-authored-by: Yu-Hsiang M. Tsai <yhmtsai@gmail.com>
  • Loading branch information
4 people committed Feb 18, 2025
1 parent 3579168 commit 5b5ddac
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 179 deletions.
37 changes: 21 additions & 16 deletions common/cuda_hip/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,20 +302,20 @@ void map_to_global(
device_partition<const LocalIndexType, const GlobalIndexType> partition,
device_segmented_array<const GlobalIndexType> remote_global_idxs,
experimental::distributed::comm_index_type rank,
const array<LocalIndexType>& local_ids,
const array<LocalIndexType>& local_idxs,
experimental::distributed::index_space is,
array<GlobalIndexType>& global_ids)
array<GlobalIndexType>& global_idxs)
{
auto range_bounds = partition.offsets_begin;
auto starting_indices = partition.starting_indices_begin;
const auto& ranges_by_part = partition.ranges_by_part;
auto local_ids_it = local_ids.get_const_data();
auto input_size = local_ids.get_size();
auto local_idxs_it = local_idxs.get_const_data();
auto input_size = local_idxs.get_size();

auto policy = thrust_policy(exec);

global_ids.resize_and_reset(local_ids.get_size());
auto global_ids_it = global_ids.get_data();
global_idxs.resize_and_reset(local_idxs.get_size());
auto global_idxs_it = global_idxs.get_data();

auto map_local = [rank, ranges_by_part, range_bounds, starting_indices,
partition] __device__(auto lid) {
Expand All @@ -330,11 +330,16 @@ void map_to_global(
auto local_ranges_size =
static_cast<int64>(local_ranges.end - local_ranges.begin);

auto it = binary_search(int64(0), local_ranges_size, [=](const auto i) {
return starting_indices[local_ranges.begin[i]] >= lid;
});
// the binary search finds the first local range, such that the starting
// index is larger than lid, thus lid is contained in the local range
// before that one
auto local_range_id =
it != local_ranges_size ? it : max(int64(0), it - 1);
binary_search(int64(0), local_ranges_size,
[=](const auto i) {
return starting_indices[local_ranges.begin[i]] >
lid;
}) -
1;
auto range_id = local_ranges.begin[local_range_id];

return static_cast<GlobalIndexType>(lid - starting_indices[range_id]) +
Expand Down Expand Up @@ -363,16 +368,16 @@ void map_to_global(
};

if (is == experimental::distributed::index_space::local) {
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
global_ids_it, map_local);
thrust::transform(policy, local_idxs_it, local_idxs_it + input_size,
global_idxs_it, map_local);
}
if (is == experimental::distributed::index_space::non_local) {
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
global_ids_it, map_non_local);
thrust::transform(policy, local_idxs_it, local_idxs_it + input_size,
global_idxs_it, map_non_local);
}
if (is == experimental::distributed::index_space::combined) {
thrust::transform(policy, local_ids_it, local_ids_it + input_size,
global_ids_it, map_combined);
thrust::transform(policy, local_idxs_it, local_idxs_it + input_size,
global_idxs_it, map_combined);
}
}

Expand Down
8 changes: 4 additions & 4 deletions core/distributed/index_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ array<LocalIndexType> index_map<LocalIndexType, GlobalIndexType>::map_to_local(
template <typename LocalIndexType, typename GlobalIndexType>
array<GlobalIndexType>
index_map<LocalIndexType, GlobalIndexType>::map_to_global(
const array<LocalIndexType>& local_ids, index_space index_space_v) const
const array<LocalIndexType>& local_idxs, index_space index_space_v) const
{
array<GlobalIndexType> global_ids(exec_);
array<GlobalIndexType> global_idxs(exec_);

exec_->run(index_map_kernels::make_map_to_global(
to_device_const(partition_.get()), to_device(remote_global_idxs_),
rank_, local_ids, index_space_v, global_ids));
rank_, local_idxs, index_space_v, global_idxs));

return global_ids;
return global_idxs;
}


Expand Down
4 changes: 2 additions & 2 deletions core/distributed/index_map_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ namespace kernels {
device_partition<const _ltype, const _gtype> partition, \
device_segmented_array<const _gtype> remote_global_idxs, \
experimental::distributed::comm_index_type rank, \
const array<_ltype>& local_ids, \
experimental::distributed::index_space is, array<_gtype>& global_ids)
const array<_ltype>& local_idxs, \
experimental::distributed::index_space is, array<_gtype>& global_idxs)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
Expand Down
25 changes: 19 additions & 6 deletions core/test/utils/assertions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/mtx_io.hpp>
#include <ginkgo/core/base/name_demangling.hpp>
#include <ginkgo/core/base/segmented_array.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "core/base/batch_utilities.hpp"
Expand Down Expand Up @@ -1014,19 +1015,19 @@ ::testing::AssertionResult segmented_array_equal(
second.get_const_flat_data())
.copy_to_array();

auto buffer_result = array_equal(first_expression, second_expression,
view_first, view_second);
if (buffer_result == ::testing::AssertionFailure()) {
return buffer_result << "Buffers of the segmented arrays mismatch";
}

auto offsets_result =
array_equal(first_expression, second_expression, first.get_offsets(),
second.get_offsets());
if (offsets_result == ::testing::AssertionFailure()) {
return offsets_result << "Offsets of the segmented arrays mismatch";
}

auto buffer_result = array_equal(first_expression, second_expression,
view_first, view_second);
if (buffer_result == ::testing::AssertionFailure()) {
return buffer_result << "Buffers of the segmented arrays mismatch";
}

return ::testing::AssertionSuccess();
}

Expand Down Expand Up @@ -1414,6 +1415,18 @@ T* plain_ptr(T* ptr)
}


/**
* Checks if two `gko::segmented_array`s are equal.
*
* Both the flat array buffer and the offsets of both arrays are tested
* for equality.
*
* Has to be called from within a google test unit test.
* Internally calls gko::test::assertions::segmented_array_equal().
*
* @param _array1 first segmented array
* @param _array2 second segmented array
*/
#define GKO_ASSERT_SEGMENTED_ARRAY_EQ(_array1, _array2) \
{ \
ASSERT_PRED_FORMAT2(::gko::test::assertions::segmented_array_equal, \
Expand Down
48 changes: 47 additions & 1 deletion core/test/utils/assertions_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -8,6 +8,7 @@

#include <gtest/gtest.h>

#include <ginkgo/core/base/segmented_array.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>

Expand Down Expand Up @@ -218,4 +219,49 @@ TEST_F(ArraysNear, CanUseShortNotation)
}


class SegmentedArraysEqual : public ::testing::Test {
protected:
using array = gko::array<double>;
using iarray = gko::array<gko::int64>;
using segmented_array = gko::segmented_array<double>;

std::shared_ptr<gko::Executor> exec = gko::ReferenceExecutor::create();

segmented_array arr1 = segmented_array::create_from_sizes(
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {2, 1, 2}});
segmented_array arr2 = segmented_array::create_from_sizes(
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {2, 1, 2}});
segmented_array arr3 = segmented_array::create_from_sizes(
array{exec, {1, 2, 3, 5, 6}}, iarray{exec, {2, 1, 2}});
segmented_array arr4 = segmented_array::create_from_sizes(
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {3, 2}});
segmented_array arr5 = segmented_array::create_from_sizes(
array{exec, {1, 2, 3, 4, 5}}, iarray{exec, {1, 2, 2}});
};


TEST_F(SegmentedArraysEqual, SucceedsIfEqual)
{
GKO_ASSERT_SEGMENTED_ARRAY_EQ(arr1, arr2);
}


TEST_F(SegmentedArraysEqual, FailsIfValuesDifferent)
{
GKO_ASSERT_SEGMENTED_ARRAY_EQ(arr1, arr3);
}


TEST_F(SegmentedArraysEqual, FailsIfOffsetsDifferent1)
{
GKO_ASSERT_SEGMENTED_ARRAY_EQ(arr1, arr4);
}


TEST_F(SegmentedArraysEqual, FailsIfOffsetsDifferent2)
{
GKO_ASSERT_SEGMENTED_ARRAY_EQ(arr1, arr5);
}


} // namespace
4 changes: 2 additions & 2 deletions dpcpp/distributed/index_map_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ void map_to_global(
device_partition<const LocalIndexType, const GlobalIndexType> partition,
device_segmented_array<const GlobalIndexType> remote_global_idxs,
experimental::distributed::comm_index_type rank,
const array<LocalIndexType>& local_ids,
const array<LocalIndexType>& local_idxs,
experimental::distributed::index_space is,
array<GlobalIndexType>& global_ids) GKO_NOT_IMPLEMENTED;
array<GlobalIndexType>& global_idxs) GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL);
Expand Down
5 changes: 1 addition & 4 deletions dpcpp/distributed/partition_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ void build_ranges_by_part(std::shared_ptr<const DefaultExecutor> exec,

range_ids.resize_and_reset(num_ranges);
auto range_ids_ptr = range_ids.get_data();
// fill range_ids with 0,...,num_ranges - 1
run_kernel(
exec, [] GKO_KERNEL(auto i, auto rid) { rid[i] = i; }, num_ranges,
range_ids_ptr);
components::fill_seq_array(exec, range_ids_ptr, num_ranges);

oneapi::dpl::stable_sort(policy, range_ids_ptr, range_ids_ptr + num_ranges,
[range_parts](const auto rid_a, const auto rid_b) {
Expand Down
7 changes: 4 additions & 3 deletions include/ginkgo/core/distributed/index_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,16 @@ struct index_map {
/**
* Maps local indices to global indices
*
* @param local_ids the local indices to map
* @param local_idxs the local indices to map
* @param index_space_v the index space in which the passed-in local
* indices are defined
*
* @return the mapped global indices. Any local index, that is not in the
* specified index space is mapped to invalid_index
*/
array<GlobalIndexType> map_to_global(const array<LocalIndexType>& local_ids,
index_space index_space_v) const;
array<GlobalIndexType> map_to_global(
const array<LocalIndexType>& local_idxs,
index_space index_space_v) const;

/**
* \brief get size of index_space::local
Expand Down
44 changes: 15 additions & 29 deletions omp/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,71 +245,57 @@ void map_to_global(
device_partition<const LocalIndexType, const GlobalIndexType> partition,
device_segmented_array<const GlobalIndexType> remote_global_idxs,
experimental::distributed::comm_index_type rank,
const array<LocalIndexType>& local_ids,
const array<LocalIndexType>& local_idxs,
experimental::distributed::index_space is,
array<GlobalIndexType>& global_ids)
array<GlobalIndexType>& global_idxs)
{
const auto& ranges_by_part = partition.ranges_by_part;
auto local_ranges = ranges_by_part.get_segment(rank);

global_ids.resize_and_reset(local_ids.get_size());
global_idxs.resize_and_reset(local_idxs.get_size());

auto local_size =
static_cast<LocalIndexType>(partition.part_sizes_begin[rank]);
auto remote_size = static_cast<LocalIndexType>(
remote_global_idxs.flat_end - remote_global_idxs.flat_begin);
size_type local_range_id = 0;
if (is == experimental::distributed::index_space::local) {
#pragma omp parallel for firstprivate(local_range_id)
for (size_type i = 0; i < local_ids.get_size(); ++i) {
auto lid = local_ids.get_const_data()[i];
for (size_type i = 0; i < local_idxs.get_size(); ++i) {
auto lid = local_idxs.get_const_data()[i];

if (is == experimental::distributed::index_space::local) {
if (0 <= lid && lid < local_size) {
local_range_id =
find_local_range(lid, rank, partition, local_range_id);
global_ids.get_data()[i] = map_to_global(
global_idxs.get_data()[i] = map_to_global(
lid, partition, local_ranges.begin[local_range_id]);
} else {
global_ids.get_data()[i] = invalid_index<GlobalIndexType>();
global_idxs.get_data()[i] = invalid_index<GlobalIndexType>();
}
}
}
if (is == experimental::distributed::index_space::non_local) {
#pragma omp parallel for
for (size_type i = 0; i < local_ids.get_size(); ++i) {
auto lid = local_ids.get_const_data()[i];

} else if (is == experimental::distributed::index_space::non_local) {
if (0 <= lid && lid < remote_size) {
global_ids.get_data()[i] = remote_global_idxs.flat_begin[lid];
global_idxs.get_data()[i] = remote_global_idxs.flat_begin[lid];
} else {
global_ids.get_data()[i] = invalid_index<GlobalIndexType>();
global_idxs.get_data()[i] = invalid_index<GlobalIndexType>();
}
}
}
if (is == experimental::distributed::index_space::combined) {
#pragma omp parallel for firstprivate(local_range_id)
for (size_type i = 0; i < local_ids.get_size(); ++i) {
auto lid = local_ids.get_const_data()[i];

} else if (is == experimental::distributed::index_space::combined) {
if (0 <= lid && lid < local_size) {
local_range_id =
find_local_range(lid, rank, partition, local_range_id);
global_ids.get_data()[i] = map_to_global(
global_idxs.get_data()[i] = map_to_global(
lid, partition, local_ranges.begin[local_range_id]);
} else if (local_size <= lid && lid < local_size + remote_size) {
global_ids.get_data()[i] =
global_idxs.get_data()[i] =
remote_global_idxs.flat_begin[lid - local_size];
} else {
global_ids.get_data()[i] = invalid_index<GlobalIndexType>();
global_idxs.get_data()[i] = invalid_index<GlobalIndexType>();
}
}
}
}

GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_INDEX_MAP_MAP_TO_GLOBAL);


} // namespace index_map
} // namespace omp
} // namespace kernels
Expand Down
30 changes: 11 additions & 19 deletions reference/distributed/index_map_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,14 @@ void map_to_global(
device_partition<const LocalIndexType, const GlobalIndexType> partition,
device_segmented_array<const GlobalIndexType> remote_global_idxs,
experimental::distributed::comm_index_type rank,
const array<LocalIndexType>& local_ids,
const array<LocalIndexType>& local_idxs,
experimental::distributed::index_space is,
array<GlobalIndexType>& global_ids)
array<GlobalIndexType>& global_idxs)
{
const auto& ranges_by_part = partition.ranges_by_part;
auto local_ranges = ranges_by_part.get_segment(rank);

global_ids.resize_and_reset(local_ids.get_size());
global_idxs.resize_and_reset(local_idxs.get_size());

auto local_size =
static_cast<LocalIndexType>(partition.part_sizes_begin[rank]);
Expand Down Expand Up @@ -246,22 +246,14 @@ void map_to_global(
}
};

if (is == experimental::distributed::index_space::local) {
for (size_type i = 0; i < local_ids.get_size(); ++i) {
auto lid = local_ids.get_const_data()[i];
global_ids.get_data()[i] = map_local(lid);
}
}
if (is == experimental::distributed::index_space::non_local) {
for (size_type i = 0; i < local_ids.get_size(); ++i) {
auto lid = local_ids.get_const_data()[i];
global_ids.get_data()[i] = map_non_local(lid);
}
}
if (is == experimental::distributed::index_space::combined) {
for (size_type i = 0; i < local_ids.get_size(); ++i) {
auto lid = local_ids.get_const_data()[i];
global_ids.get_data()[i] = map_combined(lid);
for (size_type i = 0; i < local_idxs.get_size(); ++i) {
auto lid = local_idxs.get_const_data()[i];
if (is == experimental::distributed::index_space::local) {
global_idxs.get_data()[i] = map_local(lid);
} else if (is == experimental::distributed::index_space::non_local) {
global_idxs.get_data()[i] = map_non_local(lid);
} else if (is == experimental::distributed::index_space::combined) {
global_idxs.get_data()[i] = map_combined(lid);
}
}
}
Expand Down
Loading

0 comments on commit 5b5ddac

Please sign in to comment.