Skip to content

Commit

Permalink
Check for overflow in sum aggregate function
Browse files Browse the repository at this point in the history
This patch adds changes to check for overflow on every update
operation in the sum aggregate function. This is only implemented
for integer types as floating points get set to infinity once they
overflow, which is a valid result.

Test Plan:
Verified that this causes no performance regression by running
existing aggregation benchmark.
Also added a unit test for the same.
  • Loading branch information
bikramSingh91 committed Jul 28, 2022
1 parent da8b01a commit a49ed74
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 15 deletions.
48 changes: 34 additions & 14 deletions velox/functions/prestosql/aggregates/SumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "velox/expression/FunctionSignature.h"
#include "velox/functions/prestosql/CheckedArithmeticImpl.h"
#include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h"

namespace facebook::velox::aggregate {
Expand Down Expand Up @@ -83,8 +84,8 @@ class SumAggregate
group,
rows,
args[0],
[](TAccumulator& result, TInput value) { result += value; },
[](TAccumulator& result, TInput value, int n) { result += n * value; },
&updateSingleValue<TAccumulator>,
&updateDuplicateValues<TAccumulator>,
mayPushdown,
0);
}
Expand All @@ -98,8 +99,8 @@ class SumAggregate
group,
rows,
args[0],
[](ResultType& result, TInput value) { result += value; },
[](ResultType& result, TInput value, int n) { result += n * value; },
&updateSingleValue<ResultType>,
&updateDuplicateValues<ResultType>,
mayPushdown,
0);
}
Expand All @@ -123,18 +124,37 @@ class SumAggregate

if (exec::Aggregate::numNulls_) {
BaseAggregate::template updateGroups<true, TData>(
groups,
rows,
arg,
[](TData& result, TInput value) { result += value; },
false);
groups, rows, arg, &updateSingleValue<TData>, false);
} else {
BaseAggregate::template updateGroups<false, TData>(
groups,
rows,
arg,
[](TData& result, TInput value) { result += value; },
false);
groups, rows, arg, &updateSingleValue<TData>, false);
}
}

private:
/// Update functions that check for overflows for integer types.
/// For floating points, an overflow results in +/- infinity which is a
/// valid output.
template <typename TOutput>
static void updateSingleValue(TOutput& result, TInput value) {
if constexpr (
std::is_same<TOutput, double>::value ||
std::is_same<TOutput, float>::value) {
result += value;
} else {
result = functions::checkedPlus<TOutput>(result, value);
}
}

template <typename TOutput>
static void updateDuplicateValues(TOutput& result, TInput value, int n) {
if constexpr (
std::is_same<TOutput, double>::value ||
std::is_same<TOutput, float>::value) {
result += n * value;
} else {
result = functions::checkedPlus<TOutput>(
result, functions::checkedMultiply<TOutput>(n, value));
}
}
};
Expand Down
109 changes: 108 additions & 1 deletion velox/functions/prestosql/aggregates/tests/SumTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,108 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/exec/AggregationHook.h"
#include "velox/exec/tests/utils/Cursor.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/functions/prestosql/aggregates/tests/AggregationTestBase.h"

using facebook::velox::exec::test::PlanBuilder;
using namespace facebook::velox::exec::test;

