Skip to content

Commit

Permalink
Add array_sum Presto function
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyles-ahana committed Aug 2, 2022
1 parent 0baf13f commit bec6db1
Show file tree
Hide file tree
Showing 9 changed files with 479 additions and 0 deletions.
5 changes: 5 additions & 0 deletions velox/docs/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ Array Functions
SELECT array_join(ARRAY [1, NULL, 2], ",") -- "1,2"
SELECT array_join(ARRAY [1, NULL, 2], ",", "0") -- "1,0,2"

.. function:: array_sum(array(T)) -> bigint/double

Returns the sum of all non-null elements of the array. If there is no non-null elements, returns 0. The behaviour is similar to aggregation function sum().
T must be coercible to double. Returns bigint if T is coercible to bigint. Otherwise, returns double.

.. function:: cardinality(x) -> bigint

Returns the cardinality (size) of the array ``x``.
Expand Down
36 changes: 36 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "velox/functions/Udf.h"
#include "velox/functions/prestosql/CheckedArithmetic.h"
#include "velox/type/Conversions.h"

namespace facebook::velox::functions {
Expand Down Expand Up @@ -280,4 +281,39 @@ struct CombinationsFunction {
}
};

template <typename T>
struct ArraySumFunction {
VELOX_DEFINE_FUNCTION_TYPES(T)
template <typename TOutput, typename TInput>
FOLLY_ALWAYS_INLINE void call(TOutput& out, const TInput& array) {
TOutput sum = 0;
for (const auto& item : array) {
if (item.has_value()) {
if constexpr (std::is_same<TOutput, int64_t>::value) {
sum = checkedPlus<TOutput>(sum, *item);
} else {
sum += *item;
}
}
}
out = sum;
return;
}

template <typename TOutput, typename TInput>
FOLLY_ALWAYS_INLINE void callNullFree(TOutput& out, const TInput& array) {
// Not nulls path
TOutput sum = 0;
for (const auto& item : array) {
if constexpr (std::is_same<TOutput, int64_t>::value) {
sum = checkedPlus<TOutput>(sum, item);
} else {
sum += item;
}
}
out = sum;
return;
}
};

} // namespace facebook::velox::functions
206 changes: 206 additions & 0 deletions velox/functions/prestosql/ArraySum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <folly/container/F14Set.h>

#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/prestosql/CheckedArithmetic.h"

namespace facebook::velox::functions {
namespace {
///
/// Implements the array_sum function.
/// See documentation at https://prestodb.io/docs/current/functions/array.html
///
template <typename TInput, typename TOutput>
class ArraySumFunction : public exec::VectorFunction {
public:
template <bool mayHaveNulls>
void applyFlat(
const SelectivityVector& rows,
ArrayVector* arrayVector,
const uint64_t* rawNulls,
const TInput* rawElements,
FlatVector<TOutput>* resultValues) const {
rows.template applyToSelected([&](vector_size_t row) {
auto start = arrayVector->offsetAt(row);
auto end = start + arrayVector->sizeAt(row);
TOutput sum = 0;
for (; start < end; start++) {
if constexpr (mayHaveNulls) {
bool isNull = bits::isBitNull(rawNulls, start);
if (!isNull) {
if constexpr (std::is_same<TOutput, int64_t>::value) {
sum = checkedPlus<TOutput>(sum, rawElements[start]);
} else {
sum += rawElements[start];
}
}
} else {
if constexpr (std::is_same<TOutput, int64_t>::value) {
sum = checkedPlus<TOutput>(sum, rawElements[start]);
} else {
sum += rawElements[start];
}
}
}
resultValues->set(row, sum);
});
}

template <bool mayHaveNulls>
void applyNonFlat(
const SelectivityVector& rows,
ArrayVector* arrayVector,
exec::LocalDecodedVector& elements,
FlatVector<TOutput>* resultValues) const {
rows.template applyToSelected([&](vector_size_t row) {
auto start = arrayVector->offsetAt(row);
auto end = start + arrayVector->sizeAt(row);
TOutput sum = 0;
for (; start < end; start++) {
if constexpr (mayHaveNulls) {
if (!elements->isNullAt(start)) {
if constexpr (std::is_same<TOutput, int64_t>::value) {
sum = checkedPlus<TOutput>(
sum, elements->template valueAt<TInput>(start));
} else {
sum += elements->template valueAt<TInput>(start);
}
}
} else {
if constexpr (std::is_same<TOutput, int64_t>::value) {
sum = checkedPlus<TOutput>(
sum, elements->template valueAt<TInput>(start));
} else {
sum += elements->template valueAt<TInput>(start);
}
}
}
resultValues->set(row, sum);
});
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx* context,
VectorPtr* result) const override {
// Prepare result vector for writing
BaseVector::ensureWritable(rows, outputType, context->pool(), result);
auto resultValues = (*result)->template asFlatVector<TOutput>();

// Acquire the array elements vector.
auto arrayVector = args[0]->as<ArrayVector>();
VELOX_CHECK(arrayVector);
auto elementsVector = arrayVector->elements();

if (elementsVector->encoding() == VectorEncoding::Simple::FLAT) {
const TInput* __restrict rawElements =
elementsVector->as<FlatVector<TInput>>()->rawValues();
const uint64_t* __restrict rawNulls = elementsVector->rawNulls();

if (elementsVector->mayHaveNulls()) {
applyFlat<true>(rows, arrayVector, rawNulls, rawElements, resultValues);
} else {
applyFlat<false>(
rows, arrayVector, rawNulls, rawElements, resultValues);
}
} else {
SelectivityVector elementsRows(elementsVector->size());
exec::LocalDecodedVector elements(context, *elementsVector, elementsRows);

if (elementsVector->mayHaveNulls()) {
applyNonFlat<true>(rows, arrayVector, elements, resultValues);
} else {
applyNonFlat<false>(rows, arrayVector, elements, resultValues);
}
}
}
};

// Create function.
std::shared_ptr<exec::VectorFunction> create(
const std::string& /* name */,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
auto elementType = inputArgs.front().type->childAt(0);

switch (elementType->kind()) {
case TypeKind::TINYINT: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::TINYINT>::NativeType,
int64_t>>();
}
case TypeKind::SMALLINT: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::SMALLINT>::NativeType,
int64_t>>();
}
case TypeKind::INTEGER: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::INTEGER>::NativeType,
int64_t>>();
}
case TypeKind::BIGINT: {
return std::make_shared<ArraySumFunction<
TypeTraits<TypeKind::BIGINT>::NativeType,
int64_t>>();
}
case TypeKind::REAL: {
return std::make_shared<
ArraySumFunction<TypeTraits<TypeKind::REAL>::NativeType, double>>();
}
case TypeKind::DOUBLE: {
return std::make_shared<
ArraySumFunction<TypeTraits<TypeKind::DOUBLE>::NativeType, double>>();
}
default: {
VELOX_FAIL("Unsupported Type")
}
}
}

