Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fix begin_bit == end_bit == 0 for device-wide and segmented sort #481

Merged
merged 5 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1694,7 +1694,7 @@ struct DispatchRadixSort :
#endif
cudaError_t error = cudaSuccess;
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
error = cudaMemcpyAsync(d_keys.Alternate(), d_keys.Current(), num_items * sizeof(KeyT),
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyDefault, stream);
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
if (CubDebug(error))
{
return error;
Expand All @@ -1713,7 +1713,7 @@ struct DispatchRadixSort :
(long long)num_items, (long long)stream);
#endif
error = cudaMemcpyAsync(d_values.Alternate(), d_values.Current(),
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
num_items * sizeof(ValueT), cudaMemcpyDeviceToDevice, stream);
num_items * sizeof(ValueT), cudaMemcpyDefault, stream);
if (CubDebug(error))
{
return error;
Expand Down Expand Up @@ -1747,7 +1747,10 @@ struct DispatchRadixSort :
}

// Check if simple copy suffices (is_overwrite_okay == false at this point)
if (begin_bit == end_bit)
cudaError_t error = cudaSuccess;
bool has_uva = false;
if ((error = HasUVA(has_uva)) != cudaSuccess) return error;
if (begin_bit == end_bit & has_uva)
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
{
return InvokeCopy();
}
Expand Down
17 changes: 17 additions & 0 deletions cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ CUB_RUNTIME_FUNCTION inline int CurrentDevice()
return device;
}

/** \brief Gets whether the current device supports unified addressing */
CUB_RUNTIME_FUNCTION cudaError_t HasUVA(bool& has_uva)
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
{
has_uva = false;
cudaError_t error = cudaSuccess;
int device = -1;
if (CubDebug(error = cudaGetDevice(&device)) != cudaSuccess) return error;
int uva = 0;
if (CubDebug(error = cudaDeviceGetAttribute(&uva, cudaDevAttrUnifiedAddressing, device))
!= cudaSuccess)
{
return error;
}
has_uva = uva == 1;
return error;
}

/**
* \brief RAII helper which saves the current device and switches to the
* specified device on construction and switches to the saved device on
Expand Down