Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #564 from senior-zero/fix-main/github/block_radix_…
Browse files Browse the repository at this point in the history
…rank

Fix block radix rank for blocks with non multiple of 32 threads
  • Loading branch information
gevtushenko authored Sep 2, 2022
2 parents 97659a3 + 3346cfc commit eee0ca9
Show file tree
Hide file tree
Showing 4 changed files with 517 additions and 60 deletions.
24 changes: 3 additions & 21 deletions cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,9 @@ struct AgentRadixSortDownsweep
using ValuesItr = CacheModifiedInputIterator<LOAD_MODIFIER, ValueT, OffsetT>;

// Radix ranking type to use
using BlockRadixRankT = cub::detail::conditional_t<
RANK_ALGORITHM == RADIX_RANK_BASIC,
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, false, SCAN_ALGORITHM>,
cub::detail::conditional_t<
RANK_ALGORITHM == RADIX_RANK_MEMOIZE,
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, true, SCAN_ALGORITHM>,
cub::detail::conditional_t<
RANK_ALGORITHM == RADIX_RANK_MATCH,
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>,
cub::detail::conditional_t<
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS,
RADIX_BITS,
IS_DESCENDING,
SCAN_ALGORITHM,
WARP_MATCH_ANY>,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS,
RADIX_BITS,
IS_DESCENDING,
SCAN_ALGORITHM,
WARP_MATCH_ATOMIC_OR>>>>>;
using BlockRadixRankT =
cub::detail::block_radix_rank_t<
RANK_ALGORITHM, BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>;

// Digit extractor type
using DigitExtractorT = BFEDigitExtractor<KeyT>;
Expand Down
117 changes: 107 additions & 10 deletions cub/block/block_radix_rank.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,35 @@ struct BlockRadixRankEmptyCallback
};


namespace detail
{

template <int Bits, int PartialWarpThreads, int PartialWarpId>
struct warp_in_block_matcher_t
{
static __device__ unsigned int match_any(unsigned int label, unsigned int warp_id)
{
if (warp_id == static_cast<unsigned int>(PartialWarpId))
{
return MatchAny<Bits, PartialWarpThreads>(label);
}

return MatchAny<Bits>(label);
}
};

template <int Bits, int PartialWarpId>
struct warp_in_block_matcher_t<Bits, 0, PartialWarpId>
{
static __device__ unsigned int match_any(unsigned int label, unsigned int warp_id)
{
return MatchAny<Bits>(label);
}
};

} // namespace detail


