Skip to content

Commit

Permalink
GH-45572: [C++][Compute] Add rank_normal function
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Feb 19, 2025
1 parent c7a9100 commit f671856
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 37 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ set(ARROW_UTIL_SRCS
util/logger.cc
util/logging.cc
util/key_value_metadata.cc
util/math_internal.cc
util/memory.cc
util/mutex.cc
util/ree_util.cc
Expand Down
74 changes: 58 additions & 16 deletions cpp/src/arrow/compute/kernels/vector_rank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "arrow/compute/function.h"
#include "arrow/compute/kernels/vector_sort_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/util/math_internal.h"

namespace arrow::compute::internal {

Expand Down Expand Up @@ -62,16 +63,6 @@ void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_sel
}
}

const RankOptions* GetDefaultRankOptions() {
static const auto kDefaultRankOptions = RankOptions::Defaults();
return &kDefaultRankOptions;
}

const RankQuantileOptions* GetDefaultQuantileRankOptions() {
static const auto kDefaultQuantileRankOptions = RankQuantileOptions::Defaults();
return &kDefaultQuantileRankOptions;
}

template <typename ArrowType>
Result<NullPartitionResult> DoSortAndMarkDuplicate(
ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const Array& input,
Expand Down Expand Up @@ -164,8 +155,9 @@ class SortAndMarkDuplicate : public TypeVisitor {
NullPartitionResult sorted_{};
};

// A helper class that emits rankings for the "rank_quantile" function
struct QuantileRanker {
// A CRTP-based helper class for "rank_normal" and "rank_quantile"
template <typename Derived>
struct BaseQuantileRanker {
Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) {
const int64_t length = sorted.overall_end() - sorted.overall_begin();
ARROW_ASSIGN_OR_RAISE(auto rankings,
Expand All @@ -187,10 +179,11 @@ struct QuantileRanker {
}
// The run length, i.e. the frequency of the current value
int64_t freq = run_end - it;
double quantile = (cum_freq + 0.5 * freq) / static_cast<double>(length);
const double quantile = (cum_freq + 0.5 * freq) / static_cast<double>(length);
const double value = Derived::TransformValue(quantile);
// Output quantile rank values
for (; it < run_end; ++it) {
out_begin[original_index(*it)] = quantile;
out_begin[original_index(*it)] = value;
}
cum_freq += freq;
}
Expand All @@ -199,6 +192,18 @@ struct QuantileRanker {
}
};

// A derived class that emits rankings for the "rank_quantile" function
struct QuantileRanker : public BaseQuantileRanker<QuantileRanker> {
static double TransformValue(double quantile) { return quantile; }
};

// A derived class that emits rankings for the "rank_normal" function
struct NormalRanker : public BaseQuantileRanker<NormalRanker> {
static double TransformValue(double quantile) {
return ::arrow::internal::NormalPPF(quantile);
}
};

// A helper class that emits rankings for the "rank" function
struct OrdinalRanker {
explicit OrdinalRanker(RankOptions::Tiebreaker tiebreaker) : tiebreaker_(tiebreaker) {}
Expand Down Expand Up @@ -294,6 +299,20 @@ const FunctionDoc rank_quantile_doc(
"The handling of nulls and NaNs can be changed in RankQuantileOptions."),
{"input"}, "RankQuantileOptions");

const FunctionDoc rank_normal_doc(
"Compute normal (gaussian) ranks of an array",
("This function computes a normal (gaussian) rank of the input array.\n"
"By default, null values are considered greater than any other value and\n"
"are therefore sorted at the end of the input. For floating-point types,\n"
"NaNs are considered greater than any other non-null value, but smaller\n"
"than null values.\n"
"The results are finite real values. They are obtained as if first\n"
"calling the \"rank_quantile\" function and then applying the normal\n"
"percent-point function (PPF) to the resulting quantile values.\n"
"\n"
"The handling of nulls and NaNs can be changed in RankQuantileOptions."),
{"input"}, "RankQuantileOptions");

template <typename Derived>
class RankMetaFunctionBase : public MetaFunction {
public:
Expand Down Expand Up @@ -361,11 +380,14 @@ class RankMetaFunction : public RankMetaFunctionBase<RankMetaFunction> {
}

RankMetaFunction()
: RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {}
: RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, &kDefaultOptions) {}

static inline const auto kDefaultOptions = RankOptions::Defaults();
};

class RankQuantileMetaFunction : public RankMetaFunctionBase<RankQuantileMetaFunction> {
public:
using Base = RankMetaFunctionBase<RankQuantileMetaFunction>;
using FunctionOptionsType = RankQuantileOptions;
using RankerType = QuantileRanker;

Expand All @@ -375,14 +397,34 @@ class RankQuantileMetaFunction : public RankMetaFunctionBase<RankQuantileMetaFun

RankQuantileMetaFunction()
: RankMetaFunctionBase("rank_quantile", Arity::Unary(), rank_quantile_doc,
GetDefaultQuantileRankOptions()) {}
&kDefaultOptions) {}

static inline const auto kDefaultOptions = RankQuantileOptions::Defaults();
};

