Skip to content

Commit

Permalink
Fix for init_list space size detection
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Feb 4, 2022
1 parent 2f8e21d commit a942ad2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 33 deletions.
4 changes: 3 additions & 1 deletion include/ginkgo/core/base/index_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ class IndexSet : public EnablePolymorphicObject<IndexSet<IndexType>> {
std::initializer_list<IndexType> init_list,
const bool is_sorted = false)
: EnablePolymorphicObject<IndexSet>(std::move(executor)),
index_space_size_(init_list.size())
index_space_size_(
*(std::max_element(std::begin(init_list), std::end(init_list))) +
1)
{
this->populate_subsets(
Array<IndexType>(this->get_executor(), init_list), is_sorted);
Expand Down
40 changes: 24 additions & 16 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,44 +749,51 @@ namespace index_set {


template <typename IndexType>
Array<IndexType> map_global_to_local(const IndexSet<IndexType>& index_set,
const Array<IndexType>& global_indices,
const bool is_sorted)
Array<IndexType> map_global_to_local(
std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const Array<IndexType>& global_indices, const bool is_sorted)
{
auto exec = index_set.get_executor();
auto local_indices =
gko::Array<IndexType>(exec, global_indices.get_num_elems());

GKO_ASSERT(index_set.get_num_subsets() >= 1);
gko::kernels::omp::index_set::global_to_local(
as<gko::OmpExecutor>(exec), index_set.get_size(),
index_set.get_num_subsets(), index_set.get_subsets_begin(),
index_set.get_subsets_end(), index_set.get_superset_indices(),
exec, index_set.get_size(), index_set.get_num_subsets(),
index_set.get_subsets_begin(), index_set.get_subsets_end(),
index_set.get_superset_indices(),
static_cast<IndexType>(local_indices.get_num_elems()),
global_indices.get_const_data(), local_indices.get_data(), is_sorted);
return local_indices;
}


template <typename IndexType>
IndexType get_local_index(const IndexSet<IndexType>& index_set,
IndexType get_local_index(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const IndexType index)
{
auto exec = index_set.get_executor();
const auto global_idx =
Array<IndexType>(exec, std::initializer_list<IndexType>{index});
auto local_idx = Array<IndexType>(
exec, index_set::map_global_to_local(index_set, global_idx, true));
exec,
index_set::map_global_to_local(exec, index_set, global_idx, true));

return exec->copy_val_to_host(local_idx.get_data());
}


template <typename IndexType>
bool contains(const IndexSet<IndexType>& index_set, const IndexType input_index)
bool contains(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set, const IndexType input_index)
{
auto local_index = index_set::get_local_index(index_set, input_index);
return local_index != invalid_index<IndexType>();
if (input_index >= index_set.get_size()) {
return false;
} else {
auto local_index =
index_set::get_local_index(exec, index_set, input_index);
return local_index != invalid_index<IndexType>();
}
}


Expand All @@ -810,7 +817,7 @@ void calculate_nonzeros_per_row_in_index_set(
row_nnz->get_data()[res_row] = zero<IndexType>();
for (size_type nnz = source->get_const_row_ptrs()[row];
nnz < source->get_const_row_ptrs()[row + 1]; ++nnz) {
if (index_set::contains(col_index_set,
if (index_set::contains(exec, col_index_set,
source->get_const_col_idxs()[nnz])) {
row_nnz->get_data()[res_row]++;
}
Expand Down Expand Up @@ -885,9 +892,10 @@ void compute_submatrix_from_index_set(
src_row_ptrs[row + 1] - src_row_ptrs[row], 0);
for (size_type nnz = src_row_ptrs[row]; nnz < src_row_ptrs[row + 1];
++nnz) {
if (index_set::contains(col_index_set, src_col_idxs[nnz])) {
if (index_set::contains(exec, col_index_set,
src_col_idxs[nnz])) {
res_col_idxs[res_nnz] = index_set::get_local_index(
col_index_set, src_col_idxs[nnz]);
exec, col_index_set, src_col_idxs[nnz]);
res_values[res_nnz] = src_values[nnz];
res_nnz++;
}
Expand Down
40 changes: 24 additions & 16 deletions reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,44 +647,51 @@ namespace index_set {


template <typename IndexType>
Array<IndexType> map_global_to_local(const IndexSet<IndexType>& index_set,
const Array<IndexType>& global_indices,
const bool is_sorted)
Array<IndexType> map_global_to_local(
std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const Array<IndexType>& global_indices, const bool is_sorted)
{
auto exec = index_set.get_executor();
auto local_indices =
gko::Array<IndexType>(exec, global_indices.get_num_elems());

GKO_ASSERT(index_set.get_num_subsets() >= 1);
gko::kernels::reference::index_set::global_to_local(
as<gko::ReferenceExecutor>(exec), index_set.get_size(),
index_set.get_num_subsets(), index_set.get_subsets_begin(),
index_set.get_subsets_end(), index_set.get_superset_indices(),
exec, index_set.get_size(), index_set.get_num_subsets(),
index_set.get_subsets_begin(), index_set.get_subsets_end(),
index_set.get_superset_indices(),
static_cast<IndexType>(local_indices.get_num_elems()),
global_indices.get_const_data(), local_indices.get_data(), is_sorted);
return local_indices;
}


template <typename IndexType>
IndexType get_local_index(const IndexSet<IndexType>& index_set,
IndexType get_local_index(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const IndexType index)
{
auto exec = index_set.get_executor();
const auto global_idx =
Array<IndexType>(exec, std::initializer_list<IndexType>{index});
auto local_idx = Array<IndexType>(
exec, index_set::map_global_to_local(index_set, global_idx, true));
exec,
index_set::map_global_to_local(exec, index_set, global_idx, true));

return exec->copy_val_to_host(local_idx.get_data());
}


template <typename IndexType>
bool contains(const IndexSet<IndexType>& index_set, const IndexType input_index)
bool contains(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set, const IndexType input_index)
{
auto local_index = index_set::get_local_index(index_set, input_index);
return local_index != invalid_index<IndexType>();
if (input_index >= index_set.get_size()) {
return false;
} else {
auto local_index =
index_set::get_local_index(exec, index_set, input_index);
return local_index != invalid_index<IndexType>();
}
}


Expand All @@ -708,7 +715,7 @@ void calculate_nonzeros_per_row_in_index_set(
row_nnz->get_data()[res_row] = zero<IndexType>();
for (size_type nnz = source->get_const_row_ptrs()[row];
nnz < source->get_const_row_ptrs()[row + 1]; ++nnz) {
if (index_set::contains(col_index_set,
if (index_set::contains(exec, col_index_set,
source->get_const_col_idxs()[nnz])) {
row_nnz->get_data()[res_row]++;
}
Expand Down Expand Up @@ -784,9 +791,10 @@ void compute_submatrix_from_index_set(
src_row_ptrs[row + 1] - src_row_ptrs[row], 0);
for (size_type nnz = src_row_ptrs[row]; nnz < src_row_ptrs[row + 1];
++nnz) {
if (index_set::contains(col_index_set, src_col_idxs[nnz])) {
if (index_set::contains(exec, col_index_set,
src_col_idxs[nnz])) {
res_col_idxs[res_nnz] = index_set::get_local_index(
col_index_set, src_col_idxs[nnz]);
exec, col_index_set, src_col_idxs[nnz]);
res_values[res_nnz] = src_values[nnz];
res_nnz++;
}
Expand Down

0 comments on commit a942ad2

Please sign in to comment.