Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fix main/GitHub/warp reduce #516

Merged
merged 3 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions cub/warp/specializations/warp_reduce_shfl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#include "../../util_type.cuh"

#include <stdint.h>
#include <type_traits>

#include <nv/target>

CUB_NAMESPACE_BEGIN

Expand Down Expand Up @@ -455,35 +458,122 @@ struct WarpReduceShfl
//---------------------------------------------------------------------
// Reduction operations
//---------------------------------------------------------------------

/// Reduction
template <
bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items
typename ReductionOp>
__device__ __forceinline__ T Reduce(
template <typename ReductionOp>
__device__ __forceinline__ T ReduceImpl(
Int2Type<0> /* all_lanes_valid */,
T input, ///< [in] Calling thread's input
int valid_items, ///< [in] Total number of valid items across the logical warp
ReductionOp reduction_op) ///< [in] Binary reduction operator
{
int last_lane = (ALL_LANES_VALID) ?
LOGICAL_WARP_THREADS - 1 :
valid_items - 1;
int last_lane = valid_items - 1;

T output = input;

// // Iterate reduction steps
// #pragma unroll
// for (int STEP = 0; STEP < STEPS; STEP++)
// {
// output = ReduceStep(output, reduction_op, last_lane, 1 << STEP, Int2Type<IsInteger<T>::IS_SMALL_UNSIGNED>());
// }
// Template-iterate reduction steps
ReduceStep(output, reduction_op, last_lane, Int2Type<0>());

return output;
}

template <typename ReductionOp>
__device__ __forceinline__ T ReduceImpl(
Int2Type<1> /* all_lanes_valid */,
T input, ///< [in] Calling thread's input
int /* valid_items */, ///< [in] Total number of valid items across the logical warp
ReductionOp reduction_op) ///< [in] Binary reduction operator
{
int last_lane = LOGICAL_WARP_THREADS - 1;

T output = input;

// Template-iterate reduction steps
ReduceStep(output, reduction_op, last_lane, Int2Type<0>());

return output;
}

// Warp reduce functions are not supported by nvc++ (NVBug 3694682)
#ifndef _NVHPC_CUDA
template <class U = T>
__device__ __forceinline__
typename std::enable_if<
std::is_same<int, U>::value
|| std::is_same<unsigned int, U>::value, T>::type
ReduceImpl(Int2Type<1> /* all_lanes_valid */,
T input,
int /* valid_items */,
cub::Sum /* reduction_op */)
{
T output = input;

NV_IF_TARGET(NV_PROVIDES_SM_80,
(output = __reduce_add_sync(member_mask, input);),
(output = ReduceImpl<cub::Sum>(Int2Type<1>{},
input,
LOGICAL_WARP_THREADS,
cub::Sum{});));

return output;
}

template <class U = T>
__device__ __forceinline__
typename std::enable_if<
std::is_same<int, U>::value
|| std::is_same<unsigned int, U>::value, T>::type
ReduceImpl(Int2Type<1> /* all_lanes_valid */,
T input,
int /* valid_items */,
cub::Min /* reduction_op */)
{
T output = input;

NV_IF_TARGET(NV_PROVIDES_SM_80,
(output = __reduce_min_sync(member_mask, input);),
(output = ReduceImpl<cub::Min>(Int2Type<1>{},
input,
LOGICAL_WARP_THREADS,
cub::Min{});));

return output;
}

template <class U = T>
__device__ __forceinline__
typename std::enable_if<
std::is_same<int, U>::value
|| std::is_same<unsigned int, U>::value, T>::type
ReduceImpl(Int2Type<1> /* all_lanes_valid */,
T input,
int /* valid_items */,
cub::Max /* reduction_op */)
{
T output = input;

NV_IF_TARGET(NV_PROVIDES_SM_80,
(output = __reduce_max_sync(member_mask, input);),
(output = ReduceImpl<cub::Max>(Int2Type<1>{},
input,
LOGICAL_WARP_THREADS,
cub::Max{});));

return output;
}
#endif // _NVHPC_CUDA

/// Reduction
template <
bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items
typename ReductionOp>
__device__ __forceinline__ T Reduce(
T input, ///< [in] Calling thread's input
int valid_items, ///< [in] Total number of valid items across the logical warp
ReductionOp reduction_op) ///< [in] Binary reduction operator
{
return ReduceImpl(
Int2Type<ALL_LANES_VALID>{}, input, valid_items, reduction_op);
}


/// Segmented reduction
template <
Expand Down
65 changes: 65 additions & 0 deletions cub/warp/warp_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,71 @@ public:
//@} end member group
};

