Skip to content

Commit

Permalink
Support custom comparison in Map Aggregations (attempt 2) (#11155)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #11155

This is a second attempt to land the changes in
#11124

The original description:
Building on #11021 this adds support
for custom comparison functions provided by custom types in the map aggregations
map_agg and map_union.

Note that I did not add support in map_union_sum because it currently only supports
Maps with a fixed set of key types, and none of those provide custom comparisons.

New context:
I landed this along with #11119 so it
got reverted along with it.  This particular change did not introduce any issues
though.

Reviewed By: xiaoxmeng

Differential Revision: D63796041

fbshipit-source-id: 06e7ed5e5f62e8671341c6b938429aa05b23c177
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Oct 4, 2024
1 parent ad7d0cf commit ef28741
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 19 deletions.
75 changes: 75 additions & 0 deletions velox/functions/prestosql/aggregates/MapAccumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,81 @@ struct MapAccumulatorTypeTraits<ComplexType> {

} // namespace detail

/// A wrapper around MapAccumulator that overrides hash and equal_to functions
/// to use the custom comparisons provided by a custom type.
template <TypeKind Kind>
struct CustomComparisonMapAccumulator {
using NativeType = typename TypeTraits<Kind>::NativeType;

struct Hash {
const TypePtr& type;

size_t operator()(const NativeType& value) const {
return static_cast<const CanProvideCustomComparisonType<Kind>*>(
type.get())
->hash(value);
}
};

struct EqualTo {
const TypePtr& type;

bool operator()(const NativeType& left, const NativeType& right) const {
return static_cast<const CanProvideCustomComparisonType<Kind>*>(
type.get())
->compare(left, right) == 0;
}
};

/// The underlying MapAccumulator to which all operations are delegated.
detail::MapAccumulator<
NativeType,
CustomComparisonMapAccumulator::Hash,
CustomComparisonMapAccumulator::EqualTo>
base;

CustomComparisonMapAccumulator(
const TypePtr& type,
HashStringAllocator* allocator)
: base{
CustomComparisonMapAccumulator::Hash{type},
CustomComparisonMapAccumulator::EqualTo{type},
allocator} {}

/// Adds key-value pair if entry with that key doesn't exist yet.
void insert(
const DecodedVector& decodedKeys,
const DecodedVector& decodedValues,
vector_size_t index,
HashStringAllocator& allocator) {
return base.insert(decodedKeys, decodedValues, index, allocator);
}

/// Returns number of key-value pairs.
size_t size() const {
return base.size();
}

void extract(
const VectorPtr& mapKeys,
const VectorPtr& mapValues,
vector_size_t offset) {
base.extract(mapKeys, mapValues, offset);
}

void extractValues(
const VectorPtr& mapValues,
vector_size_t offset,
int32_t mapSize,
const folly::F14FastMap<int32_t, int32_t>& indices) {
base.extractValues(mapValues, offset, mapSize, indices);
}

void free(HashStringAllocator& allocator) {
base.free(allocator);
}
};

template <typename T>
using MapAccumulator =
typename detail::MapAccumulatorTypeTraits<T>::AccumulatorType;
Expand Down
29 changes: 22 additions & 7 deletions velox/functions/prestosql/aggregates/MapAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ namespace facebook::velox::aggregate::prestosql {
namespace {
// See documentation at
// https://prestodb.io/docs/current/functions/aggregate.html
template <typename K>
class MapAggAggregate : public MapAggregateBase<K> {
template <typename K, typename AccumulatorType = MapAccumulator<K>>
class MapAggAggregate : public MapAggregateBase<K, AccumulatorType> {
public:
explicit MapAggAggregate(TypePtr resultType, bool throwOnNestedNulls = false)
: MapAggregateBase<K>(std::move(resultType)),
throwOnNestedNulls_(throwOnNestedNulls) {}
using Base = MapAggregateBase<K, AccumulatorType>;

using Base = MapAggregateBase<K>;
explicit MapAggAggregate(TypePtr resultType, bool throwOnNestedNulls = false)
: Base(std::move(resultType)), throwOnNestedNulls_(throwOnNestedNulls) {}

bool supportsToIntermediate() const override {
return true;
Expand Down Expand Up @@ -145,6 +144,14 @@ class MapAggAggregate : public MapAggregateBase<K> {
const bool throwOnNestedNulls_;
};

template <TypeKind Kind>
std::unique_ptr<exec::Aggregate> createMapAggAggregateWithCustomCompare(
const TypePtr& resultType) {
return std::make_unique<MapAggAggregate<
typename TypeTraits<Kind>::NativeType,
CustomComparisonMapAccumulator<Kind>>>(resultType);
}

} // namespace

void registerMapAggAggregate(
Expand Down Expand Up @@ -178,7 +185,15 @@ void registerMapAggAggregate(
"{}: unexpected number of arguments",
name);
const bool throwOnNestedNulls = rawInput;
const auto typeKind = resultType->childAt(0)->kind();

const auto keyType = resultType->childAt(0);
const auto typeKind = keyType->kind();

if (keyType->providesCustomComparison()) {
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
createMapAggAggregateWithCustomCompare, typeKind, resultType);
}

switch (typeKind) {
case TypeKind::BOOLEAN:
return std::make_unique<MapAggAggregate<bool>>(resultType);
Expand Down
7 changes: 3 additions & 4 deletions velox/functions/prestosql/aggregates/MapAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@

namespace facebook::velox::aggregate::prestosql {

template <typename K>
template <typename K, typename AccumulatorType>
class MapAggregateBase : public exec::Aggregate {
public:
explicit MapAggregateBase(TypePtr resultType) : Aggregate(resultType) {}

using AccumulatorType = MapAccumulator<K>;

int32_t accumulatorFixedWidthSize() const override {
return sizeof(AccumulatorType);
}
Expand Down Expand Up @@ -202,7 +200,8 @@ class MapAggregateBase : public exec::Aggregate {
DecodedVector decodedMaps_;
};

template <template <typename K> class TAggregate>
template <template <typename K, typename Accumulator = MapAccumulator<K>>
class TAggregate>
std::unique_ptr<exec::Aggregate> createMapAggregate(const TypePtr& resultType) {
auto typeKind = resultType->childAt(0)->kind();
switch (typeKind) {
Expand Down
33 changes: 25 additions & 8 deletions velox/functions/prestosql/aggregates/MapUnionAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ namespace facebook::velox::aggregate::prestosql {
namespace {
// See documentation at
// https://prestodb.io/docs/current/functions/aggregate.html
template <typename K>
class MapUnionAggregate : public MapAggregateBase<K> {
template <typename K, typename AccumulatorType = MapAccumulator<K>>
class MapUnionAggregate : public MapAggregateBase<K, AccumulatorType> {
public:
explicit MapUnionAggregate(TypePtr resultType)
: MapAggregateBase<K>(resultType) {}
using Base = MapAggregateBase<K, AccumulatorType>;

explicit MapUnionAggregate(TypePtr resultType) : Base(resultType) {}

bool supportsToIntermediate() const override {
return true;
Expand All @@ -37,7 +38,7 @@ class MapUnionAggregate : public MapAggregateBase<K> {
if (rows.isAllSelected()) {
result = args[0];
} else {
auto* pool = MapAggregateBase<K>::allocator_->pool();
auto* pool = Base::allocator_->pool();
const auto numRows = rows.size();

// Set nulls for rows not present in 'rows'.
Expand All @@ -60,19 +61,26 @@ class MapUnionAggregate : public MapAggregateBase<K> {
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
MapAggregateBase<K>::addMapInputToAccumulator(groups, rows, args, false);
Base::addMapInputToAccumulator(groups, rows, args, false);
}

void addSingleGroupRawInput(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
MapAggregateBase<K>::addSingleGroupMapInputToAccumulator(
group, rows, args, false);
Base::addSingleGroupMapInputToAccumulator(group, rows, args, false);
}
};

template <TypeKind Kind>
std::unique_ptr<exec::Aggregate> createMapUnionAggregateWithCustomCompare(
const TypePtr& resultType) {
return std::make_unique<MapUnionAggregate<
typename TypeTraits<Kind>::NativeType,
CustomComparisonMapAccumulator<Kind>>>(resultType);
}

} // namespace

void registerMapUnionAggregate(
Expand Down Expand Up @@ -101,6 +109,15 @@ void registerMapUnionAggregate(
VELOX_CHECK_EQ(
argTypes.size(), 1, "{}: unexpected number of arguments", name);

const auto keyType = resultType->childAt(0);

if (keyType->providesCustomComparison()) {
return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(
createMapUnionAggregateWithCustomCompare,
keyType->kind(),
resultType);
}

return createMapAggregate<MapUnionAggregate>(resultType);
},
withCompanionFunctions,
Expand Down
81 changes: 81 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MapAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

using namespace facebook::velox::exec;
using namespace facebook::velox::exec::test;
Expand Down Expand Up @@ -551,5 +552,85 @@ TEST_F(MapAggTest, nans) {
{data}, {"c2"}, {"map_agg(c0, c1)"}, {"a0", "c2"}, {expectedResult});
}

TEST_F(MapAggTest, timestampWithTimeZone) {
// Global Aggregation, Primitive type
auto data = makeRowVector(
{makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(2, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(2, 2),
pack(3, 3),
pack(1, 1),
pack(3, 0)},
TIMESTAMP_WITH_TIME_ZONE()),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
makeFlatVector<int32_t>({1, 1, 1, 1, 1, 2, 2, 2, 2, 2})});

auto expectedResult = makeRowVector({makeMapVector<int64_t, int32_t>(
{{{pack(0, 0), 1}, {pack(1, 0), 2}, {pack(2, 0), 3}, {pack(3, 3), 8}}},
MAP(TIMESTAMP_WITH_TIME_ZONE(), INTEGER()))});

testAggregations({data}, {}, {"map_agg(c0, c1)"}, {expectedResult});

// Group by Aggregation, Primitive type
expectedResult = makeRowVector(
{makeMapVector<int64_t, int32_t>(
{{{pack(0, 0), 1}, {pack(1, 0), 2}, {pack(2, 0), 3}},
{{pack(3, 3), 8}, {pack(1, 2), 6}, {pack(2, 2), 7}}},
MAP(TIMESTAMP_WITH_TIME_ZONE(), INTEGER())),
makeFlatVector<int32_t>({1, 2})});

testAggregations(
{data}, {"c2"}, {"map_agg(c0, c1)"}, {"a0", "c2"}, {expectedResult});

// Global Aggregation, Complex type(Row)
data = makeRowVector(
{makeRowVector({makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(2, 0),
pack(0, 1),
pack(1, 1),
pack(1, 2),
pack(2, 2),
pack(3, 3),
pack(1, 1),
pack(3, 0)},
TIMESTAMP_WITH_TIME_ZONE())}),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
makeFlatVector<int32_t>({1, 1, 1, 1, 1, 2, 2, 2, 2, 2})});

expectedResult = makeRowVector({makeMapVector(
{0},
makeRowVector({makeFlatVector<int64_t>(
{pack(0, 0), pack(1, 0), pack(2, 0), pack(3, 3)},
TIMESTAMP_WITH_TIME_ZONE())}),
makeFlatVector<int32_t>({1, 2, 3, 8}))});

testAggregations({data}, {}, {"map_agg(c0, c1)"}, {expectedResult});

// Group by Aggregation, Complex type(Row)
expectedResult = makeRowVector(
{makeMapVector(
{0, 3},
makeRowVector({makeFlatVector<int64_t>(
{pack(0, 0),
pack(1, 0),
pack(2, 0),
pack(3, 3),
pack(1, 2),
pack(2, 2)},
TIMESTAMP_WITH_TIME_ZONE())}),
makeFlatVector<int32_t>({1, 2, 3, 8, 6, 7})),
makeFlatVector<int32_t>({1, 2})});

testAggregations(
{data}, {"c2"}, {"map_agg(c0, c1)"}, {"a0", "c2"}, {expectedResult});
}

} // namespace
} // namespace facebook::velox::aggregate::test
Loading

0 comments on commit ef28741

Please sign in to comment.