/**
* \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
* \ingroup BlockModule
Expand All @@ -114,17 +143,34 @@ struct BlockRadixRankEmptyCallback
* \par Performance Considerations
* - \granularity
*
* \par Examples
* \par
* - <b>Example 1:</b> Simple radix rank of 32-bit integer keys
* \code
* #include <cub/cub.cuh>
* \code
* #include <cub/cub.cuh>
*
* __global__ void ExampleKernel(...)
* {
* constexpr int block_threads = 2;
* constexpr int radix_bits = 5;
*
* // Specialize BlockRadixRank for a 1D block of 2 threads
* using block_radix_rank = cub::BlockRadixRank<block_threads, radix_bits>;
* using storage_t = typename block_radix_rank::TempStorage;
*
* // Allocate shared memory for BlockRadixSort
* __shared__ storage_t temp_storage;
*
* template <int BLOCK_THREADS>
* __global__ void ExampleKernel(...)
* {
* // Obtain a segment of consecutive items that are blocked across threads
* int keys[2];
* int ranks[2];
* ...
*
* \endcode
* cub::BFEDigitExtractor<int> extractor(0, radix_bits);
* block_radix_rank(temp_storage).RankKeys(keys, ranks, extractor);
*
* ...
* \endcode
* Suppose the set of input `keys` across the block of threads is `{ [16,10], [9,11] }`.
* The corresponding output `ranks` in those threads will be `{ [3,1], [0,2] }`.
*
* \par Re-using dynamically allocating shared memory
* The following example under the examples/block folder illustrates usage of
Expand Down Expand Up @@ -528,6 +574,7 @@ private:

LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(0),
WARP_THREADS = 1 << LOG_WARP_THREADS,
PARTIAL_WARP_THREADS = BLOCK_THREADS % WARP_THREADS,
WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,

PADDED_WARPS = ((WARPS & 0x1) == 0) ?
Expand Down Expand Up @@ -698,7 +745,11 @@ public:
digit = RADIX_DIGITS - digit - 1;

// Mask of peers who have same digit as me
uint32_t peer_mask = MatchAny<RADIX_BITS>(digit);
uint32_t peer_mask =
detail::warp_in_block_matcher_t<
RADIX_BITS,
PARTIAL_WARP_THREADS,
WARPS - 1>::match_any(digit, warp_id);

// Pointer to smem digit counter for this key
digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id];
Expand Down Expand Up @@ -844,7 +895,9 @@ struct BlockRadixRankMatchEarlyCounts
BINS_TRACKED_PER_THREAD = BINS_PER_THREAD,
FULL_BINS = BINS_PER_THREAD * BLOCK_THREADS == RADIX_DIGITS,
WARP_THREADS = CUB_PTX_WARP_THREADS,
PARTIAL_WARP_THREADS = BLOCK_THREADS % WARP_THREADS,
BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS,
PARTIAL_WARP_ID = BLOCK_WARPS - 1,
WARP_MASK = ~0,
NUM_MATCH_MASKS = MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR ? BLOCK_WARPS : 0,
// Guard against declaring zero-sized array:
Expand Down Expand Up @@ -1040,7 +1093,10 @@ struct BlockRadixRankMatchEarlyCounts
for (int u = 0; u < KEYS_PER_THREAD; ++u)
{
int bin = Digit(keys[u]);
int bin_mask = MatchAny<RADIX_BITS>(bin);
int bin_mask = detail::warp_in_block_matcher_t<RADIX_BITS,
PARTIAL_WARP_THREADS,
BLOCK_WARPS - 1>::match_any(bin,
warp);
int leader = (WARP_THREADS - 1) - __clz(bin_mask);
int warp_offset = 0;
int popc = __popc(bin_mask & LaneMaskLe());
Expand Down Expand Up @@ -1125,6 +1181,47 @@ struct BlockRadixRankMatchEarlyCounts
};


namespace detail
{

// `BlockRadixRank` doesn't conform to the typical pattern, not exposing the algorithm
// template parameter. Other algorithms don't provide the same template parameters, not allowing
// multi-dimensional thread block specializations.
//
// TODO(senior-zero) for 3.0:
// - Put existing implementations into the detail namespace
// - Support multi-dimensional thread blocks in the rest of implementations
// - Repurpose BlockRadixRank as an entry name with the algorithm template parameter
template <RadixRankAlgorithm RankAlgorithm,
int BlockDimX,
int RadixBits,
bool IsDescending,
BlockScanAlgorithm ScanAlgorithm>
using block_radix_rank_t = cub::detail::conditional_t<
RankAlgorithm == RADIX_RANK_BASIC,
BlockRadixRank<BlockDimX, RadixBits, IsDescending, false, ScanAlgorithm>,
cub::detail::conditional_t<
RankAlgorithm == RADIX_RANK_MEMOIZE,
BlockRadixRank<BlockDimX, RadixBits, IsDescending, true, ScanAlgorithm>,
cub::detail::conditional_t<
RankAlgorithm == RADIX_RANK_MATCH,
BlockRadixRankMatch<BlockDimX, RadixBits, IsDescending, ScanAlgorithm>,
cub::detail::conditional_t<
RankAlgorithm == RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BlockRadixRankMatchEarlyCounts<BlockDimX,
RadixBits,
IsDescending,
ScanAlgorithm,
WARP_MATCH_ANY>,
BlockRadixRankMatchEarlyCounts<BlockDimX,
RadixBits,
IsDescending,
ScanAlgorithm,
WARP_MATCH_ATOMIC_OR>>>>>;

} // namespace detail


CUB_NAMESPACE_END


93 changes: 64 additions & 29 deletions cub/util_ptx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -686,43 +686,78 @@ __device__ __forceinline__ T ShuffleIndex(
}


namespace detail
{

/**
* Compute a 32b mask of threads having the same least-significant
* LABEL_BITS of \p label as the calling thread.
/**
* Implementation detail for `MatchAny`. It provides specializations for full and partial warps.
* For partial warps, inactive threads must be masked out. This is done in the partial warp
* specialization below.
* Usage:
* ```
* // returns a mask of threads with the same 4 least-significant bits of `label`
* // in a warp with 16 active threads
* warp_matcher_t<4, 16>::match_any(label);
*
* // returns a mask of threads with the same 4 least-significant bits of `label`
* // in a warp with 32 active threads (no extra work is done)
* warp_matcher_t<4, 32>::match_any(label);
* ```
*/
template <int LABEL_BITS, int WARP_ACTIVE_THREADS>
struct warp_matcher_t
{

static __device__ unsigned int match_any(unsigned int label)
{
return warp_matcher_t<LABEL_BITS, 32>::match_any(label) & ~(~0 << WARP_ACTIVE_THREADS);
}

};