template <typename T, int LEGACY_PTX_ARCH>
class WarpReduce<T, 1, LEGACY_PTX_ARCH>
{
private:
using _TempStorage = cub::NullType;

public:
struct TempStorage : Uninitialized<_TempStorage>
{};

__device__ __forceinline__ WarpReduce(TempStorage & /*temp_storage */)
{}

__device__ __forceinline__ T Sum(T input) { return input; }

__device__ __forceinline__ T Sum(T input, int /* valid_items */)
{
return input;
}

template <typename FlagT>
__device__ __forceinline__ T HeadSegmentedSum(T input, FlagT /* head_flag */)
{
return input;
}

template <typename FlagT>
__device__ __forceinline__ T TailSegmentedSum(T input, FlagT /* tail_flag */)
{
return input;
}

template <typename ReductionOp>
__device__ __forceinline__ T Reduce(T input, ReductionOp /* reduction_op */)
{
return input;
}

template <typename ReductionOp>
__device__ __forceinline__ T Reduce(T input,
ReductionOp /* reduction_op */,
int /* valid_items */)
{
return input;
}

template <typename ReductionOp, typename FlagT>
__device__ __forceinline__ T
HeadSegmentedReduce(T input,
FlagT /* head_flag */,
ReductionOp /* reduction_op */)
{
return input;
}

template <typename ReductionOp, typename FlagT>
__device__ __forceinline__ T
TailSegmentedReduce(T input,
FlagT /* tail_flag */,
ReductionOp /* reduction_op */)
{
return input;
}
};

/** @} */ // end group WarpModule

CUB_NAMESPACE_END
13 changes: 8 additions & 5 deletions test/test_warp_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ void Initialize(
RandomBits(bits, flag_entropy);
h_flags[i] = bits & 0x1;
}
h_flags[warps * warp_threads] = {};
h_tail_out[warps * warp_threads] = {};

// Accumulate segments (lane 0 of each warp is implicitly a segment head)
for (int warp = 0; warp < warps; ++warp)
Expand Down Expand Up @@ -483,9 +485,9 @@ void TestReduce(

// Allocate host arrays
T *h_in = new T[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS + 1];
T *h_out = new T[BLOCK_THREADS];
T *h_tail_out = new T[BLOCK_THREADS];
T *h_tail_out = new T[BLOCK_THREADS + 1];

// Initialize problem
Initialize(gen_mode, -1, h_in, h_flags, WARPS, LOGICAL_WARP_THREADS, valid_warp_threads, reduction_op, h_out, h_tail_out);
Expand Down Expand Up @@ -578,9 +580,9 @@ void TestSegmentedReduce(
// Allocate host arrays
int compare;
T *h_in = new T[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS];
T *h_head_out = new T[BLOCK_THREADS];
T *h_tail_out = new T[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS + 1];
T *h_head_out = new T[BLOCK_THREADS + 1];
T *h_tail_out = new T[BLOCK_THREADS + 1];

// Initialize problem
Initialize(gen_mode, flag_entropy, h_in, h_flags, WARPS, LOGICAL_WARP_THREADS, LOGICAL_WARP_THREADS, reduction_op, h_head_out, h_tail_out);
Expand Down Expand Up @@ -817,6 +819,7 @@ int main(int argc, char** argv)
Test<16>();
Test<9>();
Test<7>();
Test<1>();

return 0;
}
Expand Down