From bb32e7f0305a040f1375b983b93bcdd277d7ab5e Mon Sep 17 00:00:00 2001 From: Bikramjeet Vig Date: Wed, 13 Jul 2022 17:20:28 -0700 Subject: [PATCH] Check for overflow in sum aggregate function This patch adds changes to check for overflow on every update operation in the sum aggregate function. This is only implemented for integer types as floating points get set to infinity once they overflow, which is a valid result. Test Plan: Verified that this causes no performance regression by running existing aggregation benchmark. Also added a unit test for the same. --- velox/exec/tests/AggregationTest.cpp | 40 ++++++++++++++++ .../prestosql/aggregates/SumAggregate.h | 48 +++++++++++++------ 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/velox/exec/tests/AggregationTest.cpp b/velox/exec/tests/AggregationTest.cpp index aa7f3491d696c..204d71965d602 100644 --- a/velox/exec/tests/AggregationTest.cpp +++ b/velox/exec/tests/AggregationTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" #include "velox/dwio/dwrf/test/utils/BatchMaker.h" #include "velox/exec/Aggregate.h" @@ -484,6 +485,27 @@ class AggregationTest : public OperatorTestBase { DOUBLE(), VARCHAR()})}; folly::Random::DefaultGenerator rng_; + + template + void testSumOverflow(bool expectError, const ResultType expectedResult) { + auto expectedVector = + makeFlatVector(std::vector(1, expectedResult)); + std::vector data; + data.push_back(makeRowVector({makeFlatVector( + {std::numeric_limits::max(), + std::numeric_limits::max() - 1})})); + auto plan = PlanBuilder() + .values(data) + .singleAggregation({}, {"sum(c0)"}) + .planNode(); + CursorParameters params; + params.planNode = plan; + if (expectError) { + VELOX_ASSERT_THROW(readCursor(params, [](auto /*task*/) {}), "overflow"); + } else { + AssertQueryBuilder(plan).assertResults(makeRowVector({expectedVector})); + } + } }; template <> @@ -1041,5 +1063,23 @@ TEST_F(AggregationTest, groupingSets) { "SELECT k1, k2, count(1), sum(a), max(b) FROM tmp GROUP BY ROLLUP (k1, k2)"); } +TEST_F(AggregationTest, overflowSumAggregate) { + // Verify sum aggregate function checks and throws an overflow error when + // appropriate. Since all integer types have output types as int64, overflow + // only occurs if the sum exceeds the max int64 value. For floating points, an + // overflow results in an infinite result but does not throw. Results are + // manually compared instead of comparing with duckDB as it throws an error + // instead when floating points go over limit. + testSumOverflow(false, 253); + testSumOverflow(false, 65533); + testSumOverflow(false, 4294967293); + testSumOverflow(true, 0); + // TODO: add this back once sum agg for floats is fixed. + // testSumOverflow(false, + // std::numeric_limits::infinity()); + testSumOverflow( + false, std::numeric_limits::infinity()); +} + } // namespace } // namespace facebook::velox::exec::test diff --git a/velox/functions/prestosql/aggregates/SumAggregate.h b/velox/functions/prestosql/aggregates/SumAggregate.h index 937f61a17ec27..82f3c9a6261bc 100644 --- a/velox/functions/prestosql/aggregates/SumAggregate.h +++ b/velox/functions/prestosql/aggregates/SumAggregate.h @@ -16,6 +16,7 @@ #pragma once #include "velox/expression/FunctionSignature.h" +#include "velox/functions/prestosql/CheckedArithmetic.h" #include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h" namespace facebook::velox::aggregate { @@ -83,8 +84,8 @@ class SumAggregate group, rows, args[0], - [](TAccumulator& result, TInput value) { result += value; }, - [](TAccumulator& result, TInput value, int n) { result += n * value; }, + &updateSingleValue, + &updateDuplicateValues, mayPushdown, 0); } @@ -98,8 +99,8 @@ class SumAggregate group, rows, args[0], - [](ResultType& result, TInput value) { result += value; }, - [](ResultType& result, TInput value, int n) { result += n * value; }, + &updateSingleValue, + &updateDuplicateValues, mayPushdown, 0); } @@ -123,18 +124,37 @@ class SumAggregate if (exec::Aggregate::numNulls_) { BaseAggregate::template updateGroups( - groups, - rows, - arg, - [](TData& result, TInput value) { result += value; }, - false); + groups, rows, arg, &updateSingleValue, false); } else { BaseAggregate::template updateGroups( - groups, - rows, - arg, - [](TData& result, TInput value) { result += value; }, - false); + groups, rows, arg, &updateSingleValue, false); + } + } + + private: + /// Update functions that check for overflows for integer types. + /// For floating points, an overflow results in +/- infinity which is a + /// valid output. + template + static void updateSingleValue(TOutput& result, TInput value) { + if constexpr ( + std::is_same::value || + std::is_same::value) { + result += value; + } else { + result = functions::checkedPlus(result, value); + } + } + + template + static void updateDuplicateValues(TOutput& result, TInput value, int n) { + if constexpr ( + std::is_same::value || + std::is_same::value) { + result += n * value; + } else { + result = functions::checkedPlus( + result, functions::checkedMultiply(n, value)); } } };