From d63448cfa68946a151a2729c3968a2537a8db52a Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Tue, 31 May 2022 23:27:22 +0200 Subject: [PATCH] Copy if begin_bit == end_bit, but overwrite not allowed. --- cub/device/dispatch/dispatch_radix_sort.cuh | 60 +++++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index c2c296605..c578cf22c 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -1276,8 +1276,7 @@ struct DispatchRadixSort : // portions handle inputs with >=2**30 elements, due to the way lookback works // for testing purposes, one portion is <= 2**28 elements const PortionOffsetT PORTION_SIZE = ((1 << 28) - 1) / ONESWEEP_TILE_ITEMS * ONESWEEP_TILE_ITEMS; - // even if begin_bit == end_bit, there will always be at least one pass - int num_passes = CUB_MAX(cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS), 1); + int num_passes = cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS); OffsetT num_portions = static_cast(cub::DivideAndRoundUp(num_items, PORTION_SIZE)); PortionOffsetT max_num_blocks = cub::DivideAndRoundUp( static_cast( @@ -1313,7 +1312,7 @@ struct DispatchRadixSort : ValueT* d_values_tmp2 = (ValueT*)allocations[3]; AtomicOffsetT* d_ctrs = (AtomicOffsetT*)allocations[4]; - do { + do { // initialization if (CubDebug(error = cudaMemsetAsync( d_ctrs, 0, num_portions * num_passes * sizeof(AtomicOffsetT), stream))) break; @@ -1397,8 +1396,9 @@ struct DispatchRadixSort : d_values.d_buffers[1] = d_values_tmp2; } - int current_bit = begin_bit, pass = 0; - do { + for (int current_bit = begin_bit, pass = 0; current_bit < end_bit; + current_bit += RADIX_BITS, ++pass) + { int num_bits = CUB_MIN(end_bit - current_bit, RADIX_BITS); for (OffsetT portion = 0; portion < num_portions; ++portion) { @@ -1465,11 +1465,7 @@ struct DispatchRadixSort : } d_keys.selector ^= 1; d_values.selector ^= 1; - - // next pass - current_bit += RADIX_BITS; - ++pass; - } while (current_bit < end_bit); + } } while (0); return error; @@ -1563,7 +1559,7 @@ struct DispatchRadixSort : // Pass planning. Run passes of the alternate digit-size configuration until we have an even multiple of our preferred digit size int num_bits = end_bit - begin_bit; - int num_passes = CUB_MAX(cub::DivideAndRoundUp(num_bits, pass_config.radix_bits), 1); + int num_passes = cub::DivideAndRoundUp(num_bits, pass_config.radix_bits); bool is_num_passes_odd = num_passes & 1; int max_alt_passes = (num_passes * pass_config.radix_bits) - num_bits; int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_pass_config.radix_bits)); @@ -1643,6 +1639,42 @@ struct DispatchRadixSort : return InvokeOnesweep(); } + CUB_RUNTIME_FUNCTION __forceinline__ + cudaError_t InvokeCopy() + { + // is_overwrite_okay == false here + // Return the number of temporary bytes if requested + if (d_temp_storage == nullptr) + { + temp_storage_bytes = 1; + return cudaSuccess; + } + + // Copy keys + cudaError_t error = cudaSuccess; + error = cudaMemcpyAsync(d_keys.Alternate(), d_keys.Current(), num_items * sizeof(KeyT), + cudaMemcpyDefault, stream); + if (CubDebug(error)) + { + return error; + } + d_keys.selector ^= 1; + + // Copy values if necessary + if (!KEYS_ONLY) + { + error = cudaMemcpyAsync(d_values.Alternate(), d_values.Current(), + num_items * sizeof(ValueT), cudaMemcpyDefault, stream); + if (CubDebug(error)) + { + return error; + } + } + d_values.selector ^= 1; + + return error; + } + /// Invocation template CUB_RUNTIME_FUNCTION __forceinline__ @@ -1661,6 +1693,12 @@ struct DispatchRadixSort : return cudaSuccess; } + // Check if simple copy suffices (is_overwrite_okay == false at this point) + if (begin_bit == end_bit) + { + return InvokeCopy(); + } + // Force kernel code-generation in all compiler passes if (num_items <= (SingleTilePolicyT::BLOCK_THREADS * SingleTilePolicyT::ITEMS_PER_THREAD)) {