Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix overflow in reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Nov 20, 2022
1 parent 5ae7439 commit 207b66b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
22 changes: 20 additions & 2 deletions cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ struct AgentReduce
{
AccumT thread_aggregate{};

if (even_share.block_offset + TILE_ITEMS > even_share.block_end)
if (even_share.block_end - even_share.block_offset < TILE_ITEMS)
{
// First tile isn't full (not all threads have valid items)
int valid_items = even_share.block_end - even_share.block_offset;
Expand All @@ -374,16 +374,34 @@ struct AgentReduce
TILE_ITEMS,
Int2Type<true>(),
can_vectorize);

// Exit early to handle offset overflow
if (even_share.block_end - even_share.block_offset < even_share.block_stride)
{
// Compute block-wide reduction (all threads have valid items)
return BlockReduceT(temp_storage.reduce)
.Reduce(thread_aggregate, reduction_op);
}

even_share.block_offset += even_share.block_stride;

// Consume subsequent full tiles of input
while (even_share.block_offset + TILE_ITEMS <= even_share.block_end)
while (even_share.block_offset <= even_share.block_end - TILE_ITEMS)
{
ConsumeTile<false>(thread_aggregate,
even_share.block_offset,
TILE_ITEMS,
Int2Type<true>(),
can_vectorize);

// Exit early to handle offset overflow
if (even_share.block_end - even_share.block_offset < even_share.block_stride)
{
// Compute block-wide reduction (all threads have valid items)
return BlockReduceT(temp_storage.reduce)
.Reduce(thread_aggregate, reduction_op);
}

even_share.block_offset += even_share.block_stride;
}

Expand Down
14 changes: 7 additions & 7 deletions test/test_device_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1333,10 +1333,10 @@ __global__ void InitializeTestAccumulatorTypes(int num_items,
}
}

template <typename T>
void TestBigIndicesHelper(int magnitude)
template <typename T,
typename OffsetT>
void TestBigIndicesHelper(OffsetT num_items)
{
const std::size_t num_items = 1ll << magnitude;
thrust::constant_iterator<T> const_iter(T{1});
thrust::device_vector<std::size_t> out(1);
std::size_t* d_out = thrust::raw_pointer_cast(out.data());
Expand All @@ -1360,10 +1360,10 @@ void TestBigIndicesHelper(int magnitude)
template <typename T>
void TestBigIndices()
{
TestBigIndicesHelper<T>(30);
TestBigIndicesHelper<T>(31);
TestBigIndicesHelper<T>(32);
TestBigIndicesHelper<T>(33);
TestBigIndicesHelper<T, std::uint32_t>(1ull << 30);
TestBigIndicesHelper<T, std::uint32_t>(1ull << 31);
TestBigIndicesHelper<T, std::uint32_t>((1ull << 32) - 1);
TestBigIndicesHelper<T, std::uint64_t>(1ull << 33);
}

void TestAccumulatorTypes()
Expand Down

0 comments on commit 207b66b

Please sign in to comment.