From 9ef3a8e95cc404d786596cb18d54e2bb389047fa Mon Sep 17 00:00:00 2001
From: Georgy Evtushenko <evtushenko.georgy@gmail.com>
Date: Tue, 26 Apr 2022 22:12:05 +0400
Subject: [PATCH] Fix thrust::reduce_by_key for 2^31 elements

---
 testing/cuda/reduce_by_key.cu             | 112 +++++++++++++++++++++-
 thrust/system/cuda/detail/reduce_by_key.h |  82 ++++++++++++----
 2 files changed, 173 insertions(+), 21 deletions(-)

diff --git a/testing/cuda/reduce_by_key.cu b/testing/cuda/reduce_by_key.cu
index 53c43c081..8ef3632d4 100644
--- a/testing/cuda/reduce_by_key.cu
+++ b/testing/cuda/reduce_by_key.cu
@@ -1,6 +1,11 @@
-#include <unittest/unittest.h>
-#include <thrust/reduce.h>
+#include <thrust/equal.h>
 #include <thrust/execution_policy.h>
+#include <thrust/iterator/counting_iterator.h>
+#include <thrust/iterator/transform_iterator.h>
+#include <thrust/reduce.h>
+#include <unittest/unittest.h>
+
+#include <cstdint>
 
 
 template<typename ExecutionPolicy, typename Iterator1, typename Iterator2, typename Iterator3, typename Iterator4, typename Iterator5>
@@ -286,3 +291,106 @@ void TestReduceByKeyCudaStreamsNoSync()
 }
 DECLARE_UNITTEST(TestReduceByKeyCudaStreamsNoSync);
 
