Skip to content

Commit

Permalink
Merge pull request #4570 from Rombur/sort
Browse files Browse the repository at this point in the history
Fix Kokkos_Sort when using integer and HIP
  • Loading branch information
dalg24 authored Dec 14, 2021
2 parents 6212cc0 + 4d91027 commit 7ff417b
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions algorithms/src/Kokkos_Sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ struct BinOp1D {
BinOp1D(int max_bins__, typename KeyViewType::const_value_type min,
typename KeyViewType::const_value_type max)
: max_bins_(max_bins__ + 1),
mul_(1.0 * max_bins__ / (max - min)),
// Cast to int64_t to avoid possible overflow when using integer
mul_(std::is_integral<typename KeyViewType::const_value_type>::value
? 1.0 * max_bins__ / (int64_t(max) - int64_t(min))
: 1.0 * max_bins__ / (max - min)),
range_(max - min),
min_(min) {
// For integral types the number of bins may be larger than the range
Expand All @@ -450,11 +453,25 @@ struct BinOp1D {
}

// Determine bin index from key value
template <class ViewType>
template <
class ViewType,
std::enable_if_t<!std::is_integral<typename ViewType::value_type>::value,
bool> = true>
KOKKOS_INLINE_FUNCTION int bin(ViewType& keys, const int& i) const {
return int(mul_ * (keys(i) - min_));
}

// Determine bin index from key value
template <
class ViewType,
std::enable_if_t<std::is_integral<typename ViewType::value_type>::value,
bool> = true>
KOKKOS_INLINE_FUNCTION int bin(ViewType& keys, const int& i) const {
// The cast to int64_t is necessary because otherwise HIP returns the wrong
// result.
return int(mul_ * (int64_t(keys(i)) - int64_t(min_)));
}

// Return maximum bin index + 1
KOKKOS_INLINE_FUNCTION
int max_bins() const { return max_bins_; }
Expand Down Expand Up @@ -579,10 +596,13 @@ std::enable_if_t<Kokkos::is_execution_space<ExecutionSpace>::value> sort(
// TODO: figure out better max_bins then this ...
int64_t max_bins = view.extent(0) / 2;
if (std::is_integral<typename ViewType::non_const_value_type>::value) {
// Cast to int64_t to avoid possible overflow when using integer
int64_t const max_val = result.max_val;
int64_t const min_val = result.min_val;
// using 10M as the cutoff for special behavior (roughly 40MB for the count
// array)
if ((result.max_val - result.min_val) < 10000000) {
max_bins = result.max_val - result.min_val + 1;
if ((max_val - min_val) < 10000000) {
max_bins = max_val - min_val + 1;
sort_in_bins = false;
}
}
Expand Down

0 comments on commit 7ff417b

Please sign in to comment.