diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 26cdee258d7..28692d76f03 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -744,62 +744,6 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); -// TODO: FIXME -namespace index_set { - - -template -Array map_global_to_local( - std::shared_ptr exec, - const IndexSet& index_set, - const Array& global_indices, const bool is_sorted) -{ - auto local_indices = - gko::Array(exec, global_indices.get_num_elems()); - - GKO_ASSERT(index_set.get_num_subsets() >= 1); - gko::kernels::omp::index_set::global_to_local( - exec, index_set.get_size(), index_set.get_num_subsets(), - index_set.get_subsets_begin(), index_set.get_subsets_end(), - index_set.get_superset_indices(), - static_cast(local_indices.get_num_elems()), - global_indices.get_const_data(), local_indices.get_data(), is_sorted); - return local_indices; -} - - -template -IndexType get_local_index(std::shared_ptr exec, - const IndexSet& index_set, - const IndexType index) -{ - const auto global_idx = - Array(exec, std::initializer_list{index}); - auto local_idx = Array( - exec, - index_set::map_global_to_local(exec, index_set, global_idx, true)); - - return exec->copy_val_to_host(local_idx.get_data()); -} - - -template -bool contains(std::shared_ptr exec, - const IndexSet& index_set, const IndexType input_index) -{ - if (input_index >= index_set.get_size()) { - return false; - } else { - auto local_index = - index_set::get_local_index(exec, index_set, input_index); - return local_index != invalid_index(); - } -} - - -} // namespace index_set - - template void calculate_nonzeros_per_row_in_index_set( std::shared_ptr exec, @@ -811,14 +755,26 @@ void calculate_nonzeros_per_row_in_index_set( auto num_row_subsets = row_index_set.get_num_subsets(); auto row_subset_begin = row_index_set.get_subsets_begin(); auto row_subset_end = row_index_set.get_subsets_end(); + auto src_ptrs = source->get_const_row_ptrs(); for (size_type set = 0; set < num_row_subsets; ++set) { for (size_type row = row_subset_begin[set]; row < row_subset_end[set]; ++row) { row_nnz->get_data()[res_row] = zero(); - for (size_type nnz = source->get_const_row_ptrs()[row]; - nnz < source->get_const_row_ptrs()[row + 1]; ++nnz) { - if (index_set::contains(exec, col_index_set, - source->get_const_col_idxs()[nnz])) { + Array l_idxs( + exec, + static_cast(src_ptrs[row + 1] - src_ptrs[row])); + gko::kernels::omp::index_set::global_to_local( + exec, col_index_set.get_size(), col_index_set.get_num_subsets(), + col_index_set.get_subsets_begin(), + col_index_set.get_subsets_end(), + col_index_set.get_superset_indices(), + static_cast(l_idxs.get_num_elems()), + source->get_const_col_idxs() + src_ptrs[row], l_idxs.get_data(), + false); + for (size_type nnz = 0; nnz < (src_ptrs[row + 1] - src_ptrs[row]); + ++nnz) { + auto l_idx = l_idxs.get_const_data()[nnz]; + if (l_idx != invalid_index()) { row_nnz->get_data()[res_row]++; } } @@ -888,19 +844,26 @@ void compute_submatrix_from_index_set( for (size_type set = 0; set < num_row_subsets; ++set) { for (size_type row = row_subset_begin[set]; row < row_subset_end[set]; ++row) { - auto local_map = std::vector( - src_row_ptrs[row + 1] - src_row_ptrs[row], 0); - for (size_type nnz = src_row_ptrs[row]; nnz < src_row_ptrs[row + 1]; - ++nnz) { - if (index_set::contains(exec, col_index_set, - src_col_idxs[nnz])) { - res_col_idxs[res_nnz] = index_set::get_local_index( - exec, col_index_set, src_col_idxs[nnz]); - res_values[res_nnz] = src_values[nnz]; + Array l_idxs( + exec, static_cast(src_row_ptrs[row + 1] - + src_row_ptrs[row])); + gko::kernels::omp::index_set::global_to_local( + exec, col_index_set.get_size(), col_index_set.get_num_subsets(), + col_index_set.get_subsets_begin(), + col_index_set.get_subsets_end(), + col_index_set.get_superset_indices(), + static_cast(l_idxs.get_num_elems()), + source->get_const_col_idxs() + src_row_ptrs[row], + l_idxs.get_data(), false); + for (size_type nnz = 0; + nnz < (src_row_ptrs[row + 1] - src_row_ptrs[row]); ++nnz) { + auto l_idx = l_idxs.get_const_data()[nnz]; + if (l_idx != invalid_index()) { + res_col_idxs[res_nnz] = l_idx; + res_values[res_nnz] = src_values[nnz + src_row_ptrs[row]]; res_nnz++; } } - // res_nnz = res_row_ptrs[row_index_set.get_local_index(row)]; } } } diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index fa94ca01da4..e4f01dbb3c1 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -625,62 +625,6 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); -// TODO: FIXME -namespace index_set { - - -template -Array map_global_to_local( - std::shared_ptr exec, - const IndexSet& index_set, - const Array& global_indices, const bool is_sorted) -{ - auto local_indices = - gko::Array(exec, global_indices.get_num_elems()); - - GKO_ASSERT(index_set.get_num_subsets() >= 1); - gko::kernels::reference::index_set::global_to_local( - exec, index_set.get_size(), index_set.get_num_subsets(), - index_set.get_subsets_begin(), index_set.get_subsets_end(), - index_set.get_superset_indices(), - static_cast(local_indices.get_num_elems()), - global_indices.get_const_data(), local_indices.get_data(), is_sorted); - return local_indices; -} - - -template -IndexType get_local_index(std::shared_ptr exec, - const IndexSet& index_set, - const IndexType index) -{ - const auto global_idx = - Array(exec, std::initializer_list{index}); - auto local_idx = Array( - exec, - index_set::map_global_to_local(exec, index_set, global_idx, true)); - - return exec->copy_val_to_host(local_idx.get_data()); -} - - -template -bool contains(std::shared_ptr exec, - const IndexSet& index_set, const IndexType input_index) -{ - if (input_index >= index_set.get_size()) { - return false; - } else { - auto local_index = - index_set::get_local_index(exec, index_set, input_index); - return local_index != invalid_index(); - } -} - - -} // namespace index_set - - template void calculate_nonzeros_per_row_in_index_set( std::shared_ptr exec, @@ -692,14 +636,26 @@ void calculate_nonzeros_per_row_in_index_set( auto num_row_subsets = row_index_set.get_num_subsets(); auto row_subset_begin = row_index_set.get_subsets_begin(); auto row_subset_end = row_index_set.get_subsets_end(); + auto src_ptrs = source->get_const_row_ptrs(); for (size_type set = 0; set < num_row_subsets; ++set) { for (size_type row = row_subset_begin[set]; row < row_subset_end[set]; ++row) { row_nnz->get_data()[res_row] = zero(); - for (size_type nnz = source->get_const_row_ptrs()[row]; - nnz < source->get_const_row_ptrs()[row + 1]; ++nnz) { - if (index_set::contains(exec, col_index_set, - source->get_const_col_idxs()[nnz])) { + Array l_idxs( + exec, + static_cast(src_ptrs[row + 1] - src_ptrs[row])); + gko::kernels::reference::index_set::global_to_local( + exec, col_index_set.get_size(), col_index_set.get_num_subsets(), + col_index_set.get_subsets_begin(), + col_index_set.get_subsets_end(), + col_index_set.get_superset_indices(), + static_cast(l_idxs.get_num_elems()), + source->get_const_col_idxs() + src_ptrs[row], l_idxs.get_data(), + false); + for (size_type nnz = 0; nnz < (src_ptrs[row + 1] - src_ptrs[row]); + ++nnz) { + auto l_idx = l_idxs.get_const_data()[nnz]; + if (l_idx != invalid_index()) { row_nnz->get_data()[res_row]++; } } @@ -770,19 +726,26 @@ void compute_submatrix_from_index_set( for (size_type set = 0; set < num_row_subsets; ++set) { for (size_type row = row_subset_begin[set]; row < row_subset_end[set]; ++row) { - auto local_map = std::vector( - src_row_ptrs[row + 1] - src_row_ptrs[row], 0); - for (size_type nnz = src_row_ptrs[row]; nnz < src_row_ptrs[row + 1]; - ++nnz) { - if (index_set::contains(exec, col_index_set, - src_col_idxs[nnz])) { - res_col_idxs[res_nnz] = index_set::get_local_index( - exec, col_index_set, src_col_idxs[nnz]); - res_values[res_nnz] = src_values[nnz]; + Array l_idxs( + exec, static_cast(src_row_ptrs[row + 1] - + src_row_ptrs[row])); + gko::kernels::reference::index_set::global_to_local( + exec, col_index_set.get_size(), col_index_set.get_num_subsets(), + col_index_set.get_subsets_begin(), + col_index_set.get_subsets_end(), + col_index_set.get_superset_indices(), + static_cast(l_idxs.get_num_elems()), + source->get_const_col_idxs() + src_row_ptrs[row], + l_idxs.get_data(), false); + for (size_type nnz = 0; + nnz < (src_row_ptrs[row + 1] - src_row_ptrs[row]); ++nnz) { + auto l_idx = l_idxs.get_const_data()[nnz]; + if (l_idx != invalid_index()) { + res_col_idxs[res_nnz] = l_idx; + res_values[res_nnz] = src_values[nnz + src_row_ptrs[row]]; res_nnz++; } } - // res_nnz = res_row_ptrs[row_index_set.get_local_index(row)]; } } }