From 7a9b14181a7e11a093f6745330dea83c0c0bd95a Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Thu, 26 Sep 2024 16:11:51 -0700 Subject: [PATCH] Support custom comparison in Presto's IN function (#11032) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11032 Building on https://github.com/facebookincubator/velox/pull/11021 this adds support for custom comparison functions provided by custom types in Presto's IN function. I was able to reuse the ComplexTypeInPredicate and the support for custom comparisons already present in BaseVector. This diff is largely just renaming ComplexTypeInPredicate to VectorSetInPredicate (to clarify it's not just for complex types anymore) and if statement to identify the case where providesCustomComparison() is true for the element type (and of course updating the tests). Making TimestampWithTimeZone a special case of bigint (comparing the millis) in the future might give a performance boost if this shows up as a bottleneck. Reviewed By: xiaoxmeng Differential Revision: D62994557 fbshipit-source-id: 82d3eb2c3d24118f7f555b3f739810c91c58fc32 --- velox/functions/prestosql/InPredicate.cpp | 38 ++++--- .../prestosql/tests/InPredicateTest.cpp | 103 +++++++++++------- 2 files changed, 84 insertions(+), 57 deletions(-) diff --git a/velox/functions/prestosql/InPredicate.cpp b/velox/functions/prestosql/InPredicate.cpp index 86d7fdeacb20..0819cd0d69f1 100644 --- a/velox/functions/prestosql/InPredicate.cpp +++ b/velox/functions/prestosql/InPredicate.cpp @@ -20,36 +20,39 @@ namespace facebook::velox::functions { namespace { +// This implements InPredicate using a set over VectorValues (pairs of +// BaseVector, index). Can be used in place of Filters for Types not supported +// by Filters or when custom comparisons are needed. // Returns NULL if // - input value is NULL // - in-list is NULL or empty // - input value doesn't have an exact match, but has an indeterminate match in // the in-list. E.g., 'array[null] in (array[1])' or 'array[1] in // (array[null])'. -class ComplexTypeInPredicate : public exec::VectorFunction { +class VectorSetInPredicate : public exec::VectorFunction { public: - struct ComplexValue { + struct VectorValue { BaseVector* vector; vector_size_t index; }; - struct ComplexValueHash { - size_t operator()(ComplexValue value) const { + struct VectorValueHash { + size_t operator()(VectorValue value) const { return value.vector->hashValueAt(value.index); } }; - struct ComplexValueEqualTo { - bool operator()(ComplexValue left, ComplexValue right) const { + struct VectorValueEqualTo { + bool operator()(VectorValue left, VectorValue right) const { return left.vector->equalValueAt(right.vector, left.index, right.index); } }; - using ComplexSet = - folly::F14FastSet; + using VectorSet = + folly::F14FastSet; - ComplexTypeInPredicate( - ComplexSet uniqueValues, + VectorSetInPredicate( + VectorSet uniqueValues, bool hasNull, VectorPtr originalValues) : uniqueValues_{std::move(uniqueValues)}, @@ -58,7 +61,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction { static std::shared_ptr create(const VectorPtr& values, vector_size_t offset, vector_size_t size) { - ComplexSet uniqueValues; + VectorSet uniqueValues; bool hasNull = false; for (auto i = offset; i < offset + size; i++) { @@ -68,7 +71,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction { uniqueValues.insert({values.get(), i}); } - return std::make_shared( + return std::make_shared( std::move(uniqueValues), hasNull, values); } @@ -126,7 +129,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction { // Set of unique values to check against. This set doesn't include any value // that is null or contains null. - const ComplexSet uniqueValues_; + const VectorSet uniqueValues_; // Boolean indicating whether one of the value was null or contained null. const bool hasNull_; @@ -339,10 +342,15 @@ class InPredicate : public exec::VectorFunction { } const auto& elements = arrayVector->elements(); + const auto& elementType = elements->type(); + + if (elementType->providesCustomComparison()) { + return VectorSetInPredicate::create(elements, offset, size); + } std::pair, bool> filter; - switch (inListType->childAt(0)->kind()) { + switch (elementType->kind()) { case TypeKind::HUGEINT: filter = createHugeintValuesFilter(elements, offset, size); break; @@ -384,7 +392,7 @@ class InPredicate : public exec::VectorFunction { case TypeKind::MAP: [[fallthrough]]; case TypeKind::ROW: - return ComplexTypeInPredicate::create(elements, offset, size); + return VectorSetInPredicate::create(elements, offset, size); default: VELOX_UNSUPPORTED( "Unsupported in-list type for IN predicate: {}", diff --git a/velox/functions/prestosql/tests/InPredicateTest.cpp b/velox/functions/prestosql/tests/InPredicateTest.cpp index b08558148513..ca8dfb71db53 100644 --- a/velox/functions/prestosql/tests/InPredicateTest.cpp +++ b/velox/functions/prestosql/tests/InPredicateTest.cpp @@ -14,7 +14,10 @@ * limitations under the License. */ #include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/lib/DateTimeFormatter.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/tz/TimeZoneMap.h" using namespace facebook::velox::test; using namespace facebook::velox::functions::test; @@ -25,32 +28,32 @@ namespace { class InPredicateTest : public FunctionBaseTest { protected: template - std::string getInList( + ArrayVectorPtr getInList( std::vector> input, - const TypePtr& type = CppToType::create()) { + const TypePtr& type) { FlatVectorPtr flatVec = makeNullableFlatVector(input, type); - std::string inList; - auto len = flatVec->size(); - auto toString = [&](vector_size_t idx) { - if (type->isDecimal()) { - if (flatVec->isNullAt(idx)) { - return std::string("null"); - } - return fmt::format( - "cast({} as {})", flatVec->toString(idx), type->toString()); - } - return flatVec->toString(idx); - }; - for (auto i = 0; i < len - 1; i++) { - inList += fmt::format("{}, ", toString(i)); - } - inList += toString(len - 1); - return inList; + return makeArrayVector({0, flatVec->size()}, flatVec); + } + + core::TypedExprPtr makeInExpression( + const std::string& probe, + const ArrayVectorPtr& inList, + const TypePtr& type) { + return std::make_shared( + BOOLEAN(), + std::vector{ + std::make_shared(type, probe), + std::make_shared(inList)}, + "in"); } template - void testValues(const TypePtr type = CppToType::create()) { + void testValues( + const TypePtr type = CppToType::create(), + std::function valueAt = [](auto row) { + return row % 17; + }) { if (type->isDecimal()) { this->options_.parseDecimalAsDouble = false; } @@ -58,17 +61,17 @@ class InPredicateTest : public FunctionBaseTest { memory::memoryManager()->addLeafPool()}; const vector_size_t size = 1'000; - auto inList = getInList({1, 3, 5}, type); + auto inList = getInList({valueAt(1), valueAt(3), valueAt(5)}, type); auto vector = makeFlatVector( - size, [](auto row) { return row % 17; }, nullptr, type); + size, [&](auto row) { return valueAt(row); }, nullptr, type); auto vectorWithNulls = makeFlatVector( - size, [](auto row) { return row % 17; }, nullEvery(7), type); + size, [&](auto row) { return valueAt(row); }, nullEvery(7), type); auto rowVector = makeRowVector({vector, vectorWithNulls}); // no nulls auto result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); auto expected = makeFlatVector(size, [](auto row) { auto n = row % 17; return n == 1 || n == 3 || n == 5; @@ -78,7 +81,7 @@ class InPredicateTest : public FunctionBaseTest { // some nulls result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); expected = makeFlatVector( size, [](auto row) { @@ -91,9 +94,10 @@ class InPredicateTest : public FunctionBaseTest { // null values in the in-list // The results can be either true or null, but not false. - inList = getInList({1, 3, std::nullopt, 5}, type); + inList = + getInList({valueAt(1), valueAt(3), std::nullopt, valueAt(5)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -105,7 +109,7 @@ class InPredicateTest : public FunctionBaseTest { assertEqualVectors(expected, result); result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -116,9 +120,9 @@ class InPredicateTest : public FunctionBaseTest { assertEqualVectors(expected, result); - inList = getInList({2, std::nullopt}, type); + inList = getInList({valueAt(2), std::nullopt}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -130,7 +134,7 @@ class InPredicateTest : public FunctionBaseTest { assertEqualVectors(expected, result); result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -173,9 +177,9 @@ class InPredicateTest : public FunctionBaseTest { rowVector = makeRowVector({dict}); - inList = getInList({2, 5, 9}, type); + inList = getInList({valueAt(2), valueAt(5), valueAt(9)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); assertEqualVectors(expected, result); // an in list with nulls only is always null. @@ -186,12 +190,16 @@ class InPredicateTest : public FunctionBaseTest { } template - void testConstantValues(const TypePtr type = CppToType::create()) { + void testConstantValues( + const TypePtr type = CppToType::create(), + std::function valueAt = [](auto row) { + return row % 17; + }) { const vector_size_t size = 1'000; auto rowVector = makeRowVector( - {makeConstant(static_cast(123), size, type), + {makeConstant(valueAt(123), size, type), BaseVector::createNullConstant(type, size, pool())}); - auto inList = getInList({1, 3, 5}, type); + auto inList = getInList({valueAt(1), valueAt(3), valueAt(5)}, type); auto constTrue = makeConstant(true, size); auto constFalse = makeConstant(false, size); @@ -199,24 +207,24 @@ class InPredicateTest : public FunctionBaseTest { // a miss auto result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); assertEqualVectors(constFalse, result); // null result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); assertEqualVectors(constNull, result); // a hit - inList = getInList({1, 123, 5}, type); + inList = getInList({valueAt(1), valueAt(123), valueAt(5)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); assertEqualVectors(constTrue, result); // a miss that is a null - inList = getInList({1, std::nullopt, 5}, type); + inList = getInList({valueAt(1), std::nullopt, valueAt(5)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); assertEqualVectors(constNull, result); } @@ -1120,5 +1128,16 @@ TEST_F(InPredicateTest, nans) { testNaNs(); } +TEST_F(InPredicateTest, TimestampWithTimeZone) { + // The millis ranges from 0-17, but after every 17th row we increment the time + // zone ID, so that no two rows have the same millis and time zone. However, + // by the semantics of TimestampWithTimeZone's comparison, it's the same 17 + // values repeated. + auto valueAt = [](auto row) { return pack(row % 17, row / 17); }; + + testValues(TIMESTAMP_WITH_TIME_ZONE(), valueAt); + testConstantValues(TIMESTAMP_WITH_TIME_ZONE(), valueAt); +} + } // namespace } // namespace facebook::velox::functions