Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
P2322R6 accumulator types for reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 13, 2022
1 parent 92b501a commit a064b12
Show file tree
Hide file tree
Showing 9 changed files with 2,221 additions and 1,236 deletions.
661 changes: 361 additions & 300 deletions cub/agent/agent_reduce.cuh

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cub/block/specializations/block_reduce_warp_reductions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ struct BlockReduceWarpReductions
// Share lane aggregates
if (lane_id == 0)
{
temp_storage.warp_aggregates[warp_id] = warp_aggregate;
new (temp_storage.warp_aggregates + warp_id) T(warp_aggregate);
}

CTA_SYNC();
Expand Down
14 changes: 9 additions & 5 deletions cub/detail/type_traits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@

#pragma once

#include "../util_cpp_dialect.cuh"
#include "../util_namespace.cuh"
#include <cub/util_cpp_dialect.cuh>
#include <cub/util_namespace.cuh>

#include <type_traits>
#include <cuda/std/type_traits>


CUB_NAMESPACE_BEGIN
Expand All @@ -44,11 +44,15 @@ namespace detail {
template <typename Invokable, typename... Args>
using invoke_result_t =
#if CUB_CPP_DIALECT < 2017
typename std::result_of<Invokable(Args...)>::type;
typename cuda::std::result_of<Invokable(Args...)>::type;
#else // 2017+
std::invoke_result_t<Invokable, Args...>;
cuda::std::invoke_result_t<Invokable, Args...>;
#endif

/// The type of intermediate accumulator (according to P2322R6)
template <typename Invokable, typename InitT, typename InputT>
using accumulator_t =
typename cuda::std::decay<invoke_result_t<Invokable, InitT, InputT>>::type;

} // namespace detail
CUB_NAMESPACE_END
32 changes: 21 additions & 11 deletions cub/device/device_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct DeviceReduce
// Signed integer type for global offsets
typedef int OffsetT;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT>::Dispatch(
return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT, T>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
Expand Down Expand Up @@ -239,14 +239,16 @@ struct DeviceReduce
cub::detail::non_void_value_t<OutputIteratorT,
cub::detail::value_t<InputIteratorT>>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Sum>::Dispatch(
using InitT = OutputT;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Sum, InitT>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Sum(),
OutputT(), // zero-initialize
InitT{}, // zero-initialize
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -314,14 +316,16 @@ struct DeviceReduce
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Min>::Dispatch(
using InitT = InputT;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Min, InitT>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Min(),
Traits<InputT>::Max(), // replace with std::numeric_limits<T>::max() when C++11 support is more prevalent
Traits<InitT>::Max(), // replace with std::numeric_limits<T>::max() when C++11 support is more prevalent
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -396,6 +400,8 @@ struct DeviceReduce
cub::detail::non_void_value_t<OutputIteratorT,
KeyValuePair<OffsetT, InputValueT>>;

using InitT = OutputTupleT;

// The output value type
using OutputValueT = typename OutputTupleT::Value;

Expand All @@ -406,9 +412,9 @@ struct DeviceReduce
ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value
OutputTupleT initial_value(1, Traits<InputValueT>::Max()); // replace with std::numeric_limits<T>::max() when C++11 support is more prevalent
InitT initial_value(1, Traits<InputValueT>::Max()); // replace with std::numeric_limits<T>::max() when C++11 support is more prevalent

return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMin>::Dispatch(
return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMin, InitT>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_indexed_in,
Expand Down Expand Up @@ -483,14 +489,16 @@ struct DeviceReduce
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Max>::Dispatch(
using InitT = InputT;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Max, InitT>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Max(),
Traits<InputT>::Lowest(), // replace with std::numeric_limits<T>::lowest() when C++11 support is more prevalent
Traits<InitT>::Lowest(), // replace with std::numeric_limits<T>::lowest() when C++11 support is more prevalent
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -568,16 +576,18 @@ struct DeviceReduce
// The output value type
using OutputValueT = typename OutputTupleT::Value;

using InitT = OutputTupleT;

// Wrapped input iterator to produce index-value <OffsetT, InputT> tuples
using ArgIndexInputIteratorT =
ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;

ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value
OutputTupleT initial_value(1, Traits<InputValueT>::Lowest()); // replace with std::numeric_limits<T>::lowest() when C++11 support is more prevalent
InitT initial_value(1, Traits<InputValueT>::Lowest()); // replace with std::numeric_limits<T>::lowest() when C++11 support is more prevalent

return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMax>::Dispatch(
return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMax, InitT>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_indexed_in,
Expand Down
Loading

0 comments on commit a064b12

Please sign in to comment.