Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fix overflow in reduce #592

Merged
merged 2 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 63 additions & 32 deletions cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ struct AgentReduce
{
AccumT thread_aggregate{};

if (even_share.block_offset + TILE_ITEMS > even_share.block_end)
if (even_share.block_end - even_share.block_offset < TILE_ITEMS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify my assumptions: is it always true that even_share.block_end <= even_share.block_offset? Can they be equal?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not happy, that we transform the condition differently here and below. I like @canonizer suggestion below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miscco, @canonizer suggestion doesn't change the fact that this line, or the line below has to be changed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@canonizer, for segmented reduce block_offset is always less than block_end. For reduce the number of blocks is about RoundUp(num_items, tile_size), while block_offset is just block_id * TILE_ITEMS and block_end is num_items. The case of num_items == 0 is treated differently, so I don't think block_end can be equal to block_offset. Could you elaborate on why it's relevant here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether the whole algorithm would be simpler if we used int valid_items = even_share.block_end - even_share.block_offset; as the main variable instead of repeatedly computing the remaining number of items.

That said, the change is definitely correct and a large scale refactor is a bit too much right now

{
// First tile isn't full (not all threads have valid items)
int valid_items = even_share.block_end - even_share.block_offset;
Expand All @@ -368,35 +368,9 @@ struct AgentReduce
.Reduce(thread_aggregate, reduction_op, valid_items);
}

// At least one full block
ConsumeTile<true>(thread_aggregate,
even_share.block_offset,
TILE_ITEMS,
Int2Type<true>(),
can_vectorize);
even_share.block_offset += even_share.block_stride;

// Consume subsequent full tiles of input
while (even_share.block_offset + TILE_ITEMS <= even_share.block_end)
{
ConsumeTile<false>(thread_aggregate,
even_share.block_offset,
TILE_ITEMS,
Int2Type<true>(),
can_vectorize);
even_share.block_offset += even_share.block_stride;
}

// Consume a partially-full tile
if (even_share.block_offset < even_share.block_end)
{
int valid_items = even_share.block_end - even_share.block_offset;
ConsumeTile<false>(thread_aggregate,
even_share.block_offset,
valid_items,
Int2Type<false>(),
can_vectorize);
}
// Extracting this into a function saves 8% of generated kernel size by allowing to reuse
// the block reduction below. This also workaround hang in nvcc.
ConsumeFullTileRange(thread_aggregate, even_share, can_vectorize);

// Compute block-wide reduction (all threads have valid items)
return BlockReduceT(temp_storage.reduce)
Expand Down Expand Up @@ -428,8 +402,7 @@ struct AgentReduce
__device__ __forceinline__ AccumT
ConsumeTiles(GridEvenShare<OffsetT> &even_share)
{
// Initialize GRID_MAPPING_STRIP_MINE even-share descriptor for this thread
// block
// Initialize GRID_MAPPING_STRIP_MINE even-share descriptor for this thread block
even_share.template BlockInit<TILE_ITEMS, GRID_MAPPING_STRIP_MINE>();

return (IsAligned(d_in, Int2Type<ATTEMPT_VECTORIZATION>()))
Expand All @@ -438,6 +411,64 @@ struct AgentReduce
: ConsumeRange(even_share,
Int2Type < false && ATTEMPT_VECTORIZATION > ());
}

private:
/**
* @brief Reduce a contiguous segment of input tiles with more than `TILE_ITEMS` elements
* @param even_share GridEvenShare descriptor
* @param can_vectorize Whether or not we can vectorize loads
*/
template <int CAN_VECTORIZE>
__device__ __forceinline__ void
ConsumeFullTileRange(AccumT &thread_aggregate,
GridEvenShare<OffsetT> &even_share,
Int2Type<CAN_VECTORIZE> can_vectorize)
{
// At least one full block
ConsumeTile<true>(thread_aggregate,
even_share.block_offset,
TILE_ITEMS,
Int2Type<true>(),
can_vectorize);

if (even_share.block_end - even_share.block_offset < even_share.block_stride)
{
// Exit early to handle offset overflow
return;
}

even_share.block_offset += even_share.block_stride;

// Consume subsequent full tiles of input, at least one full tile was processed, so
// `even_share.block_end >= TILE_ITEMS`
while (even_share.block_offset <= even_share.block_end - TILE_ITEMS)
{
ConsumeTile<false>(thread_aggregate,
even_share.block_offset,
TILE_ITEMS,
Int2Type<true>(),
can_vectorize);

if (even_share.block_end - even_share.block_offset < even_share.block_stride)
{
// Exit early to handle offset overflow
return;
}

even_share.block_offset += even_share.block_stride;
}

// Consume a partially-full tile
if (even_share.block_offset < even_share.block_end)
{
int valid_items = even_share.block_end - even_share.block_offset;
ConsumeTile<false>(thread_aggregate,
even_share.block_offset,
valid_items,
Int2Type<false>(),
can_vectorize);
}
}
};

CUB_NAMESPACE_END
Expand Down
14 changes: 7 additions & 7 deletions test/test_device_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1333,10 +1333,10 @@ __global__ void InitializeTestAccumulatorTypes(int num_items,
}
}

template <typename T>
void TestBigIndicesHelper(int magnitude)
template <typename T,
typename OffsetT>
void TestBigIndicesHelper(OffsetT num_items)
{
const std::size_t num_items = 1ll << magnitude;
thrust::constant_iterator<T> const_iter(T{1});
thrust::device_vector<std::size_t> out(1);
std::size_t* d_out = thrust::raw_pointer_cast(out.data());
Expand All @@ -1360,10 +1360,10 @@ void TestBigIndicesHelper(int magnitude)
template <typename T>
void TestBigIndices()
{
TestBigIndicesHelper<T>(30);
TestBigIndicesHelper<T>(31);
TestBigIndicesHelper<T>(32);
TestBigIndicesHelper<T>(33);
TestBigIndicesHelper<T, std::uint32_t>(1ull << 30);
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
TestBigIndicesHelper<T, std::uint32_t>(1ull << 31);
TestBigIndicesHelper<T, std::uint32_t>((1ull << 32) - 1);
TestBigIndicesHelper<T, std::uint64_t>(1ull << 33);
}

void TestAccumulatorTypes()
Expand Down