Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide custom comparison functions in TimestampWithTimezone #11025

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions velox/exec/VectorHasher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,32 @@ namespace facebook::velox::exec {
}()

namespace {
template <TypeKind Kind>
template <bool typeProvidesCustomComparison, TypeKind Kind>
uint64_t hashOne(DecodedVector& decoded, vector_size_t index) {
if constexpr (
Kind == TypeKind::ROW || Kind == TypeKind::ARRAY ||
Kind == TypeKind::MAP) {
// Virtual function call for complex type.
return decoded.base()->hashValueAt(decoded.index(index));
}
// Inlined for scalars.
using T = typename KindToFlatVector<Kind>::HashRowType;
if constexpr (std::is_floating_point_v<T>) {
return util::floating_point::NaNAwareHash<T>()(decoded.valueAt<T>(index));
} else {
return folly::hasher<T>()(decoded.valueAt<T>(index));
// Inlined for scalars.
using T = typename KindToFlatVector<Kind>::HashRowType;
T value = decoded.valueAt<T>(index);

if constexpr (typeProvidesCustomComparison) {
return static_cast<const CanProvideCustomComparisonType<Kind>*>(
decoded.base()->type().get())
->hash(value);
} else if constexpr (std::is_floating_point_v<T>) {
return util::floating_point::NaNAwareHash<T>()(value);
} else {
return folly::hasher<T>()(value);
}
}
}
} // namespace

template <TypeKind Kind>
template <bool typeProvidesCustomComparison, TypeKind Kind>
void VectorHasher::hashValues(
const SelectivityVector& rows,
bool mix,
Expand All @@ -79,7 +86,7 @@ void VectorHasher::hashValues(
if (decoded_.isConstantMapping()) {
auto hash = decoded_.isNullAt(rows.begin())
? kNullHash
: hashOne<Kind>(decoded_, rows.begin());
: hashOne<typeProvidesCustomComparison, Kind>(decoded_, rows.begin());
rows.applyToSelected([&](vector_size_t row) {
result[row] = mix ? bits::hashMix(result[row], hash) : hash;
});
Expand All @@ -96,7 +103,7 @@ void VectorHasher::hashValues(
auto baseIndex = decoded_.index(row);
uint64_t hash = cachedHashes_[baseIndex];
if (hash == kNullHash) {
hash = hashOne<Kind>(decoded_, row);
hash = hashOne<typeProvidesCustomComparison, Kind>(decoded_, row);
cachedHashes_[baseIndex] = hash;
}
result[row] = mix ? bits::hashMix(result[row], hash) : hash;
Expand All @@ -107,7 +114,7 @@ void VectorHasher::hashValues(
result[row] = mix ? bits::hashMix(result[row], kNullHash) : kNullHash;
return;
}
auto hash = hashOne<Kind>(decoded_, row);
auto hash = hashOne<typeProvidesCustomComparison, Kind>(decoded_, row);
result[row] = mix ? bits::hashMix(result[row], hash) : hash;
});
}
Expand Down Expand Up @@ -543,8 +550,13 @@ void VectorHasher::hash(
result[row] = mix ? bits::hashMix(result[row], kNullHash) : kNullHash;
});
} else {
VELOX_DYNAMIC_TYPE_DISPATCH(
hashValues, typeKind_, rows, mix, result.data());
if (type_->providesCustomComparison()) {
VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
hashValues, true, typeKind_, rows, mix, result.data());
} else {
VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
hashValues, false, typeKind_, rows, mix, result.data());
}
}
}

Expand All @@ -566,8 +578,14 @@ void VectorHasher::precompute(const BaseVector& value) {

const SelectivityVector rows(1, true);
decoded_.decode(value, rows);
precomputedHash_ =
VELOX_DYNAMIC_TYPE_DISPATCH(hashOne, typeKind_, decoded_, 0);

if (type_->providesCustomComparison()) {
precomputedHash_ = VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
hashOne, true, typeKind_, decoded_, 0);
} else {
precomputedHash_ = VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH(
hashOne, false, typeKind_, decoded_, 0);
}
}

