Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1721 from senior-zero/fix-main/github/scan_interm…
Browse files Browse the repository at this point in the history
…ediate_type

Fixing scan accumulator types for NVIDIA/cub#511
  • Loading branch information
gevtushenko authored Aug 3, 2022
2 parents f2ba086 + 20ba21c commit f3c28da
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
6 changes: 4 additions & 2 deletions thrust/system/cuda/detail/async/exclusive_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ async_exclusive_scan_n(execution_policy<DerivedPolicy>& policy,
OutputIt,
BinaryOp,
InputValueT,
thrust::detail::int32_t>;
thrust::detail::int32_t,
InitialValueType>;
using Dispatch64 = cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
InputValueT,
thrust::detail::int64_t>;
thrust::detail::int64_t,
InitialValueType>;

InputValueT init_value(init);

Expand Down
7 changes: 5 additions & 2 deletions thrust/system/cuda/detail/async/inclusive_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,19 @@ async_inclusive_scan_n(execution_policy<DerivedPolicy>& policy,
OutputIt out,
BinaryOp op)
{
using AccumT = typename thrust::iterator_traits<ForwardIt>::value_type;
using Dispatch32 = cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
cub::NullType,
thrust::detail::int32_t>;
thrust::detail::int32_t,
AccumT>;
using Dispatch64 = cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
cub::NullType,
thrust::detail::int64_t>;
thrust::detail::int64_t,
AccumT>;

auto const device_alloc = get_async_device_allocator(policy);
unique_eager_event ev;
Expand Down
13 changes: 9 additions & 4 deletions thrust/system/cuda/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,19 @@ OutputIt inclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
OutputIt result,
ScanOp scan_op)
{
using AccumT = typename thrust::iterator_traits<InputIt>::value_type;
using Dispatch32 = cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
cub::NullType,
thrust::detail::int32_t>;
thrust::detail::int32_t,
AccumT>;
using Dispatch64 = cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
cub::NullType,
thrust::detail::int64_t>;
thrust::detail::int64_t,
AccumT>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
cudaError_t status;
Expand Down Expand Up @@ -141,12 +144,14 @@ OutputIt exclusive_scan_n_impl(thrust::cuda_cub::execution_policy<Derived> &poli
OutputIt,
ScanOp,
InputValueT,
thrust::detail::int32_t>;
thrust::detail::int32_t,
InitValueT>;
using Dispatch64 = cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
InputValueT,
thrust::detail::int64_t>;
thrust::detail::int64_t,
InitValueT>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
cudaError_t status;
Expand Down
13 changes: 9 additions & 4 deletions thrust/system/cuda/detail/scan_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ ValuesOutIt inclusive_scan_by_key_n(
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesInIt>;
using ValuesOutUnwrapIt =
thrust::detail::try_unwrap_contiguous_iterator_return_t<ValuesOutIt>;
using AccumT = typename thrust::iterator_traits<ValuesInUnwrapIt>::value_type;

auto keys_unwrap = thrust::detail::try_unwrap_contiguous_iterator(keys);
auto values_unwrap = thrust::detail::try_unwrap_contiguous_iterator(values);
Expand All @@ -98,14 +99,16 @@ ValuesOutIt inclusive_scan_by_key_n(
EqualityOpT,
ScanOpT,
cub::NullType,
thrust::detail::int32_t>;
thrust::detail::int32_t,
AccumT>;
using Dispatch64 = cub::DispatchScanByKey<KeysInUnwrapIt,
ValuesInUnwrapIt,
ValuesOutUnwrapIt,
EqualityOpT,
ScanOpT,
cub::NullType,
thrust::detail::int64_t>;
thrust::detail::int64_t,
AccumT>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
cudaError_t status{};
Expand Down Expand Up @@ -209,14 +212,16 @@ ValuesOutIt exclusive_scan_by_key_n(
EqualityOpT,
ScanOpT,
InitValueT,
thrust::detail::int32_t>;
thrust::detail::int32_t,
InitValueT>;
using Dispatch64 = cub::DispatchScanByKey<KeysInUnwrapIt,
ValuesInUnwrapIt,
ValuesOutUnwrapIt,
EqualityOpT,
ScanOpT,
InitValueT,
thrust::detail::int64_t>;
thrust::detail::int64_t,
InitValueT>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
cudaError_t status{};
Expand Down

0 comments on commit f3c28da

Please sign in to comment.