Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
P2322R6 accumulator types for reduce by key
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 16, 2022
1 parent 8c6ac0d commit 9591ae1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
21 changes: 9 additions & 12 deletions cub/agent/agent_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ template <typename AgentReduceByKeyPolicyT,
typename NumRunsOutputIteratorT,
typename EqualityOpT,
typename ReductionOpT,
typename OffsetT>
typename OffsetT,
typename AccumT>
struct AgentReduceByKey
{
//---------------------------------------------------------------------
Expand All @@ -151,19 +152,15 @@ struct AgentReduceByKey
// The input values type
using ValueInputT = cub::detail::value_t<ValuesInputIteratorT>;

// The output values type
using ValueOutputT =
cub::detail::non_void_value_t<AggregatesOutputIteratorT, ValueInputT>;

// Tuple type for scanning (pairs accumulated segment-value with
// segment-index)
using OffsetValuePairT = KeyValuePair<OffsetT, ValueOutputT>;
using OffsetValuePairT = KeyValuePair<OffsetT, AccumT>;

// Tuple type for pairing keys and values
using KeyValuePairT = KeyValuePair<KeyOutputT, ValueOutputT>;
using KeyValuePairT = KeyValuePair<KeyOutputT, AccumT>;

// Tile status descriptor interface type
using ScanTileStateT = ReduceByKeyScanTileState<ValueOutputT, OffsetT>;
using ScanTileStateT = ReduceByKeyScanTileState<AccumT, OffsetT>;

// Guarded inequality functor
template <typename _EqualityOpT>
Expand Down Expand Up @@ -209,7 +206,7 @@ struct AgentReduceByKey
// if we're performing addition on a primitive type)
static constexpr int HAS_IDENTITY_ZERO =
(std::is_same<ReductionOpT, cub::Sum>::value) &&
(Traits<ValueOutputT>::PRIMITIVE);
(Traits<AccumT>::PRIMITIVE);

// Cache-modified Input iterator wrapper type (for applying cache modifier)
// for keys Wrap the native input pointer with
Expand Down Expand Up @@ -254,7 +251,7 @@ struct AgentReduceByKey
AgentReduceByKeyPolicyT::LOAD_ALGORITHM>;

// Parameterized BlockLoad type for values
using BlockLoadValuesT = BlockLoad<ValueOutputT,
using BlockLoadValuesT = BlockLoad<AccumT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentReduceByKeyPolicyT::LOAD_ALGORITHM>;
Expand All @@ -273,7 +270,7 @@ struct AgentReduceByKey

// Key and value exchange types
typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1];
typedef ValueOutputT ValueExchangeT[TILE_ITEMS + 1];
typedef AccumT ValueExchangeT[TILE_ITEMS + 1];

// Shared memory type for this thread block
union _TempStorage
Expand Down Expand Up @@ -509,7 +506,7 @@ struct AgentReduceByKey
KeyOutputT prev_keys[ITEMS_PER_THREAD];

// Tile values
ValueOutputT values[ITEMS_PER_THREAD];
AccumT values[ITEMS_PER_THREAD];

// Segment head flags
OffsetT head_flags[ITEMS_PER_THREAD];
Expand Down
26 changes: 15 additions & 11 deletions cub/device/dispatch/dispatch_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ template <typename AgentReduceByKeyPolicyT,
typename ScanTileStateT,
typename EqualityOpT,
typename ReductionOpT,
typename OffsetT>
typename OffsetT,
typename AccumT>
__launch_bounds__(int(AgentReduceByKeyPolicyT::BLOCK_THREADS)) __global__
void DeviceReduceByKeyKernel(KeysInputIteratorT d_keys_in,
UniqueOutputIteratorT d_unique_out,
Expand All @@ -149,7 +150,8 @@ __launch_bounds__(int(AgentReduceByKeyPolicyT::BLOCK_THREADS)) __global__
NumRunsOutputIteratorT,
EqualityOpT,
ReductionOpT,
OffsetT>;
OffsetT,
AccumT>;

// Shared memory for AgentReduceByKey
__shared__ typename AgentReduceByKeyT::TempStorage temp_storage;
Expand Down Expand Up @@ -206,7 +208,12 @@ template <typename KeysInputIteratorT,
typename NumRunsOutputIteratorT,
typename EqualityOpT,
typename ReductionOpT,
typename OffsetT>
typename OffsetT,
typename AccumT =
detail::accumulator_t<
ReductionOpT,
cub::detail::value_t<ValuesInputIteratorT>,
cub::detail::value_t<ValuesInputIteratorT>>>
struct DispatchReduceByKey
{
//-------------------------------------------------------------------------
Expand All @@ -223,20 +230,16 @@ struct DispatchReduceByKey
// The input values type
using ValueInputT = cub::detail::value_t<ValuesInputIteratorT>;

// The output values type
using ValueOutputT =
cub::detail::non_void_value_t<AggregatesOutputIteratorT, ValueInputT>;

static constexpr int INIT_KERNEL_THREADS = 128;

static constexpr int MAX_INPUT_BYTES = CUB_MAX(sizeof(KeyOutputT),
sizeof(ValueOutputT));
sizeof(AccumT));

static constexpr int COMBINED_INPUT_BYTES = sizeof(KeyOutputT) +
sizeof(ValueOutputT);
sizeof(AccumT);

// Tile status descriptor interface type
using ScanTileStateT = ReduceByKeyScanTileState<ValueOutputT, OffsetT>;
using ScanTileStateT = ReduceByKeyScanTileState<AccumT, OffsetT>;

//-------------------------------------------------------------------------
// Tuning policies
Expand Down Expand Up @@ -691,7 +694,8 @@ struct DispatchReduceByKey
ScanTileStateT,
EqualityOpT,
ReductionOpT,
OffsetT>,
OffsetT,
AccumT>,
reduce_by_key_config)))
{
break;
Expand Down
4 changes: 3 additions & 1 deletion test/test_device_reduce_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,11 @@ int Solve(
ReductionOpT reduction_op,
int num_items)
{
using AccumT = cub::detail::accumulator_t<ReductionOpT, ValueT, ValueT>;

// First item
KeyT previous = h_keys_in[0];
ValueT aggregate = h_values_in[0];
AccumT aggregate = h_values_in[0];
int num_segments = 0;

// Subsequent items
Expand Down

0 comments on commit 9591ae1

Please sign in to comment.