Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Copy if begin_bit == end_bit, but overwrite not allowed.
Browse files Browse the repository at this point in the history
  • Loading branch information
canonizer committed May 31, 2022
1 parent a68b45e commit d63448c
Showing 1 changed file with 49 additions and 11 deletions.
60 changes: 49 additions & 11 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1276,8 +1276,7 @@ struct DispatchRadixSort :
// portions handle inputs with >=2**30 elements, due to the way lookback works
// for testing purposes, one portion is <= 2**28 elements
const PortionOffsetT PORTION_SIZE = ((1 << 28) - 1) / ONESWEEP_TILE_ITEMS * ONESWEEP_TILE_ITEMS;
// even if begin_bit == end_bit, there will always be at least one pass
int num_passes = CUB_MAX(cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS), 1);
int num_passes = cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS);
OffsetT num_portions = static_cast<OffsetT>(cub::DivideAndRoundUp(num_items, PORTION_SIZE));
PortionOffsetT max_num_blocks = cub::DivideAndRoundUp(
static_cast<int>(
Expand Down Expand Up @@ -1313,7 +1312,7 @@ struct DispatchRadixSort :
ValueT* d_values_tmp2 = (ValueT*)allocations[3];
AtomicOffsetT* d_ctrs = (AtomicOffsetT*)allocations[4];

do {
do {
// initialization
if (CubDebug(error = cudaMemsetAsync(
d_ctrs, 0, num_portions * num_passes * sizeof(AtomicOffsetT), stream))) break;
Expand Down Expand Up @@ -1397,8 +1396,9 @@ struct DispatchRadixSort :
d_values.d_buffers[1] = d_values_tmp2;
}

int current_bit = begin_bit, pass = 0;
do {
for (int current_bit = begin_bit, pass = 0; current_bit < end_bit;
current_bit += RADIX_BITS, ++pass)
{
int num_bits = CUB_MIN(end_bit - current_bit, RADIX_BITS);
for (OffsetT portion = 0; portion < num_portions; ++portion)
{
Expand Down Expand Up @@ -1465,11 +1465,7 @@ struct DispatchRadixSort :
}
d_keys.selector ^= 1;
d_values.selector ^= 1;

// next pass
current_bit += RADIX_BITS;
++pass;
} while (current_bit < end_bit);
}
} while (0);

return error;
Expand Down Expand Up @@ -1563,7 +1559,7 @@ struct DispatchRadixSort :

// Pass planning. Run passes of the alternate digit-size configuration until we have an even multiple of our preferred digit size
int num_bits = end_bit - begin_bit;
int num_passes = CUB_MAX(cub::DivideAndRoundUp(num_bits, pass_config.radix_bits), 1);
int num_passes = cub::DivideAndRoundUp(num_bits, pass_config.radix_bits);
bool is_num_passes_odd = num_passes & 1;
int max_alt_passes = (num_passes * pass_config.radix_bits) - num_bits;
int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_pass_config.radix_bits));
Expand Down Expand Up @@ -1643,6 +1639,42 @@ struct DispatchRadixSort :
return InvokeOnesweep<ActivePolicyT>();
}

CUB_RUNTIME_FUNCTION __forceinline__
cudaError_t InvokeCopy()
{
// is_overwrite_okay == false here
// Return the number of temporary bytes if requested
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}

// Copy keys
cudaError_t error = cudaSuccess;
error = cudaMemcpyAsync(d_keys.Alternate(), d_keys.Current(), num_items * sizeof(KeyT),
cudaMemcpyDefault, stream);
if (CubDebug(error))
{
return error;
}
d_keys.selector ^= 1;

// Copy values if necessary
if (!KEYS_ONLY)
{
error = cudaMemcpyAsync(d_values.Alternate(), d_values.Current(),
num_items * sizeof(ValueT), cudaMemcpyDefault, stream);
if (CubDebug(error))
{
return error;
}
}
d_values.selector ^= 1;

return error;
}

/// Invocation
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION __forceinline__
Expand All @@ -1661,6 +1693,12 @@ struct DispatchRadixSort :
return cudaSuccess;
}

// Check if simple copy suffices (is_overwrite_okay == false at this point)
if (begin_bit == end_bit)
{
return InvokeCopy();
}

// Force kernel code-generation in all compiler passes
if (num_items <= (SingleTilePolicyT::BLOCK_THREADS * SingleTilePolicyT::ITEMS_PER_THREAD))
{
Expand Down

0 comments on commit d63448c

Please sign in to comment.