class RankNormalMetaFunction : public RankMetaFunctionBase<RankNormalMetaFunction> {
public:
using Base = RankMetaFunctionBase<RankQuantileMetaFunction>;
using FunctionOptionsType = RankQuantileOptions;
using RankerType = NormalRanker;

static bool NeedsDuplicates(const RankQuantileOptions&) { return true; }

static RankerType GetRanker(const RankQuantileOptions& options) { return RankerType(); }

RankNormalMetaFunction()
: RankMetaFunctionBase("rank_normal", Arity::Unary(), rank_normal_doc,
&kDefaultOptions) {}

static inline const auto kDefaultOptions = RankQuantileOptions::Defaults();
};

} // namespace

void RegisterVectorRank(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::make_shared<RankMetaFunction>()));
DCHECK_OK(registry->AddFunction(std::make_shared<RankQuantileMetaFunction>()));
DCHECK_OK(registry->AddFunction(std::make_shared<RankNormalMetaFunction>()));
}

} // namespace arrow::compute::internal
111 changes: 91 additions & 20 deletions cpp/src/arrow/compute/kernels/vector_sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2205,7 +2205,7 @@ TEST_F(TestNestedSortIndices, SortRecordBatch) { TestSort(GetRecordBatch()); }
TEST_F(TestNestedSortIndices, SortTable) { TestSort(GetTable()); }

// ----------------------------------------------------------------------
// Tests for Rank and Quantile Rank
// Tests for Rank, Quantile Rank and Normal Rank

