Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fix begin_bit == end_bit == 0 for device-wide and segmented sort #481

Merged
merged 5 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 87 additions & 23 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -912,12 +912,12 @@ struct DeviceRadixSortPolicy
struct Policy800 : ChainedPolicy<800, Policy800, Policy700>
{
enum {
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5,
SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
ONESWEEP = sizeof(KeyT) >= sizeof(uint32_t),
ONESWEEP_RADIX_BITS = 8,
OFFSET_64BIT = sizeof(OffsetT) == 8,
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5,
SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
ONESWEEP = sizeof(KeyT) >= sizeof(uint32_t),
ONESWEEP_RADIX_BITS = 8,
OFFSET_64BIT = sizeof(OffsetT) == 8,
};

// Histogram policy
Expand Down Expand Up @@ -1366,7 +1366,7 @@ struct DispatchRadixSort :
ValueT* d_values_tmp2 = (ValueT*)allocations[3];
AtomicOffsetT* d_ctrs = (AtomicOffsetT*)allocations[4];

do {
do {
// initialization
if (CubDebug(error = cudaMemsetAsync(
d_ctrs, 0, num_portions * num_passes * sizeof(AtomicOffsetT), stream))) break;
Expand Down Expand Up @@ -1498,6 +1498,11 @@ struct DispatchRadixSort :
}
}

if (error != cudaSuccess)
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
{
break;
}

// use the temporary buffers if no overwrite is allowed
if (!is_overwrite_okay && pass == 0)
{
Expand Down Expand Up @@ -1671,6 +1676,58 @@ struct DispatchRadixSort :
return InvokeOnesweep<ActivePolicyT>();
}

CUB_RUNTIME_FUNCTION __forceinline__
cudaError_t InvokeCopy()
{
// is_overwrite_okay == false here
// Return the number of temporary bytes if requested
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}

// Copy keys
#ifdef CUB_DETAIL_DEBUG_ENABLE_LOG
_CubLog("Invoking async copy of %lld keys on stream %lld\n", (long long)num_items,
(long long)stream);
#endif
cudaError_t error = cudaSuccess;
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
error = cudaMemcpyAsync(d_keys.Alternate(), d_keys.Current(), num_items * sizeof(KeyT),
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
cudaMemcpyDefault, stream);
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
if (CubDebug(error))
{
return error;
}
if (CubDebug(error = detail::DebugSyncStream(stream)))
{
return error;
}
d_keys.selector ^= 1;

// Copy values if necessary
if (!KEYS_ONLY)
{
#ifdef CUB_DETAIL_DEBUG_ENABLE_LOG
_CubLog("Invoking async copy of %lld values on stream %lld\n",
(long long)num_items, (long long)stream);
#endif
error = cudaMemcpyAsync(d_values.Alternate(), d_values.Current(),
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
num_items * sizeof(ValueT), cudaMemcpyDefault, stream);
if (CubDebug(error))
{
return error;
}
if (CubDebug(error = detail::DebugSyncStream(stream)))
{
return error;
}
}
d_values.selector ^= 1;

return error;
}

/// Invocation
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION __forceinline__
Expand All @@ -1679,15 +1736,23 @@ struct DispatchRadixSort :
typedef typename DispatchRadixSort::MaxPolicy MaxPolicyT;
typedef typename ActivePolicyT::SingleTilePolicy SingleTilePolicyT;

// Return if empty problem
if (num_items == 0)
// Return if empty problem, or if no bits to sort and double-buffering is used
if (num_items == 0 || (begin_bit == end_bit && is_overwrite_okay))
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}
return cudaSuccess;
}

return cudaSuccess;
// Check if simple copy suffices (is_overwrite_okay == false at this point)
cudaError_t error = cudaSuccess;
bool has_uva = false;
if ((error = HasUVA(has_uva)) != cudaSuccess) return error;
if (begin_bit == end_bit & has_uva)
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
{
return InvokeCopy();
}

// Force kernel code-generation in all compiler passes
Expand Down Expand Up @@ -2021,7 +2086,7 @@ struct DispatchSegmentedRadixSort :
int radix_bits = ActivePolicyT::SegmentedPolicy::RADIX_BITS;
int alt_radix_bits = ActivePolicyT::AltSegmentedPolicy::RADIX_BITS;
int num_bits = end_bit - begin_bit;
int num_passes = (num_bits + radix_bits - 1) / radix_bits;
int num_passes = CUB_MAX(DivideAndRoundUp(num_bits, radix_bits), 1);
bool is_num_passes_odd = num_passes & 1;
int max_alt_passes = (num_passes * radix_bits) - num_bits;
int alt_end_bit = CUB_MIN(end_bit, begin_bit + (max_alt_passes * alt_radix_bits));
Expand Down Expand Up @@ -2082,15 +2147,14 @@ struct DispatchSegmentedRadixSort :
{
typedef typename DispatchSegmentedRadixSort::MaxPolicy MaxPolicyT;

// Return if empty problem
if (num_items == 0)
// Return if empty problem, or if no bits to sort and double-buffering is used
if (num_items == 0 || (begin_bit == end_bit && is_overwrite_okay))
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}

return cudaSuccess;
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
}
return cudaSuccess;
}

// Force kernel code-generation in all compiler passes
Expand Down
17 changes: 17 additions & 0 deletions cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ CUB_RUNTIME_FUNCTION inline int CurrentDevice()
return device;
}

/** \brief Gets whether the current device supports unified addressing */
CUB_RUNTIME_FUNCTION cudaError_t HasUVA(bool& has_uva)
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
{
has_uva = false;
cudaError_t error = cudaSuccess;
int device = -1;
if (CubDebug(error = cudaGetDevice(&device)) != cudaSuccess) return error;
int uva = 0;
if (CubDebug(error = cudaDeviceGetAttribute(&uva, cudaDevAttrUnifiedAddressing, device))
!= cudaSuccess)
{
return error;
}
has_uva = uva == 1;
return error;
}

/**
* \brief RAII helper which saves the current device and switches to the
* specified device on construction and switches to the saved device on
Expand Down
8 changes: 6 additions & 2 deletions test/test_device_radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,11 @@ void TestBits(
printf("Testing key bits [%d,%d)\n", begin_bit, end_bit); fflush(stdout);
TestDirection(h_keys, num_items, num_segments, pre_sorted, h_segment_offsets, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit);

// Equal bits
begin_bit = end_bit = 0;
printf("Testing key bits [%d,%d)\n", begin_bit, end_bit); fflush(stdout);
TestDirection(h_keys, num_items, num_segments, pre_sorted, h_segment_offsets, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit);

// Across subword boundaries
int mid_bit = sizeof(KeyT) * 4;
printf("Testing key bits [%d,%d)\n", mid_bit - 1, mid_bit + 1); fflush(stdout);
Expand Down Expand Up @@ -1587,7 +1592,7 @@ void TestGen(
{
if (max_items == ~std::size_t(0))
{
max_items = 9000003;
max_items = 8000003;
}

if (max_segments < 0)
Expand Down Expand Up @@ -1650,7 +1655,6 @@ void TestGen(
TestSizes(h_keys.get(), large_num_items, max_segments, true);
fflush(stdout);
}

}


Expand Down