void VectorHasher::analyze(
Expand Down
2 changes: 1 addition & 1 deletion velox/exec/VectorHasher.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ class VectorHasher {
return *reinterpret_cast<const T*>(group + offset);
}

template <TypeKind Kind>
template <bool typeProvidesCustomComparison, TypeKind Kind>
void hashValues(const SelectivityVector& rows, bool mix, uint64_t* result);

const column_index_t channel_;
Expand Down
126 changes: 126 additions & 0 deletions velox/exec/tests/VectorHasherTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <gtest/gtest.h>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/type/Type.h"
#include "velox/type/tests/utils/CustomTypesForTesting.h"
#include "velox/vector/tests/utils/VectorTestBase.h"

using namespace facebook::velox;
Expand Down Expand Up @@ -1069,3 +1070,128 @@ TEST_F(VectorHasherTest, typeMismatch) {
VELOX_ASSERT_THROW(
hasher->decode(*data, rows), "Type mismatch: BIGINT vs. VARCHAR");
}

void testCustomComparison(const VectorPtr& actual, const VectorPtr& expected) {
// Vector hasher should hash the values in actual to the same hashes as it
// does when hashing the values in expected.
ASSERT_EQ(actual->size(), expected->size());

auto vectorHasher = exec::VectorHasher::create(actual->type(), 1);
ASSERT_EQ(vectorHasher->channel(), 1);

SelectivityVector allRows = SelectivityVector(actual->size());

raw_vector<uint64_t> actualHashes(actual->size());
std::fill(actualHashes.begin(), actualHashes.end(), 0);
vectorHasher->decode(*actual, allRows);
vectorHasher->hash(allRows, false, actualHashes);

raw_vector<uint64_t> expectedHashes(expected->size());
std::fill(expectedHashes.begin(), expectedHashes.end(), 0);
vectorHasher->decode(*expected, allRows);
vectorHasher->hash(allRows, false, expectedHashes);

for (int32_t i = 0; i < actual->size(); i++) {
EXPECT_EQ(actualHashes[i], expectedHashes[i])
<< "at " << i << " values: " << actual->toString(i) << " "
<< expected->toString(i);
}
}

TEST_F(VectorHasherTest, customComparison) {
// Tests that types that provide custom comparison are hashed using the custom
// hash implementation they provide.

testCustomComparison(
makeFlatVector<int64_t>(
{0, 1, 256, 257, 512, 513}, BIGINT_TYPE_WITH_CUSTOM_COMPARISON()),
// Different values that are equal mod 256 should hash to the same value.
makeFlatVector<int64_t>(
{0, 1, 0, 1, 0, 1}, BIGINT_TYPE_WITH_CUSTOM_COMPARISON()));
}

TEST_F(VectorHasherTest, customComparisonArray) {
// Tests that types that provide custom comparison are hashed using the custom
// hash implementation they provide.

testCustomComparison(
makeNullableArrayVector<int64_t>(
{{0, 1, 2},
{256, 257, 258},
{512, 513, 514},
{3, 4, 5},
{259, 260, 261},
{515, 516, 517},
{std::nullopt}},
ARRAY(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())),
// Different values that are equal mod 256 should hash to the same value.
makeNullableArrayVector<int64_t>(
{{0, 1, 2},
{0, 1, 2},
{0, 1, 2},
{3, 4, 5},
{3, 4, 5},
{3, 4, 5},
{std::nullopt}},
ARRAY(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())));
}

TEST_F(VectorHasherTest, customComparisonMap) {
// Tests that types that provide custom comparison are hashed using the custom
// hash implementation they provide.

testCustomComparison(
makeNullableMapVector<int64_t, int64_t>(
{std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{0, 10}, {1, 11}, {2, 12}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{256, 266}, {257, 267}, {258, 268}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{512, 522}, {513, 523}, {514, 524}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{3, 103}, {4, 104}, {5, 105}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{259, 359}, {260, 360}, {261, 361}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{515, 615}, {516, 616}, {517, 617}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{0, std::nullopt}}},
MAP(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(),
test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())),
// Different values that are equal mod 256 should hash to the same value.
makeNullableMapVector<int64_t, int64_t>(
{std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{0, 10}, {1, 11}, {2, 12}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{0, 10}, {1, 11}, {2, 12}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{0, 10}, {1, 11}, {2, 12}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{3, 103}, {4, 104}, {5, 105}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{3, 103}, {4, 104}, {5, 105}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{3, 103}, {4, 104}, {5, 105}},
std::vector<std::pair<int64_t, std::optional<int64_t>>>{
{0, std::nullopt}}},
MAP(test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(),
test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())));
}