namespace facebook::velox::aggregate::test {

namespace {

class SumTest : public AggregationTestBase {};
class SumTest : public AggregationTestBase {
protected:
template <typename InputType, typename ResultType>
void testInputTypeLimits(bool expectOverflow = false) {
std::vector<InputType> underflowTestCase = {
std::numeric_limits<InputType>::min(),
std::numeric_limits<InputType>::min() + 2};
std::vector<InputType> overflowTestCase = {
std::numeric_limits<InputType>::max(),
std::numeric_limits<InputType>::max() - 2};
auto createRowVectorFromSingleValue = [&](InputType value) {
return makeRowVector(
{makeFlatVector<InputType>(std::vector<InputType>(1, value))});
};
for (auto& testCase : {underflowTestCase, overflowTestCase}) {
// Test code path for single values with overflow hit in add.
std::vector<RowVectorPtr> input = {
makeRowVector({makeFlatVector<InputType>(testCase)})};
// Test code path for duplicate values with overflow hit in multiply.
std::vector<RowVectorPtr> inputConstantVector = {
makeRowVector({makeConstant<InputType>(testCase[0] / 3, 4)})};
// Test code path for duplicate values with overflow hit in add.
std::vector<RowVectorPtr> inputHybridVector = {
createRowVectorFromSingleValue(testCase[0]),
makeRowVector({makeConstant<InputType>(testCase[1] / 3, 3)})};
std::vector<core::PlanNodePtr> plansToTest;
// Single Aggregation (raw input in - final result out)
plansToTest.push_back(PlanBuilder()
.values(input)
.singleAggregation({}, {"sum(c0)"})
.planNode());
plansToTest.push_back(PlanBuilder()
.values(inputConstantVector)
.singleAggregation({}, {"sum(c0)"})
.planNode());
plansToTest.push_back(PlanBuilder()
.values(inputHybridVector)
.singleAggregation({}, {"sum(c0)"})
.planNode());
// Partial Aggregation (raw input in - partial result out)
plansToTest.push_back(PlanBuilder()
.values(input)
.partialAggregation({}, {"sum(c0)"})
.planNode());
plansToTest.push_back(PlanBuilder()
.values(inputConstantVector)
.partialAggregation({}, {"sum(c0)"})
.planNode());
plansToTest.push_back(PlanBuilder()
.values(inputHybridVector)
.partialAggregation({}, {"sum(c0)"})
.planNode());
// Final Aggregation (partial result in - final result out):
// To make sure that the overflow occurs in the final aggregation step, we
// create 2 plan fragments and plugging their partially aggregated
// output into a final aggregate plan node. Each of those input fragments
// only have a single input value under the max limit which when added in
// the final step causes an overflow.
auto planNodeIdGenerator =
std::make_shared<exec::test::PlanNodeIdGenerator>();
plansToTest.push_back(
PlanBuilder(planNodeIdGenerator)
.localPartition(
{},
{PlanBuilder(planNodeIdGenerator)
.values({createRowVectorFromSingleValue(testCase[0])})
.partialAggregation({}, {"sum(c0)"})
.planNode(),
PlanBuilder(planNodeIdGenerator)
.values({createRowVectorFromSingleValue(testCase[1])})
.partialAggregation({}, {"sum(c0)"})
.planNode()})
.finalAggregation()
.planNode());
// Run all plan types
CursorParameters params;
for (auto& plan : plansToTest) {
params.planNode = plan;
if (expectOverflow) {
VELOX_ASSERT_THROW(
readCursor(params, [](auto /*task*/) {}), "overflow");
} else {
readCursor(params, [](auto /*task*/) {});
}
}
}
}
};

TEST_F(SumTest, sumTinyint) {
auto rowType = ROW({"c0", "c1"}, {BIGINT(), TINYINT()});
Expand Down Expand Up @@ -268,5 +359,21 @@ TEST_F(SumTest, hook) {
EXPECT_EQ(0, numNulls);
EXPECT_EQ(value, sumRow.sum);
}

TEST_F(SumTest, inputTypeLimits) {
// Verify sum aggregate function checks and throws an overflow error when
// appropriate. Since all integer types have output types as int64, overflow
// only occurs if the sum exceeds the max int64 value. For floating points, an
// overflow results in an infinite result but does not throw. Results are
// manually compared instead of comparing with duckDB as it throws an error
// instead when floating points go over limit.
testInputTypeLimits<int8_t, int64_t>();
testInputTypeLimits<int16_t, int64_t>();
testInputTypeLimits<int32_t, int64_t>();
testInputTypeLimits<int64_t, int64_t>(true);
// TODO: enable this test once Issue #2079 is fixed
// testInputTypeLimits<float, float>();
testInputTypeLimits<double, double>();
}
} // namespace
} // namespace facebook::velox::aggregate::test

0 comments on commit a49ed74

Please sign in to comment.