diff --git a/cpp/src/groupby/sort/group_nunique.cu b/cpp/src/groupby/sort/group_nunique.cu index 478060cbd16..b719698b6b5 100644 --- a/cpp/src/groupby/sort/group_nunique.cu +++ b/cpp/src/groupby/sort/group_nunique.cu @@ -16,7 +16,6 @@ #include #include -#include #include #include #include @@ -24,7 +23,6 @@ #include #include -#include #include #include #include @@ -34,6 +32,41 @@ namespace cudf { namespace groupby { namespace detail { namespace { + +template +struct is_unique_iterator_fn { + Nullate nulls; + column_device_view const v; + element_equality_comparator equal; + null_policy null_handling; + size_type const* group_offsets; + size_type const* group_labels; + + is_unique_iterator_fn(Nullate nulls, + column_device_view const& v, + null_policy null_handling, + size_type const* group_offsets, + size_type const* group_labels) + : nulls{nulls}, + v{v}, + equal{nulls, v, v}, + null_handling{null_handling}, + group_offsets{group_offsets}, + group_labels{group_labels} + { + } + + __device__ size_type operator()(size_type i) + { + bool is_input_countable = + !nulls || (null_handling == null_policy::INCLUDE || v.is_valid_nocheck(i)); + bool is_unique = is_input_countable && + (group_offsets[group_labels[i]] == i || // first element or + (not equal.template operator()(i, i - 1))); // new unique value in sorted + return static_cast(is_unique); + } +}; + struct nunique_functor { template std::enable_if_t(), std::unique_ptr> operator()( @@ -50,49 +83,21 @@ struct nunique_functor { if (num_groups == 0) { return result; } - auto values_view = column_device_view::create(values, stream); - if (values.has_nulls()) { - auto equal = element_equality_comparator{nullate::YES{}, *values_view, *values_view}; - auto is_unique_iterator = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - [v = *values_view, - equal, - null_handling, - group_offsets = group_offsets.data(), - group_labels = group_labels.data()] __device__(auto i) -> size_type { - bool is_input_countable = - (null_handling == null_policy::INCLUDE || v.is_valid_nocheck(i)); - bool is_unique = is_input_countable && - (group_offsets[group_labels[i]] == i || // first element or - (not equal.operator()(i, i - 1))); // new unique value in sorted - return static_cast(is_unique); - }); + auto values_view = column_device_view::create(values, stream); + auto is_unique_iterator = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + is_unique_iterator_fn{nullate::DYNAMIC{values.has_nulls()}, + *values_view, + null_handling, + group_offsets.data(), + group_labels.data()}); + thrust::reduce_by_key(rmm::exec_policy(stream), + group_labels.begin(), + group_labels.end(), + is_unique_iterator, + thrust::make_discard_iterator(), + result->mutable_view().begin()); - thrust::reduce_by_key(rmm::exec_policy(stream), - group_labels.begin(), - group_labels.end(), - is_unique_iterator, - thrust::make_discard_iterator(), - result->mutable_view().begin()); - } else { - auto equal = element_equality_comparator{nullate::NO{}, *values_view, *values_view}; - auto is_unique_iterator = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - [v = *values_view, - equal, - group_offsets = group_offsets.data(), - group_labels = group_labels.data()] __device__(auto i) -> size_type { - bool is_unique = group_offsets[group_labels[i]] == i || // first element or - (not equal.operator()(i, i - 1)); // new unique value in sorted - return static_cast(is_unique); - }); - thrust::reduce_by_key(rmm::exec_policy(stream), - group_labels.begin(), - group_labels.end(), - is_unique_iterator, - thrust::make_discard_iterator(), - result->mutable_view().begin()); - } return result; }