diff --git a/cub/agent/agent_reduce_by_key.cuh b/cub/agent/agent_reduce_by_key.cuh index 2f1ef88ac..abc146579 100644 --- a/cub/agent/agent_reduce_by_key.cuh +++ b/cub/agent/agent_reduce_by_key.cuh @@ -134,7 +134,8 @@ template + typename OffsetT, + typename AccumT> struct AgentReduceByKey { //--------------------------------------------------------------------- @@ -151,19 +152,15 @@ struct AgentReduceByKey // The input values type using ValueInputT = cub::detail::value_t; - // The output values type - using ValueOutputT = - cub::detail::non_void_value_t; - // Tuple type for scanning (pairs accumulated segment-value with // segment-index) - using OffsetValuePairT = KeyValuePair; + using OffsetValuePairT = KeyValuePair; // Tuple type for pairing keys and values - using KeyValuePairT = KeyValuePair; + using KeyValuePairT = KeyValuePair; // Tile status descriptor interface type - using ScanTileStateT = ReduceByKeyScanTileState; + using ScanTileStateT = ReduceByKeyScanTileState; // Guarded inequality functor template @@ -209,7 +206,7 @@ struct AgentReduceByKey // if we're performing addition on a primitive type) static constexpr int HAS_IDENTITY_ZERO = (std::is_same::value) && - (Traits::PRIMITIVE); + (Traits::PRIMITIVE); // Cache-modified Input iterator wrapper type (for applying cache modifier) // for keys Wrap the native input pointer with @@ -254,7 +251,7 @@ struct AgentReduceByKey AgentReduceByKeyPolicyT::LOAD_ALGORITHM>; // Parameterized BlockLoad type for values - using BlockLoadValuesT = BlockLoad; @@ -273,7 +270,7 @@ struct AgentReduceByKey // Key and value exchange types typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1]; - typedef ValueOutputT ValueExchangeT[TILE_ITEMS + 1]; + typedef AccumT ValueExchangeT[TILE_ITEMS + 1]; // Shared memory type for this thread block union _TempStorage @@ -509,7 +506,7 @@ struct AgentReduceByKey KeyOutputT prev_keys[ITEMS_PER_THREAD]; // Tile values - ValueOutputT values[ITEMS_PER_THREAD]; + AccumT values[ITEMS_PER_THREAD]; // Segment head flags OffsetT head_flags[ITEMS_PER_THREAD]; diff --git a/cub/device/dispatch/dispatch_reduce_by_key.cuh b/cub/device/dispatch/dispatch_reduce_by_key.cuh index d909ce92f..694ae6fbd 100644 --- a/cub/device/dispatch/dispatch_reduce_by_key.cuh +++ b/cub/device/dispatch/dispatch_reduce_by_key.cuh @@ -127,7 +127,8 @@ template + typename OffsetT, + typename AccumT> __launch_bounds__(int(AgentReduceByKeyPolicyT::BLOCK_THREADS)) __global__ void DeviceReduceByKeyKernel(KeysInputIteratorT d_keys_in, UniqueOutputIteratorT d_unique_out, @@ -149,7 +150,8 @@ __launch_bounds__(int(AgentReduceByKeyPolicyT::BLOCK_THREADS)) __global__ NumRunsOutputIteratorT, EqualityOpT, ReductionOpT, - OffsetT>; + OffsetT, + AccumT>; // Shared memory for AgentReduceByKey __shared__ typename AgentReduceByKeyT::TempStorage temp_storage; @@ -206,7 +208,12 @@ template + typename OffsetT, + typename AccumT = + detail::accumulator_t< + ReductionOpT, + cub::detail::value_t, + cub::detail::value_t>> struct DispatchReduceByKey { //------------------------------------------------------------------------- @@ -223,20 +230,16 @@ struct DispatchReduceByKey // The input values type using ValueInputT = cub::detail::value_t; - // The output values type - using ValueOutputT = - cub::detail::non_void_value_t; - static constexpr int INIT_KERNEL_THREADS = 128; static constexpr int MAX_INPUT_BYTES = CUB_MAX(sizeof(KeyOutputT), - sizeof(ValueOutputT)); + sizeof(AccumT)); static constexpr int COMBINED_INPUT_BYTES = sizeof(KeyOutputT) + - sizeof(ValueOutputT); + sizeof(AccumT); // Tile status descriptor interface type - using ScanTileStateT = ReduceByKeyScanTileState; + using ScanTileStateT = ReduceByKeyScanTileState; //------------------------------------------------------------------------- // Tuning policies @@ -691,7 +694,8 @@ struct DispatchReduceByKey ScanTileStateT, EqualityOpT, ReductionOpT, - OffsetT>, + OffsetT, + AccumT>, reduce_by_key_config))) { break; diff --git a/test/test_device_reduce_by_key.cu b/test/test_device_reduce_by_key.cu index 168c4ecac..2085053e4 100644 --- a/test/test_device_reduce_by_key.cu +++ b/test/test_device_reduce_by_key.cu @@ -290,9 +290,11 @@ int Solve( ReductionOpT reduction_op, int num_items) { + using AccumT = cub::detail::accumulator_t; + // First item KeyT previous = h_keys_in[0]; - ValueT aggregate = h_values_in[0]; + AccumT aggregate = h_values_in[0]; int num_segments = 0; // Subsequent items