Skip to content

Commit

Permalink
feat(array): Add Presto function array_top_n (#12105)
Browse files Browse the repository at this point in the history
Summary:

Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function).

Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp

Differential Revision: D68031372
  • Loading branch information
peterenescu authored and facebook-github-bot committed Feb 10, 2025
1 parent 183e3fa commit bcab9a8
Show file tree
Hide file tree
Showing 4 changed files with 596 additions and 0 deletions.
159 changes: 159 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
#include "velox/expression/PrestoCastHooks.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/lib/ComparatorUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"
#include "velox/type/FloatingPointUtil.h"

#include <queue>

namespace facebook::velox::functions {

template <typename TExecCtx, bool isMax>
Expand Down Expand Up @@ -729,6 +732,162 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
}
}

/// This class implements the array_top_n function.
///
/// DEFINITION:
/// array_top_n(array(T), int) -> array(T)
/// Returns the top n elements of the array in descending order.
template <typename T>
struct ArrayTopNFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Definition for primitives.
template <typename TReturn, typename TInput>
FOLLY_ALWAYS_INLINE void
call(TReturn& result, const TInput& array, int64_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// Define comparator that wraps built-in function for basic primitives or
// calls floating point handler for NaNs.
using facebook::velox::util::floating_point::NaNAwareGreaterThan;
struct SimpleComparator {
bool operator()(
const typename TInput::element_t& a,
const typename TInput::element_t& b) const {
if constexpr (
std::is_same_v<typename TInput::element_t, float> ||
std::is_same_v<typename TInput::element_t, double>) {
return NaNAwareGreaterThan<typename TInput::element_t>{}(a, b);
} else {
return std::greater<typename TInput::element_t>{}(a, b);
}
}
};

// Define min-heap to store the top n elements.
std::priority_queue<
typename TInput::element_t,
std::vector<typename TInput::element_t>,
SimpleComparator>
minHeap;

// Iterate through the array and push elements to the min-heap.
int numNull = 0;
for (const auto& item : array) {
if (item.has_value()) {
if (minHeap.size() < n) {
minHeap.push(item.value());
} else if (!minHeap.empty()) {
if constexpr (
std::is_same_v<typename TInput::element_t, float> ||
std::is_same_v<typename TInput::element_t, double>) {
if (NaNAwareGreaterThan<typename TInput::element_t>{}(
item.value(), minHeap.top())) {
minHeap.push(item.value());
}
} else if (item.value() > minHeap.top()) {
minHeap.push(item.value());
}
}
if (minHeap.size() > n) {
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<typename TInput::element_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& item : reversed) {
result.push_back(item);
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}

// Generic implementation.
FOLLY_ALWAYS_INLINE void call(
out_type<Array<Orderable<T1>>>& result,
const arg_type<Array<Orderable<T1>>>& array,
const int64_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// Define comparator to compare complex types.
struct ComplexTypeComparator {
const arg_type<Array<Orderable<T1>>>& array;
ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array)
: array(array) {}

bool operator()(const int64_t& a, const int64_t& b) const {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
return array[a].value().compare(array[b].value(), kFlags).value() > 0;
}
};

// Iterate through the array and push elements to the min-heap.
std::priority_queue<int64_t, std::vector<int64_t>, ComplexTypeComparator>
minHeap(array);
int numNull = 0;
for (int i = 0; i < array.size(); ++i) {
if (array[i].has_value()) {
if (minHeap.size() < n) {
minHeap.push(i);
} else if (!minHeap.empty()) {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
if (array[i]
.value()
.compare(array[minHeap.top()].value(), kFlags)
.value() > 0) {
minHeap.push(i);
}
}
if (minHeap.size() > n) {
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<int64_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& index : reversed) {
result.push_back(array[index].value());
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}
};

template <typename T>
struct ArrayTrimFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
{prefix + "trim_array"});
}

template <typename T>
inline void registerArrayTopNFunction(const std::string& prefix) {
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int64_t>(
{prefix + "array_top_n"});
}

template <typename T>
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
Expand Down Expand Up @@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) {
Array<Varchar>,
int64_t>({prefix + "trim_array"});

registerArrayTopNFunction<int8_t>(prefix);
registerArrayTopNFunction<int16_t>(prefix);
registerArrayTopNFunction<int32_t>(prefix);
registerArrayTopNFunction<int64_t>(prefix);
registerArrayTopNFunction<int128_t>(prefix);
registerArrayTopNFunction<float>(prefix);
registerArrayTopNFunction<double>(prefix);
registerArrayTopNFunction<Varchar>(prefix);
registerArrayTopNFunction<Timestamp>(prefix);
registerArrayTopNFunction<Date>(prefix);
registerArrayTopNFunction<Varbinary>(prefix);
registerArrayTopNFunction<Orderable<T1>>(prefix);

registerArrayRemoveNullFunctions<int8_t>(prefix);
registerArrayRemoveNullFunctions<int16_t>(prefix);
registerArrayRemoveNullFunctions<int32_t>(prefix);
Expand Down
Loading

0 comments on commit bcab9a8

Please sign in to comment.