// Define function signature.
// array(T1) -> T2 where T1 must be coercible to bigint or double, and
// T2 is bigint or double
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
static const std::map<std::string, std::string> s = {
{"tinyint", "bigint"},
{"smallint", "bigint"},
{"integer", "bigint"},
{"bigint", "bigint"},
{"real", "double"},
{"double", "double"}};
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures;
signatures.reserve(s.size());
for (const auto& [argType, returnType] : s) {
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType(returnType)
.argumentType(fmt::format("array({})", argType))
.build());
}
return signatures;
}
} // namespace

// Register function.
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(udf_array_sum, signatures(), create);

} // namespace facebook::velox::functions
1 change: 1 addition & 0 deletions velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_library(
ArrayIntersectExcept.cpp
ArrayPosition.cpp
ArraySort.cpp
ArraySum.cpp
ElementAt.cpp
FilterFunctions.cpp
FromUnixTime.cpp
Expand Down
104 changes: 104 additions & 0 deletions velox/functions/prestosql/benchmarks/ArraySumBenchmark.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Benchmark.h>
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Macros.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/benchmarks/FunctionBenchmarkBase.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

using namespace facebook::velox;
using namespace facebook::velox::exec;
using namespace facebook::velox::functions;

namespace {

class ArraySumBenchmark : public functions::test::FunctionBenchmarkBase {
public:
ArraySumBenchmark() : FunctionBenchmarkBase() {
functions::prestosql::registerArrayFunctions();
functions::prestosql::registerGeneralFunctions();
}

void runInteger(const std::string& functionName) {
folly::BenchmarkSuspender suspender;
vector_size_t size = 10'000;
auto arrayVector = vectorMaker_.arrayVector<int32_t>(
size,
[](auto row) { return row % 5; },
[](auto row) { return row % 23; });

auto rowVector = vectorMaker_.rowVector({arrayVector});
auto exprSet = compileExpression(
fmt::format("{}(c0)", functionName), rowVector->type());
suspender.dismiss();

doRun(exprSet, rowVector);
}

void runIntegerNulls(const std::string& functionName) {
folly::BenchmarkSuspender suspender;
vector_size_t size = 10'000;
auto arrayVector = vectorMaker_.arrayVector<int32_t>(
size,
[](auto row) { return row % 5; },
[](auto row) { return row % 23; },
[](auto row) { return (row % 513) == 0; },
[](auto row) { return (row % 13) == 0; });

auto rowVector = vectorMaker_.rowVector({arrayVector});
auto exprSet = compileExpression(
fmt::format("{}(c0)", functionName), rowVector->type());
suspender.dismiss();

doRun(exprSet, rowVector);
}

void doRun(ExprSet& exprSet, const RowVectorPtr& rowVector) {
int cnt = 0;
for (auto i = 0; i < 100; i++) {
cnt += evaluate(exprSet, rowVector)->size();
}
folly::doNotOptimizeAway(cnt);
}
};

BENCHMARK(SimpleFunction) {
ArraySumBenchmark benchmark;
benchmark.runInteger("array_sum_alt");
}

BENCHMARK_RELATIVE(VectorFunction) {
ArraySumBenchmark benchmark;
benchmark.runInteger("array_sum");
}

BENCHMARK(SimpleFunctionNulls) {
ArraySumBenchmark benchmark;
benchmark.runIntegerNulls("array_sum_alt");
}

BENCHMARK_RELATIVE(VectorFunctionNulls) {
ArraySumBenchmark benchmark;
benchmark.runIntegerNulls("array_sum");
}

} // namespace

int main(int /*argc*/, char** /*argv*/) {
folly::runBenchmarks();
return 0;
}
6 changes: 6 additions & 0 deletions velox/functions/prestosql/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ add_executable(velox_functions_prestosql_benchmarks_array_position
target_link_libraries(velox_functions_prestosql_benchmarks_array_position
${BENCHMARK_DEPENDENCIES})

add_executable(velox_functions_prestosql_benchmarks_array_sum
ArraySumBenchmark.cpp)

target_link_libraries(velox_functions_prestosql_benchmarks_array_sum
${BENCHMARK_DEPENDENCIES})

add_executable(velox_functions_prestosql_benchmarks_width_bucket
WidthBucketBenchmark.cpp)
target_link_libraries(velox_functions_prestosql_benchmarks_width_bucket
Expand Down
Loading

0 comments on commit bec6db1

Please sign in to comment.