-
Notifications
You must be signed in to change notification settings - Fork 448
Fix begin_bit == end_bit == 0 for device-wide and segmented sort #481
Conversation
This pull request should address #353. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@canonizer thank you for addressing this! I'm a bit concerned about the approach, though. I wonder if we could short-circuit in the begin_bit == end_bit
case. For instance, if is_overwrite_okay == true
we wouldn't do anything, since the double buffer would contain proper data already. Otherwise, we might just copy the data. I've written a simple benchmark below that might help understand the impact of this approach. For double buffer case, we have noop, which is definitely faster. Otherwise, memcpy is about 40% faster than actually sorting anything.
#include <cub/cub.cuh>
#include <thrust/device_vector.h>
#include <iostream>
void sort(
std::uint8_t *d_temp_storage, std::size_t &temp_storage_bytes,
int *d_keys_in, int *d_keys_out,
int num_items,
bool use_buffer, bool short_circuit)
{
const int begin_bit = 0;
const int end_bit = begin_bit;
cub::DoubleBuffer<int> d_keys(d_keys_in, d_keys_out);
if (use_buffer) {
if (short_circuit) {
temp_storage_bytes = 1; // noop
} else {
cub::DeviceRadixSort::SortKeys(
d_temp_storage, temp_storage_bytes,
d_keys, num_items, begin_bit, end_bit);
}
} else {
if (short_circuit) {
if (d_temp_storage == nullptr) {
temp_storage_bytes = 1;
} else {
cudaMemcpy(d_keys_out, d_keys_in, sizeof(int) * num_items, cudaMemcpyDeviceToDevice);
}
} else {
cub::DeviceRadixSort::SortKeys(
d_temp_storage, temp_storage_bytes,
d_keys_in, d_keys_out, num_items, begin_bit, end_bit);
}
}
}
int main()
{
const int num_items = 128 * 1024 * 1024;
thrust::device_vector<int> in(num_items);
thrust::device_vector<int> out(num_items);
int *d_keys_in = thrust::raw_pointer_cast(in.data());
int *d_keys_out = thrust::raw_pointer_cast(out.data());
std::uint8_t *d_temp_storage{};
std::size_t temp_storage_bytes = 0;
const bool use_buffer = false;
const bool short_circuit = true;
sort(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items,
use_buffer, short_circuit);
thrust::device_vector<std::uint8_t> temp_storage(temp_storage_bytes);
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());
cudaEvent_t begin, end;
cudaEventCreate(&begin);
cudaEventCreate(&end);
cudaEventRecord(begin);
sort(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, num_items,
use_buffer, short_circuit);
cudaEventRecord(end);
cudaEventSynchronize(end);
float ms{};
cudaEventElapsedTime(&ms, begin, end);
std::cout << ms << "ms" << std::endl;
cudaEventDestroy(end);
cudaEventDestroy(begin);
}
I don't think this approach is applicable to segmented version. But I'd like to know your opinion on this for non-segmented API. Are there any downsides I'm missing?
@senior-zero @allisonvacanti I've added short-circuiting when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In @senior-zero's earlier comment, he suggested making the is_overwrite_okay == false
case to just do a copy and skip the sorting altogether. Can we add that optimization?
Thanks for your comments! @allisonvacanti I've addressed your comments. @senior-zero I've added the copy shortcut if Could you take another look? |
@canonizer Can you rebase this on |
- Copy if begin_bit == end_bit, but overwrite not allowed - Fix style - When begin_bit == end_bit and double-buffering, don't do any sorting work - Uncommented segmented sort test - begin_bit == end_bit == 0 for upsweep/downsweep and segmented sort - Fixed begin_bit == end_bit == 0 case
d63448c
to
9b50753
Compare
@allisonvacanti @senior-zero I've synced with the latest |
@senior-zero I've addressed your comments. Could you take another look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this optimization! I'll start testing now.
Fix
begin_bit == end_bit == 0
for device-wide and segmented sort.