From 25fc81a28a60632b6d9b84daffe1f6fcfa963525 Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Wed, 25 Sep 2024 10:52:26 -0700 Subject: [PATCH 1/2] custom comparison vector hasher --- velox/exec/VectorHasher.cpp | 48 +++++++--- velox/exec/VectorHasher.h | 2 +- velox/exec/tests/VectorHasherTest.cpp | 126 ++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 16 deletions(-) diff --git a/velox/exec/VectorHasher.cpp b/velox/exec/VectorHasher.cpp index ed16155c18e6..a98e3d9b3d46 100644 --- a/velox/exec/VectorHasher.cpp +++ b/velox/exec/VectorHasher.cpp @@ -52,25 +52,32 @@ namespace facebook::velox::exec { }() namespace { -template +template uint64_t hashOne(DecodedVector& decoded, vector_size_t index) { if constexpr ( Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || Kind == TypeKind::MAP) { // Virtual function call for complex type. return decoded.base()->hashValueAt(decoded.index(index)); - } - // Inlined for scalars. - using T = typename KindToFlatVector::HashRowType; - if constexpr (std::is_floating_point_v) { - return util::floating_point::NaNAwareHash()(decoded.valueAt(index)); } else { - return folly::hasher()(decoded.valueAt(index)); + // Inlined for scalars. + using T = typename KindToFlatVector::HashRowType; + T value = decoded.valueAt(index); + + if constexpr (typeProvidesCustomComparison) { + return static_cast*>( + decoded.base()->type().get()) + ->hash(value); + } else if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareHash()(value); + } else { + return folly::hasher()(value); + } } } } // namespace -template +template void VectorHasher::hashValues( const SelectivityVector& rows, bool mix, @@ -79,7 +86,7 @@ void VectorHasher::hashValues( if (decoded_.isConstantMapping()) { auto hash = decoded_.isNullAt(rows.begin()) ? kNullHash - : hashOne(decoded_, rows.begin()); + : hashOne(decoded_, rows.begin()); rows.applyToSelected([&](vector_size_t row) { result[row] = mix ? bits::hashMix(result[row], hash) : hash; }); @@ -96,7 +103,7 @@ void VectorHasher::hashValues( auto baseIndex = decoded_.index(row); uint64_t hash = cachedHashes_[baseIndex]; if (hash == kNullHash) { - hash = hashOne(decoded_, row); + hash = hashOne(decoded_, row); cachedHashes_[baseIndex] = hash; } result[row] = mix ? bits::hashMix(result[row], hash) : hash; @@ -107,7 +114,7 @@ void VectorHasher::hashValues( result[row] = mix ? bits::hashMix(result[row], kNullHash) : kNullHash; return; } - auto hash = hashOne(decoded_, row); + auto hash = hashOne(decoded_, row); result[row] = mix ? bits::hashMix(result[row], hash) : hash; }); } @@ -543,8 +550,13 @@ void VectorHasher::hash( result[row] = mix ? bits::hashMix(result[row], kNullHash) : kNullHash; }); } else { - VELOX_DYNAMIC_TYPE_DISPATCH( - hashValues, typeKind_, rows, mix, result.data()); + if (type_->providesCustomComparison()) { + VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + hashValues, true, typeKind_, rows, mix, result.data()); + } else { + VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + hashValues, false, typeKind_, rows, mix, result.data()); + } } } @@ -566,8 +578,14 @@ void VectorHasher::precompute(const BaseVector& value) { const SelectivityVector rows(1, true); decoded_.decode(value, rows); - precomputedHash_ = - VELOX_DYNAMIC_TYPE_DISPATCH(hashOne, typeKind_, decoded_, 0); + + if (type_->providesCustomComparison()) { + precomputedHash_ = VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + hashOne, true, typeKind_, decoded_, 0); + } else { + precomputedHash_ = VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( + hashOne, false, typeKind_, decoded_, 0); + } } void VectorHasher::analyze( diff --git a/velox/exec/VectorHasher.h b/velox/exec/VectorHasher.h index 339f6deee64c..425fe267bd8a 100644 --- a/velox/exec/VectorHasher.h +++ b/velox/exec/VectorHasher.h @@ -540,7 +540,7 @@ class VectorHasher { return *reinterpret_cast(group + offset); } - template + template void hashValues(const SelectivityVector& rows, bool mix, uint64_t* result); const column_index_t channel_; diff --git a/velox/exec/tests/VectorHasherTest.cpp b/velox/exec/tests/VectorHasherTest.cpp index bceb6bf24feb..b98a5e44f69a 100644 --- a/velox/exec/tests/VectorHasherTest.cpp +++ b/velox/exec/tests/VectorHasherTest.cpp @@ -17,6 +17,7 @@ #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/type/Type.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" #include "velox/vector/tests/utils/VectorTestBase.h" using namespace facebook::velox; @@ -1069,3 +1070,128 @@ TEST_F(VectorHasherTest, typeMismatch) { VELOX_ASSERT_THROW( hasher->decode(*data, rows), "Type mismatch: BIGINT vs. VARCHAR"); } + +void testCustomComparison(const VectorPtr& actual, const VectorPtr& expected) { + // Vector hasher should hash the values in actual to the same hashes as it + // does when hashing the values in expected. + ASSERT_EQ(actual->size(), expected->size()); + + auto vectorHasher = exec::VectorHasher::create(actual->type(), 1); + ASSERT_EQ(vectorHasher->channel(), 1); + + SelectivityVector allRows = SelectivityVector(actual->size()); + + raw_vector actualHashes(actual->size()); + std::fill(actualHashes.begin(), actualHashes.end(), 0); + vectorHasher->decode(*actual, allRows); + vectorHasher->hash(allRows, false, actualHashes); + + raw_vector expectedHashes(expected->size()); + std::fill(expectedHashes.begin(), expectedHashes.end(), 0); + vectorHasher->decode(*expected, allRows); + vectorHasher->hash(allRows, false, expectedHashes); + + for (int32_t i = 0; i < actual->size(); i++) { + EXPECT_EQ(actualHashes[i], expectedHashes[i]) + << "at " << i << " values: " << actual->toString(i) << " " + << expected->toString(i); + } +} + +TEST_F(VectorHasherTest, customComparison) { + // Tests that types that provide custom comparison are hashed using the custom + // hash implementation they provide. + + testCustomComparison( + makeFlatVector( + {0, 1, 256, 257, 512, 513}, BIGINT_TYPE_WITH_CUSTOM_COMPARISON()), + // Different values that are equal mod 256 should hash to the same value. + makeFlatVector( + {0, 1, 0, 1, 0, 1}, BIGINT_TYPE_WITH_CUSTOM_COMPARISON())); +} + +TEST_F(VectorHasherTest, customComparisonArray) { + // Tests that types that provide custom comparison are hashed using the custom + // hash implementation they provide. + + testCustomComparison( + makeNullableArrayVector( + {{0, 1, 2}, + {256, 257, 258}, + {512, 513, 514}, + {3, 4, 5}, + {259, 260, 261}, + {515, 516, 517}, + {std::nullopt}}, + ARRAY(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())), + // Different values that are equal mod 256 should hash to the same value. + makeNullableArrayVector( + {{0, 1, 2}, + {0, 1, 2}, + {0, 1, 2}, + {3, 4, 5}, + {3, 4, 5}, + {3, 4, 5}, + {std::nullopt}}, + ARRAY(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()))); +} + +TEST_F(VectorHasherTest, customComparisonMap) { + // Tests that types that provide custom comparison are hashed using the custom + // hash implementation they provide. + + testCustomComparison( + makeNullableMapVector( + {std::vector>>{ + {0, 10}, {1, 11}, {2, 12}}, + std::vector>>{ + {256, 266}, {257, 267}, {258, 268}}, + std::vector>>{ + {512, 522}, {513, 523}, {514, 524}}, + std::vector>>{ + {3, 103}, {4, 104}, {5, 105}}, + std::vector>>{ + {259, 359}, {260, 360}, {261, 361}}, + std::vector>>{ + {515, 615}, {516, 616}, {517, 617}}, + std::vector>>{ + {0, std::nullopt}}}, + MAP(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())), + // Different values that are equal mod 256 should hash to the same value. + makeNullableMapVector( + {std::vector>>{ + {0, 10}, {1, 11}, {2, 12}}, + std::vector>>{ + {0, 10}, {1, 11}, {2, 12}}, + std::vector>>{ + {0, 10}, {1, 11}, {2, 12}}, + std::vector>>{ + {3, 103}, {4, 104}, {5, 105}}, + std::vector>>{ + {3, 103}, {4, 104}, {5, 105}}, + std::vector>>{ + {3, 103}, {4, 104}, {5, 105}}, + std::vector>>{ + {0, std::nullopt}}}, + MAP(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()))); +} + +TEST_F(VectorHasherTest, customComparisonRow) { + // Tests that types that provide custom comparison are hashed using the custom + // hash implementation they provide. + + testCustomComparison( + makeRowVector( + {"a"}, + {makeNullableFlatVector( + {std::nullopt, 0, 1, 256, 257, 512, 513}, + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}), + // Different values that are equal mod 256 should hash to the same value. + makeRowVector( + {"a"}, + {makeNullableFlatVector( + {std::nullopt, 0, 1, 0, 1, 0, 1}, + test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})); +} From bec870167f86f2816b63cc447429880dfca3a617 Mon Sep 17 00:00:00 2001 From: Kevin Wilfong Date: Wed, 25 Sep 2024 11:29:06 -0700 Subject: [PATCH 2/2] Provide custom comparison functions in TimestampWithTimezone (#11025) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11025 https://github.com/facebookincubator/velox/pull/11021 introduced the ability for custom types to provide their own compare and hash semantics. This change updates TimestampWithTimezone to take advantage of this new feature. It is represented as a 64-bit integer, but only the top 42 bits of the value, which represent the millis since epoch in UTC, should be used for the purposes of comparison and hashing. The bottom 12 bits which represent the timezone should be ignored. Reviewed By: pansatadru Differential Revision: D62906383 --- .../types/TimestampWithTimeZoneType.h | 65 +++++++++++------- .../tests/TimestampWithTimeZoneTypeTest.cpp | 67 +++++++++++++++++++ 2 files changed, 106 insertions(+), 26 deletions(-) diff --git a/velox/functions/prestosql/types/TimestampWithTimeZoneType.h b/velox/functions/prestosql/types/TimestampWithTimeZoneType.h index 78b56483e862..95d711232c77 100644 --- a/velox/functions/prestosql/types/TimestampWithTimeZoneType.h +++ b/velox/functions/prestosql/types/TimestampWithTimeZoneType.h @@ -22,6 +22,31 @@ namespace facebook::velox { +using TimeZoneKey = int16_t; + +constexpr int32_t kTimezoneMask = 0xFFF; +constexpr int32_t kMillisShift = 12; + +inline int64_t unpackMillisUtc(int64_t dateTimeWithTimeZone) { + return dateTimeWithTimeZone >> kMillisShift; +} + +inline TimeZoneKey unpackZoneKeyId(int64_t dateTimeWithTimeZone) { + return dateTimeWithTimeZone & kTimezoneMask; +} + +inline int64_t pack(int64_t millisUtc, TimeZoneKey timeZoneKey) { + return (millisUtc << kMillisShift) | (timeZoneKey & kTimezoneMask); +} + +inline int64_t pack(const Timestamp& timestamp, TimeZoneKey timeZoneKey) { + return pack(timestamp.toMillis(), timeZoneKey); +} + +inline Timestamp unpackTimestampUtc(int64_t dateTimeWithTimeZone) { + return Timestamp::fromMillis(unpackMillisUtc(dateTimeWithTimeZone)); +} + class TimestampWithTimeZoneCastOperator : public exec::CastOperator { public: static const std::shared_ptr& get() { @@ -56,7 +81,7 @@ class TimestampWithTimeZoneCastOperator : public exec::CastOperator { /// Represents timestamp with time zone as a number of milliseconds since epoch /// and time zone ID. class TimestampWithTimeZoneType : public BigintType { - TimestampWithTimeZoneType() = default; + TimestampWithTimeZoneType() : BigintType(true) {} public: static const std::shared_ptr& get() { @@ -72,6 +97,19 @@ class TimestampWithTimeZoneType : public BigintType { return this == &other; } + int32_t compare(const int64_t& left, const int64_t& right) const override { + int64_t leftUnpacked = unpackMillisUtc(left); + int64_t rightUnpacked = unpackMillisUtc(right); + + return leftUnpacked < rightUnpacked ? -1 + : leftUnpacked == rightUnpacked ? 0 + : 1; + } + + uint64_t hash(const int64_t& value) const override { + return folly::hasher()(unpackMillisUtc(value)); + } + const char* name() const override { return "TIMESTAMP WITH TIME ZONE"; } @@ -125,29 +163,4 @@ class TimestampWithTimeZoneTypeFactories : public CustomTypeFactories { void registerTimestampWithTimeZoneType(); -using TimeZoneKey = int16_t; - -constexpr int32_t kTimezoneMask = 0xFFF; -constexpr int32_t kMillisShift = 12; - -inline int64_t unpackMillisUtc(int64_t dateTimeWithTimeZone) { - return dateTimeWithTimeZone >> kMillisShift; -} - -inline TimeZoneKey unpackZoneKeyId(int64_t dateTimeWithTimeZone) { - return dateTimeWithTimeZone & kTimezoneMask; -} - -inline int64_t pack(int64_t millisUtc, int16_t timeZoneKey) { - return (millisUtc << kMillisShift) | (timeZoneKey & kTimezoneMask); -} - -inline int64_t pack(const Timestamp& timestamp, int16_t timeZoneKey) { - return pack(timestamp.toMillis(), timeZoneKey); -} - -inline Timestamp unpackTimestampUtc(int64_t dateTimeWithTimeZone) { - return Timestamp::fromMillis(unpackMillisUtc(dateTimeWithTimeZone)); -} - } // namespace facebook::velox diff --git a/velox/functions/prestosql/types/tests/TimestampWithTimeZoneTypeTest.cpp b/velox/functions/prestosql/types/tests/TimestampWithTimeZoneTypeTest.cpp index 8cfa5f578b6e..65a85fb0d2c3 100644 --- a/velox/functions/prestosql/types/tests/TimestampWithTimeZoneTypeTest.cpp +++ b/velox/functions/prestosql/types/tests/TimestampWithTimeZoneTypeTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/functions/prestosql/types/tests/TypeTestBase.h" +#include "velox/type/tz/TimeZoneMap.h" namespace facebook::velox::test { @@ -65,4 +66,70 @@ TEST_F(TimestampWithTimeZoneTypeTest, pack) { } } +TEST_F(TimestampWithTimeZoneTypeTest, compare) { + auto compare = [](int32_t expected, + int64_t millis1, + const std::string& tz1, + int64_t millis2, + const std::string& tz2) { + int64_t left = pack(millis1, tz::getTimeZoneID(tz1)); + int64_t right = pack(millis2, tz::getTimeZoneID(tz2)); + + ASSERT_EQ(expected, TIMESTAMP_WITH_TIME_ZONE()->compare(left, right)); + }; + + compare(0, 1639426440000, "+01:00", 1639426440000, "+03:00"); + compare(0, 1639426440000, "+01:00", 1639426440000, "-14:00"); + compare(0, 1639426440000, "+03:00", 1639426440000, "-14:00"); + compare(0, -1639426440000, "+01:00", -1639426440000, "+03:00"); + + compare(-1, 1549770072000, "+01:00", 1639426440000, "+03:00"); + compare(-1, 1549770072000, "+01:00", 1639426440000, "-14:00"); + compare(-1, 1549770072000, "+03:00", 1639426440000, "-14:00"); + compare(-1, -1639426440000, "+01:00", -1539426440000, "+03:00"); + compare(-1, -1639426440000, "+01:00", 1639426440000, "-14:00"); + + compare(1, 1639426440000, "+01:00", 1549770072000, "+03:00"); + compare(1, 1639426440000, "+01:00", 1549770072000, "-14:00"); + compare(1, 1639426440000, "+03:00", 1549770072000, "-14:00"); + compare(1, 1639426440000, "+01:00", -1639426440000, "+03:00"); + compare(1, -1539426440000, "+01:00", -1639426440000, "-14:00"); +} + +TEST_F(TimestampWithTimeZoneTypeTest, hash) { + auto expectHashesEq = [](int64_t millis1, + const std::string& tz1, + int64_t millis2, + const std::string& tz2) { + int64_t left = pack(millis1, tz::getTimeZoneID(tz1)); + int64_t right = pack(millis2, tz::getTimeZoneID(tz2)); + + ASSERT_EQ( + TIMESTAMP_WITH_TIME_ZONE()->hash(left), + TIMESTAMP_WITH_TIME_ZONE()->hash(right)); + }; + + auto expectHashesNeq = [](int64_t millis1, + const std::string& tz1, + int64_t millis2, + const std::string& tz2) { + int64_t left = pack(millis1, tz::getTimeZoneID(tz1)); + int64_t right = pack(millis2, tz::getTimeZoneID(tz2)); + + ASSERT_NE( + TIMESTAMP_WITH_TIME_ZONE()->hash(left), + TIMESTAMP_WITH_TIME_ZONE()->hash(right)); + }; + + expectHashesEq(1639426440000, "+01:00", 1639426440000, "+03:00"); + expectHashesEq(1639426440000, "+01:00", 1639426440000, "-14:00"); + expectHashesEq(1639426440000, "+03:00", 1639426440000, "-14:00"); + expectHashesEq(-1639426440000, "+03:00", -1639426440000, "-14:00"); + + expectHashesNeq(1549770072000, "+01:00", 1639426440000, "+03:00"); + expectHashesNeq(1549770072000, "+01:00", 1639426440000, "-14:00"); + expectHashesNeq(1549770072000, "+03:00", 1639426440000, "-14:00"); + expectHashesNeq(-1639426440000, "+03:00", 1639426440000, "-14:00"); +} + } // namespace facebook::velox::test