Skip to content

Commit

Permalink
Add all_match Presto Functions
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Nov 24, 2022
1 parent babb1c3 commit 24d346c
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 0 deletions.
8 changes: 8 additions & 0 deletions velox/docs/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
Array Functions
=============================

.. function:: all_match(array(T), function(T, boolean)) → boolean

Returns whether all elements of an array match the given predicate.

Returns true if all the elements match the predicate (a special case is when the array is empty);
Returns false if one or more elements don’t match;
Returns NULL if the predicate function returns NULL for one or more elements and true for all other elements.

.. function:: array_distinct(array(E)) -> array(E)

Remove duplicate values from the input array. ::
Expand Down
118 changes: 118 additions & 0 deletions velox/functions/prestosql/ArrayAllMatch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/vector/FunctionVector.h"

namespace facebook::velox::functions {
namespace {

class AllMatchFunction : public exec::VectorFunction {
public:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 2);

exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
auto& decodedArray = *arrayDecoder.get();

auto flatArray = flattenArray(rows, args[0], decodedArray);
auto inputOffsets = flatArray->rawOffsets();
auto inputSizes = flatArray->rawSizes();

std::vector<VectorPtr> lambdaArgs = {flatArray->elements()};
auto newNumElements = flatArray->elements()->size();

SelectivityVector finalSelection;
if (!context.isFinalSelection()) {
finalSelection = toElementRows<ArrayVector>(
newNumElements, *context.finalSelection(), flatArray.get());
}

VectorPtr matchBits;
auto elementToTopLevelRows = getElementToTopLevelRows(
newNumElements, rows, flatArray.get(), context.pool());

// loop over lambda functions and apply these to elements of the base array;
// in most cases there will be only one function and the loop will run once
context.ensureWritable(rows, BOOLEAN(), result);
auto flatResult = result->asFlatVector<bool>();
exec::LocalDecodedVector bitsDecoder(context);
auto it = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
while (auto entry = it.next()) {
auto elementRows = toElementRows<ArrayVector>(
newNumElements, *entry.rows, flatArray.get());
auto wrapCapture = toWrapCapture<ArrayVector>(
newNumElements, entry.callable, *entry.rows, flatArray);

entry.callable->apply(
elementRows,
finalSelection,
wrapCapture,
&context,
lambdaArgs,
elementToTopLevelRows,
&matchBits);

bitsDecoder.get()->decode(*matchBits, elementRows);
entry.rows->applyToSelected([&](vector_size_t row) {
auto size = inputSizes[row];
auto offset = inputOffsets[row];
auto allMatch = true;
auto hasNull = false;
for (auto i = 0; i < size; ++i) {
auto idx = offset + i;
if (bitsDecoder->isNullAt(idx)) {
hasNull = true;
} else if (!bitsDecoder->valueAt<bool>(idx)) {
allMatch = false;
break;
}
}

if (hasNull && allMatch) {
flatResult->setNull(row, true);
} else {
flatResult->set(row, allMatch);
}
});
}
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), function(T) -> boolean
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("boolean")
.argumentType("array(T)")
.argumentType("function(T, boolean)")
.build()};
}
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_all_match,
AllMatchFunction::signatures(),
std::make_unique<AllMatchFunction>());

} // namespace facebook::velox::functions
1 change: 1 addition & 0 deletions velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_subdirectory(window)

add_library(
velox_functions_prestosql_impl
ArrayAllMatch.cpp
ArrayConstructor.cpp
ArrayContains.cpp
ArrayDistinct.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ inline void registerArrayCombinationsFunctions() {
}

void registerArrayFunctions() {
VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, "all_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_constructor, "array_constructor");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, "array_distinct");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_duplicates, "array_duplicates");
Expand Down
61 changes: 61 additions & 0 deletions velox/functions/prestosql/tests/ArrayAllMatchTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox;
using namespace facebook::velox::test;

class ArrayAllMatchTest : public functions::test::FunctionBaseTest {};

TEST_F(ArrayAllMatchTest, bigints) {
auto input = makeNullableArrayVector<int64_t>(
{{},
{2},
{std::numeric_limits<int64_t>::max()},
{std::numeric_limits<int64_t>::min()},
{std::nullopt, std::nullopt}, // return null if all is null
{2,
std::nullopt}, // return null if one or more is null and others matched
{1, std::nullopt, 2}}); // return false if one is not matched
auto result = evaluate<SimpleVector<bool>>(
"all_match(c0, x -> (x % 2 = 0))", makeRowVector({input}));

auto expectedResult = makeNullableFlatVector<bool>(
{true, true, false, true, std::nullopt, std::nullopt, false});
assertEqualVectors(expectedResult, result);
}

TEST_F(ArrayAllMatchTest, strings) {
auto input = makeNullableArrayVector<StringView>(
{{}, {"abc"}, {"ab", "abc"}, {std::nullopt}});
auto result = evaluate<SimpleVector<bool>>(
"all_match(c0, x -> (x == 'abc'))", makeRowVector({input}));

auto expectedResult =
makeNullableFlatVector<bool>({true, true, false, std::nullopt});
assertEqualVectors(expectedResult, result);
}

TEST_F(ArrayAllMatchTest, doubles) {
auto input =
makeNullableArrayVector<double>({{}, {1.2}, {3.0, 0}, {std::nullopt}});
auto result = evaluate<SimpleVector<bool>>(
"all_match(c0, x -> (x > 1.1))", makeRowVector({input}));

auto expectedResult =
makeNullableFlatVector<bool>({true, true, false, std::nullopt});
assertEqualVectors(expectedResult, result);
}
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_subdirectory(utils)
add_executable(
velox_functions_test
ArithmeticTest.cpp
ArrayAllMatchTest.cpp
ArrayCombinationsTest.cpp
ArrayConstructorTest.cpp
ArrayContainsTest.cpp
Expand Down

0 comments on commit 24d346c

Please sign in to comment.