+
+// Maps indices to key ids
+class div_op : public thrust::unary_function<std::int64_t, std::int64_t>
+{
+  std::int64_t m_divisor;
+
+public:
+  __host__ div_op(std::int64_t divisor)
+    : m_divisor(divisor)
+  {}
+
+  __host__ __device__
+  std::int64_t operator()(std::int64_t x) const
+  {
+    return x / m_divisor;
+  }
+};
+
+// Produces unique sequence for key
+class mod_op : public thrust::unary_function<std::int64_t, std::int64_t>
+{
+  std::int64_t m_divisor;
+
+public:
+  __host__ mod_op(std::int64_t divisor)
+    : m_divisor(divisor)
+  {}
+
+  __host__ __device__
+  std::int64_t operator()(std::int64_t x) const
+  {
+    // div: 2          
+    // idx: 0 1   2 3   4 5 
+    // key: 0 0 | 1 1 | 2 2 
+    // mod: 0 1 | 0 1 | 0 1
+    // ret: 0 1   1 2   2 3
+    return (x % m_divisor) + (x / m_divisor);
+  }
+};
+
+
+void TestReduceByKeyWithBigIndexesHelper(int magnitude)
+{
+  const std::int64_t key_size_magnitude = 8;
+  ASSERT_EQUAL(true, key_size_magnitude < magnitude);
+
+  const std::int64_t num_items       = 1ll << magnitude;
+  const std::int64_t num_unique_keys = 1ll << key_size_magnitude;
+
+  // Size of each key group
+  const std::int64_t key_size = num_items / num_unique_keys;
+
+  using counting_it      = thrust::counting_iterator<std::int64_t>;
+  using transform_key_it = thrust::transform_iterator<div_op, counting_it>;
+  using transform_val_it = thrust::transform_iterator<mod_op, counting_it>;
+
+  counting_it count_begin(0ll);
+  counting_it count_end = count_begin + num_items;
+  ASSERT_EQUAL(static_cast<std::int64_t>(thrust::distance(count_begin, count_end)),
+               num_items);
+
+  transform_key_it keys_begin(count_begin, div_op{key_size});
+  transform_key_it keys_end(count_end, div_op{key_size});
+
+  transform_val_it values_begin(count_begin, mod_op{key_size});
+
+  thrust::device_vector<std::int64_t> output_keys(num_unique_keys);
+  thrust::device_vector<std::int64_t> output_values(num_unique_keys);
+
+  // example:
+  //  items:        6
+  //  unique_keys:  2
+  //  key_size:     3
+  //  keys:         0 0 0 | 1 1 1 
+  //  values:       0 1 2 | 1 2 3
+  //  result:       3       6     = sum(range(key_size)) + key_size * key_id
+  thrust::reduce_by_key(keys_begin,
+                        keys_end,
+                        values_begin,
+                        output_keys.begin(),
+                        output_values.begin());
+
+  ASSERT_EQUAL(
+    true,
+    thrust::equal(output_keys.begin(), output_keys.end(), count_begin));
+
+  thrust::host_vector<std::int64_t> result = output_values;
+
+  const std::int64_t sum = (key_size - 1) * key_size / 2;
+  for (std::int64_t key_id = 0; key_id < num_unique_keys; key_id++)
+  {
+    ASSERT_EQUAL(result[key_id], sum + key_id * key_size);
+  }
+}
+
+void TestReduceByKeyWithBigIndexes()
+{
+  TestReduceByKeyWithBigIndexesHelper(30);
+  TestReduceByKeyWithBigIndexesHelper(31);
+  TestReduceByKeyWithBigIndexesHelper(32);
+  TestReduceByKeyWithBigIndexesHelper(33);
+}
+DECLARE_UNITTEST(TestReduceByKeyWithBigIndexes);
diff --git a/thrust/system/cuda/detail/reduce_by_key.h b/thrust/system/cuda/detail/reduce_by_key.h
index ba66f6d88..87a5bb454 100644
--- a/thrust/system/cuda/detail/reduce_by_key.h
+++ b/thrust/system/cuda/detail/reduce_by_key.h
@@ -445,8 +445,9 @@ namespace __reduce_by_key {
         {
           if (segment_flags[ITEM])
           {
-            storage.raw_exchange[segment_indices[ITEM] -
-                                 num_tile_segments_prefix] = scatter_items[ITEM];
+            int idx = static_cast<int>(segment_indices[ITEM] -
+                                       num_tile_segments_prefix);
+            storage.raw_exchange[idx] = scatter_items[ITEM];
           }
         }
 
@@ -786,7 +787,7 @@ namespace __reduce_by_key {
         // so just assign one tile per block
         //
         int  tile_idx          = blockIdx.x;
-        Size tile_offset       = tile_idx * ITEMS_PER_TILE;
+        Size tile_offset       = static_cast<Size>(tile_idx) * ITEMS_PER_TILE;
         Size num_remaining     = num_items - tile_offset;
 
         if (num_remaining > ITEMS_PER_TILE)
@@ -962,7 +963,8 @@ namespace __reduce_by_key {
     return status;
   }
 
-  template <typename Derived,
+  template <typename Size,
+            typename Derived,
             typename KeysInputIt,
             typename ValuesInputIt,
             typename KeysOutputIt,
@@ -971,24 +973,23 @@ namespace __reduce_by_key {
             typename ReductionOp>
   THRUST_RUNTIME_FUNCTION
   pair<KeysOutputIt, ValuesOutputIt>
-  reduce_by_key(execution_policy<Derived>& policy,
-                KeysInputIt                keys_first,
-                KeysInputIt                keys_last,
-                ValuesInputIt              values_first,
-                KeysOutputIt               keys_output,
-                ValuesOutputIt             values_output,
-                EqualityOp                 equality_op,
-                ReductionOp                reduction_op)
+  reduce_by_key_dispatch(execution_policy<Derived>& policy,
+                         KeysInputIt                keys_first,
+                         Size                       num_items,
+                         ValuesInputIt              values_first,
+                         KeysOutputIt               keys_output,
+                         ValuesOutputIt             values_output,
+                         EqualityOp                 equality_op,
+                         ReductionOp                reduction_op)
   {
-    typedef int size_type;
-
-    size_type    num_items          = static_cast<size_type>(thrust::distance(keys_first, keys_last));
     size_t       temp_storage_bytes = 0;
     cudaStream_t stream             = cuda_cub::stream(policy);
     bool         debug_sync         = THRUST_DEBUG_SYNC_FLAG;
 
     if (num_items == 0)
+    {
       return thrust::make_pair(keys_output, values_output);
+    }
 
     cudaError_t status;
     status = doit_step(NULL,
@@ -997,7 +998,7 @@ namespace __reduce_by_key {
                        values_first,
                        keys_output,
                        values_output,
-                       reinterpret_cast<size_type*>(NULL),
+                       reinterpret_cast<Size*>(NULL),
                        equality_op,
                        reduction_op,
                        num_items,
@@ -1005,7 +1006,7 @@ namespace __reduce_by_key {
                        debug_sync);
     cuda_cub::throw_on_error(status, "reduce_by_key failed on 1st step");
 
-    size_t allocation_sizes[2] = {sizeof(size_type), temp_storage_bytes};
+    size_t allocation_sizes[2] = {sizeof(Size), temp_storage_bytes};
     void * allocations[2]      = {NULL, NULL};
 
     size_t storage_size = 0;
@@ -1026,8 +1027,8 @@ namespace __reduce_by_key {
                                  allocation_sizes);
     cuda_cub::throw_on_error(status, "reduce failed on 2nd alias_storage");
 
-    size_type* d_num_runs_out
-      = thrust::detail::aligned_reinterpret_cast<size_type*>(allocations[0]);
+    Size* d_num_runs_out
+      = thrust::detail::aligned_reinterpret_cast<Size*>(allocations[0]);
 
     status = doit_step(allocations[1],
                        temp_storage_bytes,
@@ -1054,6 +1055,49 @@ namespace __reduce_by_key {
     );
   }
 
+  template <typename Derived,
+            typename KeysInputIt,
+            typename ValuesInputIt,
+            typename KeysOutputIt,
+            typename ValuesOutputIt,
+            typename EqualityOp,
+            typename ReductionOp>
+  THRUST_RUNTIME_FUNCTION
+  pair<KeysOutputIt, ValuesOutputIt>
+  reduce_by_key(execution_policy<Derived>& policy,
+                KeysInputIt                keys_first,
+                KeysInputIt                keys_last,
+                ValuesInputIt              values_first,
+                KeysOutputIt               keys_output,
+                ValuesOutputIt             values_output,
+                EqualityOp                 equality_op,
+                ReductionOp                reduction_op)
+  {
+    using size_type = typename iterator_traits<KeysInputIt>::difference_type;
+
+    size_type num_items = thrust::distance(keys_first, keys_last);
+
+    if (num_items == 0)
+    {
+      return thrust::make_pair(keys_output, values_output);
+    }
+
+    pair<KeysOutputIt, ValuesOutputIt> result{};
+    THRUST_INDEX_TYPE_DISPATCH(result,
+                               reduce_by_key_dispatch,
+                               num_items,
+                               (policy,
+                                keys_first,
+                                num_items_fixed,
+                                values_first,
+                                keys_output,
+                                values_output,
+                                equality_op,
+                                reduction_op));
+
+    return result;
+  }
+
 }    // namespace __reduce_by_key
 
 //-------------------------