template <int LABEL_BITS>
inline __device__ unsigned int MatchAny(unsigned int label)
struct warp_matcher_t<LABEL_BITS, CUB_PTX_WARP_THREADS>
{
unsigned int retval;

// Extract masks of common threads for each bit
#pragma unroll
for (int BIT = 0; BIT < LABEL_BITS; ++BIT)
{
unsigned int mask;
unsigned int current_bit = 1 << BIT;
asm ("{\n"
" .reg .pred p;\n"
" and.b32 %0, %1, %2;"
" setp.eq.u32 p, %0, %2;\n"
" vote.ballot.sync.b32 %0, p, 0xffffffff;\n"
" @!p not.b32 %0, %0;\n"
"}\n" : "=r"(mask) : "r"(label), "r"(current_bit));

// Remove peers who differ
retval = (BIT == 0) ? mask : retval & mask;
}
// match.any.sync.b32 is slower when matching a few bits
// using a ballot loop instead
static __device__ unsigned int match_any(unsigned int label)
{
unsigned int retval;

// Extract masks of common threads for each bit
#pragma unroll
for (int BIT = 0; BIT < LABEL_BITS; ++BIT)
{
unsigned int mask;
unsigned int current_bit = 1 << BIT;
asm ("{\n"
" .reg .pred p;\n"
" and.b32 %0, %1, %2;"
" setp.eq.u32 p, %0, %2;\n"
" vote.ballot.sync.b32 %0, p, 0xffffffff;\n"
" @!p not.b32 %0, %0;\n"
"}\n" : "=r"(mask) : "r"(label), "r"(current_bit));

// Remove peers who differ
retval = (BIT == 0) ? mask : retval & mask;
}

return retval;
}

return retval;
};

// // VOLTA match
// unsigned int retval;
// asm ("{\n"
// " match.any.sync.b32 %0, %1, 0xffffffff;\n"
// "}\n" : "=r"(retval) : "r"(label));
// return retval;
} // namespace detail

/**
* Compute a 32b mask of threads having the same least-significant
* LABEL_BITS of \p label as the calling thread.
*/
template <int LABEL_BITS, int WARP_ACTIVE_THREADS = CUB_PTX_WARP_THREADS>
inline __device__ unsigned int MatchAny(unsigned int label)
{
return detail::warp_matcher_t<LABEL_BITS, WARP_ACTIVE_THREADS>::match_any(label);
}

CUB_NAMESPACE_END
Loading

0 comments on commit eee0ca9

Please sign in to comment.