Skip to content

Commit

Permalink
feat: Add Presto function array_top_n (facebookincubator#12105)
Browse files Browse the repository at this point in the history
Summary:

Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function).

Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp

Differential Revision: D68031372
  • Loading branch information
peterenescu authored and facebook-github-bot committed Jan 29, 2025
1 parent a1b4ee7 commit 3c0a4e3
Show file tree
Hide file tree
Showing 4 changed files with 533 additions and 0 deletions.
117 changes: 117 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
#include "velox/expression/PrestoCastHooks.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/lib/ComparatorUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"
#include "velox/type/FloatingPointUtil.h"

#include <queue>

namespace facebook::velox::functions {

template <typename TExecCtx, bool isMax>
Expand Down Expand Up @@ -729,6 +732,120 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
}
}

/// This class implements the array_top_n function.
///
/// DEFINITION:
/// array_top_n(array(T), int) -> array(T)
/// Returns the top n elements of the array in descending order.
template <typename T>
struct ArrayTopNFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Definition for primitives.
template <typename TReturn, typename TInput>
FOLLY_ALWAYS_INLINE void
call(TReturn& result, const TInput& array, int64_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// Define min-heap to store the top n elements.
std::priority_queue<
typename TInput::element_t,
std::vector<typename TInput::element_t>,
std::greater<>>
minHeap;

// Iterate through the array and push elements to the min-heap.
int numNull = 0;
for (const auto& item : array) {
if (item.has_value()) {
minHeap.push(item.value());
if (minHeap.size() > n) {
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<typename TInput::element_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& item : reversed) {
result.push_back(item);
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}

// Generic implementation.
FOLLY_ALWAYS_INLINE void call(
out_type<Array<Orderable<T1>>>& result,
const arg_type<Array<Orderable<T1>>>& array,
const int64_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// Define comparator to compare complex types.
struct ComplexTypeComparator {
const arg_type<Array<Orderable<T1>>>& array;
ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array)
: array(array) {}

bool operator()(const int64_t& a, const int64_t& b) const {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
return array[a].value().compare(array[b].value(), kFlags).value() > 0;
}
};

// Iterate through the array and push elements to the min-heap.
std::priority_queue<int64_t, std::vector<int64_t>, ComplexTypeComparator>
minHeap(array);
int numNull = 0;
for (int i = 0; i < array.size(); ++i) {
if (array[i].has_value()) {
minHeap.push(i);
if (minHeap.size() > n) {
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<int64_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& index : reversed) {
result.push_back(array[index].value());
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}
};

template <typename T>
struct ArrayTrimFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
{prefix + "trim_array"});
}

template <typename T>
inline void registerArrayTopNFunction(const std::string& prefix) {
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int64_t>(
{prefix + "array_top_n"});
}

template <typename T>
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
Expand Down Expand Up @@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) {
Array<Varchar>,
int64_t>({prefix + "trim_array"});

registerArrayTopNFunction<int8_t>(prefix);
registerArrayTopNFunction<int16_t>(prefix);
registerArrayTopNFunction<int32_t>(prefix);
registerArrayTopNFunction<int64_t>(prefix);
registerArrayTopNFunction<int128_t>(prefix);
registerArrayTopNFunction<float>(prefix);
registerArrayTopNFunction<double>(prefix);
registerArrayTopNFunction<Varchar>(prefix);
registerArrayTopNFunction<Timestamp>(prefix);
registerArrayTopNFunction<Date>(prefix);
registerArrayTopNFunction<Varbinary>(prefix);
registerArrayTopNFunction<Orderable<T1>>(prefix);

registerArrayRemoveNullFunctions<int8_t>(prefix);
registerArrayRemoveNullFunctions<int16_t>(prefix);
registerArrayRemoveNullFunctions<int32_t>(prefix);
Expand Down
Loading

0 comments on commit 3c0a4e3

Please sign in to comment.