class BaseTestRank : public ::testing::Test {
protected:
Expand Down Expand Up @@ -2471,43 +2471,84 @@ TEST_F(TestRank, EmptyChunks) {

class TestRankQuantile : public BaseTestRank {
public:
void AssertRankQuantile(const DatumVector& datums, SortOrder order,
NullPlacement null_placement,
const std::shared_ptr<Array>& expected) {
void AssertRankQuantileGeneric(const std::string& function_name,
const DatumVector& datums, SortOrder order,
NullPlacement null_placement,
const std::shared_ptr<Array>& expected) {
ARROW_SCOPED_TRACE("function = ", function_name);
const std::vector<SortKey> sort_keys{SortKey("foo", order)};
RankQuantileOptions options(sort_keys, null_placement);
ARROW_SCOPED_TRACE("options = ", options.ToString());
for (const auto& datum : datums) {
ASSERT_OK_AND_ASSIGN(auto actual, CallFunction("rank_quantile", {datum}, &options));
ASSERT_OK_AND_ASSIGN(auto actual, CallFunction(function_name, {datum}, &options));
ValidateOutput(actual);
AssertDatumsEqual(expected, actual, /*verbose=*/true);
if (function_name == "rank_normal") {
// Normal PPF results can only be approximate
auto equal_options = EqualOptions().atol(1e-8);
AssertDatumsApproxEqual(expected, actual, /*verbose=*/true, equal_options);
} else {
AssertDatumsEqual(expected, actual, /*verbose=*/true);
}
}
}

void AssertRankQuantile(const DatumVector& datums, SortOrder order,
NullPlacement null_placement, const std::string& expected) {
AssertRankQuantile(datums, order, null_placement, ArrayFromJSON(float64(), expected));
void AssertRankQuantileGeneric(const std::string& function_name, const Datum& datum,
SortOrder order, NullPlacement null_placement,
const std::shared_ptr<Array>& expected) {
AssertRankQuantileGeneric(function_name, DatumVector{datum}, order, null_placement,
expected);
}

void AssertRankQuantile(SortOrder order, NullPlacement null_placement,
const std::shared_ptr<Array>& expected) {
AssertRankQuantile(datums_, order, null_placement, expected);
void AssertRankQuantileGeneric(const std::string& function_name,
const DatumVector& datums, SortOrder order,
NullPlacement null_placement,
const std::string& expected) {
AssertRankQuantileGeneric(function_name, datums, order, null_placement,
ArrayFromJSON(float64(), expected));
}

void AssertRankQuantile(SortOrder order, NullPlacement null_placement,
const std::string& expected) {
AssertRankQuantile(datums_, order, null_placement,
ArrayFromJSON(float64(), expected));
void AssertRankQuantileGeneric(const std::string& function_name, const Datum& datum,
SortOrder order, NullPlacement null_placement,
const std::string& expected) {
AssertRankQuantileGeneric(function_name, DatumVector{datum}, order, null_placement,
ArrayFromJSON(float64(), expected));
}

void AssertRankQuantileGeneric(const std::string& function_name, SortOrder order,
NullPlacement null_placement,
const std::shared_ptr<Array>& expected) {
AssertRankQuantileGeneric(function_name, datums_, order, null_placement, expected);
}

void AssertRankQuantileGeneric(const std::string& function_name, SortOrder order,
NullPlacement null_placement,
const std::string& expected) {
AssertRankQuantileGeneric(function_name, datums_, order, null_placement,
ArrayFromJSON(float64(), expected));
}

template <typename... Args>
void AssertRankQuantile(Args&&... args) {
AssertRankQuantileGeneric("rank_quantile", std::forward<Args>(args)...);
}

template <typename... Args>
void AssertRankNormal(Args&&... args) {
AssertRankQuantileGeneric("rank_normal", std::forward<Args>(args)...);
}

void AssertRankQuantileEmpty(std::shared_ptr<DataType> type) {
for (auto null_placement : AllNullPlacements()) {
for (auto order : AllOrders()) {
AssertRankQuantile({ArrayFromJSON(type, "[]")}, order, null_placement, "[]");
AssertRankQuantile({ArrayFromJSON(type, "[null]")}, order, null_placement,
"[0.5]");
AssertRankQuantile({ArrayFromJSON(type, "[null, null, null]")}, order,
AssertRankQuantile(ArrayFromJSON(type, "[]"), order, null_placement, "[]");
AssertRankQuantile(ArrayFromJSON(type, "[null]"), order, null_placement, "[0.5]");
AssertRankQuantile(ArrayFromJSON(type, "[null, null, null]"), order,
null_placement, "[0.5, 0.5, 0.5]");

AssertRankNormal(ArrayFromJSON(type, "[]"), order, null_placement, "[]");
AssertRankNormal(ArrayFromJSON(type, "[null]"), order, null_placement, "[0.0]");
AssertRankNormal(ArrayFromJSON(type, "[null, null, null]"), order, null_placement,
"[0.0, 0.0, 0.0]");
}
}
}
Expand All @@ -2519,6 +2560,12 @@ class TestRankQuantile : public BaseTestRank {
"[0.3, 0.8, 0.3, 0.8, 0.3]");
AssertRankQuantile(SortOrder::Descending, null_placement,
"[0.7, 0.2, 0.7, 0.2, 0.7]");
AssertRankNormal(SortOrder::Ascending, null_placement,
"[-0.5244005127080409, 0.8416212335729143, -0.5244005127080409, "
"0.8416212335729143, -0.5244005127080409]");
AssertRankNormal(SortOrder::Descending, null_placement,
"[0.5244005127080407, -0.8416212335729142, 0.5244005127080407, "
"-0.8416212335729142, 0.5244005127080407]");
}
}

Expand All @@ -2532,6 +2579,19 @@ class TestRankQuantile : public BaseTestRank {
"[0.3, 0.9, 0.3, 0.7, 0.3]");
AssertRankQuantile(SortOrder::Descending, NullPlacement::AtEnd,
"[0.7, 0.3, 0.7, 0.1, 0.7]");

AssertRankNormal(SortOrder::Ascending, NullPlacement::AtStart,
"[-0.5244005127080409, 0.5244005127080407, -0.5244005127080409, "
"1.2815515655446004, -0.5244005127080409]");
AssertRankNormal(SortOrder::Ascending, NullPlacement::AtEnd,
"[0.5244005127080407, -1.2815515655446004, 0.5244005127080407, "
"-0.5244005127080409, 0.5244005127080407]");
AssertRankNormal(SortOrder::Descending, NullPlacement::AtStart,
"[-0.5244005127080409, 1.2815515655446004, -0.5244005127080409, "
"0.5244005127080407, -0.5244005127080409]");
AssertRankNormal(SortOrder::Descending, NullPlacement::AtEnd,
"[0.5244005127080407, -0.5244005127080409, 0.5244005127080407, "
"-1.2815515655446004, 0.5244005127080407]");
}

void AssertRankQuantileNumeric(std::shared_ptr<DataType> type) {
Expand All @@ -2545,6 +2605,17 @@ class TestRankQuantile : public BaseTestRank {
"[0.95, 0.8, 0.8, 0.6, 0.6, 0.35, 0.35, 0.35, 0.15, 0.05]");
AssertRankQuantile(SortOrder::Descending, null_placement,
"[0.05, 0.2, 0.2, 0.4, 0.4, 0.65, 0.65, 0.65, 0.85, 0.95]");

AssertRankNormal(SortOrder::Ascending, null_placement,
"[1.6448536269514722, 0.8416212335729143, 0.8416212335729143, "
"0.2533471031357997, 0.2533471031357997, -0.38532046640756773, "
"-0.38532046640756773, -0.38532046640756773, -1.0364333894937898, "
"-1.6448536269514729]");
AssertRankNormal(SortOrder::Descending, null_placement,
"[-1.6448536269514729, -0.8416212335729142, -0.8416212335729142, "
"-0.2533471031357997, -0.2533471031357997, 0.38532046640756773, "
"0.38532046640756773, 0.38532046640756773, 1.0364333894937898, "
"1.6448536269514722]");
}

// With nulls
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_arrow_test(utility-test
list_util_test.cc
logger_test.cc
logging_test.cc
math_test.cc
queue_test.cc
range_test.cc
ree_util_test.cc
Expand Down
Loading

0 comments on commit f671856

Please sign in to comment.