Skip to content

Commit

Permalink
Check for overflow in sum aggregate function
Browse files Browse the repository at this point in the history
This patch adds changes to check for overflow on every update
operation in the sum aggregate function. This is only implemented
for integer types as floating points get set to infinity once they
overflow, which is a valid result.

Test Plan:
Verified that this causes no performance regression by running
existing aggregation benchmark.
Also added a unit test for the same.
  • Loading branch information
bikramSingh91 committed Jul 14, 2022
1 parent e150299 commit bb32e7f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 14 deletions.
40 changes: 40 additions & 0 deletions velox/exec/tests/AggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/common/file/FileSystems.h"
#include "velox/dwio/dwrf/test/utils/BatchMaker.h"
#include "velox/exec/Aggregate.h"
Expand Down Expand Up @@ -484,6 +485,27 @@ class AggregationTest : public OperatorTestBase {
DOUBLE(),
VARCHAR()})};
folly::Random::DefaultGenerator rng_;

template <typename InputType, typename ResultType>
void testSumOverflow(bool expectError, const ResultType expectedResult) {
auto expectedVector =
makeFlatVector<ResultType>(std::vector<ResultType>(1, expectedResult));
std::vector<RowVectorPtr> data;
data.push_back(makeRowVector({makeFlatVector<InputType>(
{std::numeric_limits<InputType>::max(),
std::numeric_limits<InputType>::max() - 1})}));
auto plan = PlanBuilder()
.values(data)
.singleAggregation({}, {"sum(c0)"})
.planNode();
CursorParameters params;
params.planNode = plan;
if (expectError) {
VELOX_ASSERT_THROW(readCursor(params, [](auto /*task*/) {}), "overflow");
} else {
AssertQueryBuilder(plan).assertResults(makeRowVector({expectedVector}));
}
}
};

template <>
Expand Down Expand Up @@ -1041,5 +1063,23 @@ TEST_F(AggregationTest, groupingSets) {
"SELECT k1, k2, count(1), sum(a), max(b) FROM tmp GROUP BY ROLLUP (k1, k2)");
}

TEST_F(AggregationTest, overflowSumAggregate) {
// Verify sum aggregate function checks and throws an overflow error when
// appropriate. Since all integer types have output types as int64, overflow
// only occurs if the sum exceeds the max int64 value. For floating points, an
// overflow results in an infinite result but does not throw. Results are
// manually compared instead of comparing with duckDB as it throws an error
// instead when floating points go over limit.
testSumOverflow<int8_t, int64_t>(false, 253);
testSumOverflow<int16_t, int64_t>(false, 65533);
testSumOverflow<int32_t, int64_t>(false, 4294967293);
testSumOverflow<int64_t, int64_t>(true, 0);
// TODO: add this back once sum agg for floats is fixed.
// testSumOverflow<float, float>(false,
// std::numeric_limits<float>::infinity());
testSumOverflow<double, double>(
false, std::numeric_limits<double>::infinity());
}

} // namespace
} // namespace facebook::velox::exec::test
48 changes: 34 additions & 14 deletions velox/functions/prestosql/aggregates/SumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "velox/expression/FunctionSignature.h"
#include "velox/functions/prestosql/CheckedArithmetic.h"
#include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h"

namespace facebook::velox::aggregate {
Expand Down Expand Up @@ -83,8 +84,8 @@ class SumAggregate
group,
rows,
args[0],
[](TAccumulator& result, TInput value) { result += value; },
[](TAccumulator& result, TInput value, int n) { result += n * value; },
&updateSingleValue<TAccumulator>,
&updateDuplicateValues<TAccumulator>,
mayPushdown,
0);
}
Expand All @@ -98,8 +99,8 @@ class SumAggregate
group,
rows,
args[0],
[](ResultType& result, TInput value) { result += value; },
[](ResultType& result, TInput value, int n) { result += n * value; },
&updateSingleValue<ResultType>,
&updateDuplicateValues<ResultType>,
mayPushdown,
0);
}
Expand All @@ -123,18 +124,37 @@ class SumAggregate

if (exec::Aggregate::numNulls_) {
BaseAggregate::template updateGroups<true, TData>(
groups,
rows,
arg,
[](TData& result, TInput value) { result += value; },
false);
groups, rows, arg, &updateSingleValue<TData>, false);
} else {
BaseAggregate::template updateGroups<false, TData>(
groups,
rows,
arg,
[](TData& result, TInput value) { result += value; },
false);
groups, rows, arg, &updateSingleValue<TData>, false);
}
}

private:
/// Update functions that check for overflows for integer types.
/// For floating points, an overflow results in +/- infinity which is a
/// valid output.
template <typename TOutput>
static void updateSingleValue(TOutput& result, TInput value) {
if constexpr (
std::is_same<TOutput, double>::value ||
std::is_same<TOutput, float>::value) {
result += value;
} else {
result = functions::checkedPlus<TOutput>(result, value);
}
}

template <typename TOutput>
static void updateDuplicateValues(TOutput& result, TInput value, int n) {
if constexpr (
std::is_same<TOutput, double>::value ||
std::is_same<TOutput, float>::value) {
result += n * value;
} else {
result = functions::checkedPlus<TOutput>(
result, functions::checkedMultiply<TOutput>(n, value));
}
}
};
Expand Down

0 comments on commit bb32e7f

Please sign in to comment.