TEST_F(VectorHasherTest, customComparisonRow) {
// Tests that types that provide custom comparison are hashed using the custom
// hash implementation they provide.

testCustomComparison(
makeRowVector(
{"a"},
{makeNullableFlatVector<int64_t>(
{std::nullopt, 0, 1, 256, 257, 512, 513},
test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}),
// Different values that are equal mod 256 should hash to the same value.
makeRowVector(
{"a"},
{makeNullableFlatVector<int64_t>(
{std::nullopt, 0, 1, 0, 1, 0, 1},
test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}));
}
65 changes: 39 additions & 26 deletions velox/functions/prestosql/types/TimestampWithTimeZoneType.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@

namespace facebook::velox {

using TimeZoneKey = int16_t;

constexpr int32_t kTimezoneMask = 0xFFF;
constexpr int32_t kMillisShift = 12;

inline int64_t unpackMillisUtc(int64_t dateTimeWithTimeZone) {
return dateTimeWithTimeZone >> kMillisShift;
}

inline TimeZoneKey unpackZoneKeyId(int64_t dateTimeWithTimeZone) {
return dateTimeWithTimeZone & kTimezoneMask;
}

inline int64_t pack(int64_t millisUtc, TimeZoneKey timeZoneKey) {
return (millisUtc << kMillisShift) | (timeZoneKey & kTimezoneMask);
}

inline int64_t pack(const Timestamp& timestamp, TimeZoneKey timeZoneKey) {
return pack(timestamp.toMillis(), timeZoneKey);
}

inline Timestamp unpackTimestampUtc(int64_t dateTimeWithTimeZone) {
return Timestamp::fromMillis(unpackMillisUtc(dateTimeWithTimeZone));
}

class TimestampWithTimeZoneCastOperator : public exec::CastOperator {
public:
static const std::shared_ptr<const CastOperator>& get() {
Expand Down Expand Up @@ -56,7 +81,7 @@ class TimestampWithTimeZoneCastOperator : public exec::CastOperator {
/// Represents timestamp with time zone as a number of milliseconds since epoch
/// and time zone ID.
class TimestampWithTimeZoneType : public BigintType {
TimestampWithTimeZoneType() = default;
TimestampWithTimeZoneType() : BigintType(true) {}

public:
static const std::shared_ptr<const TimestampWithTimeZoneType>& get() {
Expand All @@ -72,6 +97,19 @@ class TimestampWithTimeZoneType : public BigintType {
return this == &other;
}

int32_t compare(const int64_t& left, const int64_t& right) const override {
int64_t leftUnpacked = unpackMillisUtc(left);
int64_t rightUnpacked = unpackMillisUtc(right);

return leftUnpacked < rightUnpacked ? -1
: leftUnpacked == rightUnpacked ? 0
: 1;
}

uint64_t hash(const int64_t& value) const override {
return folly::hasher<int64_t>()(unpackMillisUtc(value));
}

const char* name() const override {
return "TIMESTAMP WITH TIME ZONE";
}
Expand Down Expand Up @@ -125,29 +163,4 @@ class TimestampWithTimeZoneTypeFactories : public CustomTypeFactories {

void registerTimestampWithTimeZoneType();

using TimeZoneKey = int16_t;

constexpr int32_t kTimezoneMask = 0xFFF;
constexpr int32_t kMillisShift = 12;

inline int64_t unpackMillisUtc(int64_t dateTimeWithTimeZone) {
return dateTimeWithTimeZone >> kMillisShift;
}

inline TimeZoneKey unpackZoneKeyId(int64_t dateTimeWithTimeZone) {
return dateTimeWithTimeZone & kTimezoneMask;
}

inline int64_t pack(int64_t millisUtc, int16_t timeZoneKey) {
return (millisUtc << kMillisShift) | (timeZoneKey & kTimezoneMask);
}

inline int64_t pack(const Timestamp& timestamp, int16_t timeZoneKey) {
return pack(timestamp.toMillis(), timeZoneKey);
}

inline Timestamp unpackTimestampUtc(int64_t dateTimeWithTimeZone) {
return Timestamp::fromMillis(unpackMillisUtc(dateTimeWithTimeZone));
}

} // namespace facebook::velox
Loading
Loading