diff --git a/velox/functions/prestosql/aggregates/MapUnionSumAggregate.cpp b/velox/functions/prestosql/aggregates/MapUnionSumAggregate.cpp index d86ee3f407d2..f54feab94721 100644 --- a/velox/functions/prestosql/aggregates/MapUnionSumAggregate.cpp +++ b/velox/functions/prestosql/aggregates/MapUnionSumAggregate.cpp @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/exec/AddressableNonNullValueList.h" #include "velox/exec/Aggregate.h" #include "velox/exec/Strings.h" #include "velox/expression/FunctionSignature.h" -#include "velox/functions/lib/CheckedArithmeticImpl.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" #include "velox/vector/FlatVector.h" @@ -32,7 +32,7 @@ struct Accumulator { AlignedStlAllocator, 16>>::Type; ValuesMap sums; - explicit Accumulator(HashStringAllocator* allocator) + explicit Accumulator(const TypePtr& /*type*/, HashStringAllocator* allocator) : sums{AlignedStlAllocator, 16>(allocator)} {} size_t size() const { @@ -41,18 +41,20 @@ struct Accumulator { void addValues( const MapVector* mapVector, - const SimpleVector* mapKeys, - const SimpleVector* mapValues, + const VectorPtr& mapKeys, + const VectorPtr& mapValues, vector_size_t row, HashStringAllocator* allocator) { + auto keys = mapKeys->template as>(); + auto values = mapValues->template as>(); auto offset = mapVector->offsetAt(row); auto size = mapVector->sizeAt(row); for (auto i = 0; i < size; ++i) { // Ignore null map keys. - if (!mapKeys->isNullAt(offset + i)) { - auto key = mapKeys->valueAt(offset + i); - addValue(key, mapValues, offset + i, mapValues->typeKind()); + if (!keys->isNullAt(offset + i)) { + auto key = keys->valueAt(offset + i); + addValue(key, values, offset + i, values->typeKind()); } } } @@ -94,13 +96,16 @@ struct Accumulator { } vector_size_t extractValues( - FlatVector& mapKeys, - FlatVector& mapValues, + VectorPtr& mapKeys, + VectorPtr& mapValues, vector_size_t offset) { + auto keys = mapKeys->asFlatVector(); + auto values = mapValues->asFlatVector(); + auto index = offset; for (const auto& [key, sum] : sums) { - mapKeys.set(index, key); - mapValues.set(index, sum); + keys->set(index, key); + values->set(index, sum); ++index; } @@ -115,8 +120,10 @@ struct StringViewAccumulator { Strings strings; - explicit StringViewAccumulator(HashStringAllocator* allocator) - : base{allocator} {} + explicit StringViewAccumulator( + const TypePtr& type, + HashStringAllocator* allocator) + : base{type, allocator} {} size_t size() const { return base.size(); @@ -124,17 +131,19 @@ struct StringViewAccumulator { void addValues( const MapVector* mapVector, - const SimpleVector* mapKeys, - const SimpleVector* mapValues, + const VectorPtr& mapKeys, + const VectorPtr& mapValues, vector_size_t row, HashStringAllocator* allocator) { + auto keys = mapKeys->template as>(); + auto values = mapValues->template as>(); auto offset = mapVector->offsetAt(row); auto size = mapVector->sizeAt(row); for (auto i = 0; i < size; ++i) { // Ignore null map keys. - if (!mapKeys->isNullAt(offset + i)) { - auto key = mapKeys->valueAt(offset + i); + if (!keys->isNullAt(offset + i)) { + auto key = keys->valueAt(offset + i); if (!key.isInline()) { auto it = base.sums.find(key); @@ -145,19 +154,95 @@ struct StringViewAccumulator { } } - base.addValue(key, mapValues, offset + i, mapValues->typeKind()); + base.addValue(key, values, offset + i, values->typeKind()); } } } vector_size_t extractValues( - FlatVector& mapKeys, - FlatVector& mapValues, + VectorPtr& mapKeys, + VectorPtr& mapValues, vector_size_t offset) { return base.extractValues(mapKeys, mapValues, offset); } }; +/// Maintains a map with keys of type array, map or struct. +template +struct ComplexTypeAccumulator { + using ValueMap = folly::F14FastMap< + AddressableNonNullValueList::Entry, + int64_t, + AddressableNonNullValueList::Hash, + AddressableNonNullValueList::EqualTo, + AlignedStlAllocator< + std::pair, + 16>>; + + /// A set of pointers to values stored in AddressableNonNullValueList. + ValueMap sums; + + /// Stores unique non-null keys. + AddressableNonNullValueList serializedKeys; + + ComplexTypeAccumulator(const TypePtr& type, HashStringAllocator* allocator) + : sums{ + 0, + AddressableNonNullValueList::Hash{}, + AddressableNonNullValueList::EqualTo{type}, + AlignedStlAllocator< + std::pair, + 16>(allocator)} {} + + void addValues( + const MapVector* mapVector, + const VectorPtr& mapKeys, + const VectorPtr& mapValues, + vector_size_t row, + HashStringAllocator* allocator) { + auto offset = mapVector->offsetAt(row); + auto size = mapVector->sizeAt(row); + auto values = mapValues->template as>(); + + for (auto i = 0; i < size; ++i) { + if (!mapKeys->isNullAt(offset + i)) { + auto entry = + serializedKeys.append(*mapKeys.get(), offset + i, allocator); + + auto it = sums.find(entry); + if (it == sums.end()) { + // New entry. + sums[entry] = values->valueAt(offset + i); + } else { + // Existing entry. + sums[entry] += values->valueAt(offset + i); + } + } + } + } + + vector_size_t extractValues( + VectorPtr& mapKeys, + VectorPtr& mapValues, + vector_size_t offset) { + auto values = mapValues->asFlatVector(); + auto index = offset; + + for (const auto& [position, count] : sums) { + AddressableNonNullValueList::read(position, *mapKeys.get(), index); + values->set(index, count); + ++index; + } + + return sums.size(); + } + + size_t size() const { + return sums.size(); + } +}; + +// Defines unique accumulators dependent on type. template struct AccumulatorTypeTraits { using AccumulatorType = Accumulator; @@ -168,6 +253,12 @@ struct AccumulatorTypeTraits { using AccumulatorType = StringViewAccumulator; }; +template +struct AccumulatorTypeTraits { + using AccumulatorType = ComplexTypeAccumulator; +}; + +// Defines common aggregator. template class MapUnionSumAggregate : public exec::Aggregate { public: @@ -190,12 +281,18 @@ class MapUnionSumAggregate : public exec::Aggregate { VELOX_CHECK(mapVector); mapVector->resize(numGroups); - auto mapKeys = mapVector->mapKeys()->as>(); - auto mapValues = mapVector->mapValues()->as>(); + auto mapKeysPtr = mapVector->mapKeys(); + auto mapValuesPtr = mapVector->mapValues(); auto numElements = countElements(groups, numGroups); - mapKeys->resize(numElements); - mapValues->resize(numElements); + mapVector->mapValues()->as>()->resize(numElements); + + // ComplexType cannot be resized the same. + if constexpr (!std::is_same_v) { + mapVector->mapKeys()->as>()->resize(numElements); + } else { + mapVector->mapKeys()->resize(numElements); + } auto rawNulls = mapVector->mutableRawNulls(); vector_size_t offset = 0; @@ -208,7 +305,7 @@ class MapUnionSumAggregate : public exec::Aggregate { clearNull(rawNulls, i); auto mapSize = value(group)->extractValues( - *mapKeys, *mapValues, offset); + mapKeysPtr, mapValuesPtr, offset); mapVector->setOffsetAndSize(i, offset, mapSize); offset += mapSize; } @@ -227,8 +324,8 @@ class MapUnionSumAggregate : public exec::Aggregate { bool /*mayPushdown*/) override { decodedMaps_.decode(*args[0], rows); auto mapVector = decodedMaps_.base()->template as(); - auto mapKeys = mapVector->mapKeys()->template as>(); - auto mapValues = mapVector->mapValues()->template as>(); + auto mapKeys = mapVector->mapKeys(); + auto mapValues = mapVector->mapValues(); rows.applyToSelected([&](auto row) { if (!decodedMaps_.isNullAt(row)) { @@ -249,8 +346,8 @@ class MapUnionSumAggregate : public exec::Aggregate { bool /* mayPushdown */) override { decodedMaps_.decode(*args[0], rows); auto mapVector = decodedMaps_.base()->template as(); - auto mapKeys = mapVector->mapKeys()->template as>(); - auto mapValues = mapVector->mapValues()->template as>(); + auto mapKeys = mapVector->mapKeys(); + auto mapValues = mapVector->mapValues(); auto groupMap = value(group); @@ -285,7 +382,7 @@ class MapUnionSumAggregate : public exec::Aggregate { folly::Range indices) override { setAllNulls(groups, indices); for (auto index : indices) { - new (groups[index] + offset_) AccumulatorType{allocator_}; + new (groups[index] + offset_) AccumulatorType{resultType_, allocator_}; } } @@ -304,8 +401,8 @@ class MapUnionSumAggregate : public exec::Aggregate { void addMap( AccumulatorType& groupMap, const MapVector* mapVector, - const SimpleVector* mapKeys, - const SimpleVector* mapValues, + const VectorPtr& mapKeys, + const VectorPtr& mapValues, vector_size_t row) const { auto decodedRow = decodedMaps_.index(row); groupMap.addValues(mapVector, mapKeys, mapValues, decodedRow, allocator_); @@ -340,7 +437,8 @@ std::unique_ptr createMapUnionSumAggregate( case TypeKind::DOUBLE: return std::make_unique>(resultType); default: - VELOX_UNREACHABLE(); + VELOX_UNREACHABLE( + "Unexpected value type {}", mapTypeKindToName(valueKind)); } } @@ -350,15 +448,6 @@ void registerMapUnionSumAggregate( const std::string& prefix, bool withCompanionFunctions, bool overwrite) { - const std::vector keyTypes = { - "tinyint", - "smallint", - "integer", - "bigint", - "real", - "double", - "varchar", - "json"}; const std::vector valueTypes = { "tinyint", "smallint", @@ -369,15 +458,14 @@ void registerMapUnionSumAggregate( }; std::vector> signatures; - for (auto keyType : keyTypes) { - for (auto valueType : valueTypes) { - auto mapType = fmt::format("map({},{})", keyType, valueType); - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType(mapType) - .intermediateType(mapType) - .argumentType(mapType) - .build()); - } + for (auto valueType : valueTypes) { + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .comparableTypeVariable("K") + .returnType(fmt::format("map(K,{})", valueType)) + .intermediateType(fmt::format("map(K,{})", valueType)) + .argumentType(fmt::format("map(K,{})", valueType)) + .build()); } auto name = prefix + kMapUnionSum; @@ -395,6 +483,8 @@ void registerMapUnionSumAggregate( auto& mapType = argTypes[0]->asMap(); auto keyTypeKind = mapType.keyType()->kind(); auto valueTypeKind = mapType.valueType()->kind(); + const auto keyType = resultType->childAt(0); + switch (keyTypeKind) { case TypeKind::TINYINT: return createMapUnionSumAggregate( @@ -416,8 +506,14 @@ void registerMapUnionSumAggregate( case TypeKind::VARCHAR: return createMapUnionSumAggregate( valueTypeKind, resultType); + case TypeKind::ARRAY: + case TypeKind::MAP: + case TypeKind::ROW: + return createMapUnionSumAggregate( + valueTypeKind, resultType); default: - VELOX_UNREACHABLE(); + VELOX_UNREACHABLE( + "Unexpected key type {}", mapTypeKindToName(keyTypeKind)); } }, withCompanionFunctions, diff --git a/velox/functions/prestosql/aggregates/tests/MapUnionSumTest.cpp b/velox/functions/prestosql/aggregates/tests/MapUnionSumTest.cpp index 77328923d58f..5d118c78ca3e 100644 --- a/velox/functions/prestosql/aggregates/tests/MapUnionSumTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MapUnionSumTest.cpp @@ -477,4 +477,54 @@ TEST_F(MapUnionSumTest, nanKeys) { } } // namespace + +TEST_F(MapUnionSumTest, complexType) { + // Verify that NaNs with different binary representations are considered equal + // and deduplicated when used as keys in the output map. + static const auto kNaN = std::numeric_limits::quiet_NaN(); + static const auto kSNaN = std::numeric_limits::signaling_NaN(); + + // Global Aggregation, Complex type(Row) + // The complex input values are: + // [{"key":[1,1],"value":1},{"key":["NaN",2],"value":2},{"key":[2,4],"value":3},{"key":[3,5],"value":4}, + // {"key":["NaN",2],"value":5}, {"key":["NaN",2],"value":6}] + auto data = makeRowVector( + {makeMapVector( + {0, 1, 2, 3, 4, 5}, + makeRowVector( + {makeFlatVector({1, kSNaN, 2, 3, kNaN, kSNaN}), + makeFlatVector({1, 2, 4, 5, 2, 2})}), + makeFlatVector({1, 2, 3, 4, 5, 6})), + makeFlatVector({1, 1, 1, 2, 2, 2})}); + + // The expected result is + // [{"key":[1,1],"value":1},{"key":[2,4],"value":3},{"key":[3,5],"value":4}, + // {"key":["NaN",2],"value":12}] + auto expectedResult = makeRowVector({makeMapVector( + {0}, + makeRowVector( + {makeFlatVector({1, 2, 3, kNaN}), + makeFlatVector({1, 4, 5, 2})}), + makeFlatVector({1, 3, 4, 13}))}); + + testAggregations({data}, {}, {"map_union_sum(c0)"}, {expectedResult}); + + // Group by Aggregation, Complex type(Row) + // The expected result is + // [{"key":[1,1],"value":1},{"key":[2,4],"value":3}, + // {"key":["NaN",2],"value":2}] | 1 + // [{"key":[3,5],"value":4},{"key":["NaN",2],"value":11}] | 2 + expectedResult = makeRowVector( + {makeMapVector( + {0, 3}, + makeRowVector( + {makeFlatVector({1, 2, kNaN, 3, kNaN}), + makeFlatVector({1, 4, 2, 5, 2})}), + makeFlatVector({1, 3, 2, 4, 11})), + makeFlatVector({1, 2})}); + + testAggregations( + {data}, {"c1"}, {"map_union_sum(c0)"}, {"a0", "c1"}, {expectedResult}); +} + } // namespace facebook::velox::aggregate::test