Skip to content

Commit

Permalink
computes partition block permutation only once
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Oct 25, 2021
1 parent ade2305 commit c2bdbc9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 22 deletions.
30 changes: 16 additions & 14 deletions core/distributed/partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Partition<LocalIndexType>::build_from_mapping(
exec->run(
partition::make_build_from_mapping(*local_mapping.get(), result.get()));
result->compute_range_ranks();
result->compute_block_gather_permutation();
return result;
}

Expand All @@ -83,6 +84,7 @@ Partition<LocalIndexType>::build_from_contiguous(
exec->run(partition::make_build_from_contiguous(*local_ranges.get(),
result.get()));
result->compute_range_ranks();
result->compute_block_gather_permutation();
return result;
}

Expand All @@ -97,6 +99,20 @@ void Partition<LocalIndexType>::compute_range_ranks()
}


template <typename LocalIndexType>
void Partition<LocalIndexType>::compute_block_gather_permutation(
const bool recompute)
{
if (block_gather_permutation_.get_num_elems() == 0 || recompute) {
block_gather_permutation_.resize_and_reset(this->get_size());
block_gather_permutation_.fill(-1);
auto exec = block_gather_permutation_.get_executor();
exec->run(partition::make_build_block_gathered_permute(
this, block_gather_permutation_));
}
}


template <typename LocalIndexType>
void Partition<LocalIndexType>::validate_data() const
{
Expand Down Expand Up @@ -182,19 +198,5 @@ bool is_ordered(const Partition<LocalIndexType>* partition)
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_IS_ORDERED);


template <typename LocalIndexType>
Array<LocalIndexType> build_block_gather_permute(
const Partition<LocalIndexType>* partition)
{
auto exec = partition->get_executor();
Array<LocalIndexType> permute{exec, partition->get_size()};
exec->run(partition::make_build_block_gathered_permute(partition, permute));
return permute;
}
#define GKO_DECLARE_BUILD_BLOCK_GATHER_PERMUTE(_type) \
Array<_type> build_block_gather_permute(const Partition<_type>* partition)
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_BUILD_BLOCK_GATHER_PERMUTE);


} // namespace distributed
} // namespace gko
44 changes: 36 additions & 8 deletions include/ginkgo/core/distributed/partition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,39 @@ class Partition : public EnablePolymorphicObject<Partition<LocalIndexType>>,
part_sizes_.get_const_data() + part);
}

/**
* Computes a permutation from part-wise locally ordered indices to globally
* ordered.
*
* See @ref get_block_gather_permutation for a use-case description of the
* permutation.
*/
void compute_block_gather_permutation(bool recompute = false);

/**
* Access to the permutation from part-wise locally ordered indices to
* globally ordered ones.
*
* Assume there is data associated with each global index and that it is
* first sorted by the parts and then by local index wrt to that part.
* Consider the following data:
* > [0 1 2 3 4 5]
* with the following partitions:
* > p_1 = [0, 4-5]
* > p_2 = [1-3]
* Then the data sorted as described above would look like:
* > [0 4 5 1 2 3]
* This permutation can be used to get the data into the globally consistent
* ordering from before.
*
* @return index permutation which can be used in the @ref Permutable
* interface
*/
const Array<local_index_type>* get_block_gather_permutation() const
{
return &block_gather_permutation_;
}

/**
* Builds a partition from a given mapping global_index -> part_id.
* @param exec the Executor on which the partition should be built
Expand Down Expand Up @@ -195,7 +228,8 @@ class Partition : public EnablePolymorphicObject<Partition<LocalIndexType>>,
offsets_{exec, num_ranges + 1},
ranks_{exec, num_ranges},
part_sizes_{exec, static_cast<size_type>(num_parts)},
part_ids_{exec, num_ranges}
part_ids_{exec, num_ranges},
block_gather_permutation_{exec}
{
offsets_.fill(0);
ranks_.fill(0);
Expand All @@ -210,6 +244,7 @@ class Partition : public EnablePolymorphicObject<Partition<LocalIndexType>>,
Array<local_index_type> ranks_;
Array<local_index_type> part_sizes_;
Array<comm_index_type> part_ids_;
Array<local_index_type> block_gather_permutation_;
};


Expand All @@ -219,13 +254,6 @@ bool is_connected(const Partition<LocalIndexType>* partition);
template <typename LocalIndexType>
bool is_ordered(const Partition<LocalIndexType>* partition);

/**
* Creates a permutation that maps gathered indices to global indices
*/
template <typename LocalIndexType>
Array<LocalIndexType> build_block_gather_permute(
const Partition<LocalIndexType>* partition);


} // namespace distributed
} // namespace gko
Expand Down

0 comments on commit c2bdbc9

Please sign in to comment.