Skip to content

Commit

Permalink
feat(array): Add Presto function array_top_n (#12105)
Browse files Browse the repository at this point in the history
Summary:

Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function).

Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp

Differential Revision: D68031372
  • Loading branch information
peterenescu authored and facebook-github-bot committed Feb 11, 2025
1 parent b71648f commit ac17ddc
Show file tree
Hide file tree
Showing 4 changed files with 618 additions and 1 deletion.
154 changes: 153 additions & 1 deletion velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
#include "velox/expression/PrestoCastHooks.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/lib/ComparatorUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"
#include "velox/type/FloatingPointUtil.h"

#include <queue>

namespace facebook::velox::functions {

template <typename TExecCtx, bool isMax>
Expand Down Expand Up @@ -729,13 +732,162 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
}
}

/// This class implements the array_top_n function.
///
/// DEFINITION:
/// array_top_n(array(T), int) -> array(T)
/// Returns the top n elements of the array in descending order.
template <typename T>
struct ArrayTopNFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Definition for primitives.
template <typename TReturn, typename TInput>
FOLLY_ALWAYS_INLINE void
call(TReturn& result, const TInput& array, int32_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// If top n is zero or input array is empty then exit early.
if (n == 0 || array.size() == 0) {
return;
}

// Define comparator that wraps built-in function for basic primitives or
// calls floating point handler for NaNs.
using facebook::velox::util::floating_point::NaNAwareGreaterThan;
struct GreaterThanComparator {
bool operator()(
const typename TInput::element_t& a,
const typename TInput::element_t& b) const {
if constexpr (
std::is_same_v<typename TInput::element_t, float> ||
std::is_same_v<typename TInput::element_t, double>) {
return NaNAwareGreaterThan<typename TInput::element_t>{}(a, b);
} else {
return std::greater<typename TInput::element_t>{}(a, b);
}
}
};

// Define min-heap to store the top n elements.
std::priority_queue<
typename TInput::element_t,
std::vector<typename TInput::element_t>,
GreaterThanComparator>
minHeap;

// Iterate through the array and push elements to the min-heap.
GreaterThanComparator comparator;
int numNull = 0;
for (const auto& item : array) {
if (item.has_value()) {
if (minHeap.size() < n) {
minHeap.push(item.value());
} else if (comparator(item.value(), minHeap.top())) {
minHeap.push(item.value());
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<typename TInput::element_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& item : reversed) {
result.push_back(item);
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}

// Generic implementation.
FOLLY_ALWAYS_INLINE void call(
out_type<Array<Orderable<T1>>>& result,
const arg_type<Array<Orderable<T1>>>& array,
const int32_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// If top n is zero or input array is empty then exit early.
if (n == 0 || array.size() == 0) {
return;
}

// Define comparator to compare complex types.
struct ComplexTypeComparator {
const arg_type<Array<Orderable<T1>>>& array;
ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array)
: array(array) {}

bool operator()(const int32_t& a, const int32_t& b) const {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
return array[a].value().compare(array[b].value(), kFlags).value() > 0;
}
};

// Define min-heap to store the top n elements.
std::priority_queue<int32_t, std::vector<int32_t>, ComplexTypeComparator>
minHeap(array);

// Iterate through the array and push elements to the min-heap.
ComplexTypeComparator comparator(array);
int numNull = 0;
for (int i = 0; i < array.size(); ++i) {
if (array[i].has_value()) {
if (minHeap.size() < n) {
minHeap.push(i);
} else if (comparator(i, minHeap.top())) {
minHeap.push(i);
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<int32_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& index : reversed) {
result.push_back(array[index].value());
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}
};

template <typename T>
struct ArrayTrimFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Fast path for primitives.
template <typename Out, typename In>
void call(Out& out, const In& inputArray, int64_t size) {
void call(Out& out, const In& inputArray, int32_t size) {
checkIndexArrayTrim(size, inputArray.size());

int64_t end = inputArray.size() - size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
{prefix + "trim_array"});
}

template <typename T>
inline void registerArrayTopNFunction(const std::string& prefix) {
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int32_t>(
{prefix + "array_top_n"});
}

template <typename T>
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
Expand Down Expand Up @@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) {
Array<Varchar>,
int64_t>({prefix + "trim_array"});

registerArrayTopNFunction<int8_t>(prefix);
registerArrayTopNFunction<int16_t>(prefix);
registerArrayTopNFunction<int32_t>(prefix);
registerArrayTopNFunction<int64_t>(prefix);
registerArrayTopNFunction<int128_t>(prefix);
registerArrayTopNFunction<float>(prefix);
registerArrayTopNFunction<double>(prefix);
registerArrayTopNFunction<Varchar>(prefix);
registerArrayTopNFunction<Timestamp>(prefix);
registerArrayTopNFunction<Date>(prefix);
registerArrayTopNFunction<Varbinary>(prefix);
registerArrayTopNFunction<Orderable<T1>>(prefix);

registerArrayRemoveNullFunctions<int8_t>(prefix);
registerArrayRemoveNullFunctions<int16_t>(prefix);
registerArrayRemoveNullFunctions<int32_t>(prefix);
Expand Down
Loading

0 comments on commit ac17ddc

Please sign in to comment.