Skip to content

Commit

Permalink
Check for overflow in sum aggregate function
Browse files Browse the repository at this point in the history
This patch adds changes to check for overflow on every update
operation in the sum aggregate function. This is only implemented
for integer types as floating points get set to infinity once they
overflow, which is a valid result.

Test Plan:
Verified that this causes no performance regression by running
existing aggregation benchmark.
Also added a unit test for the same.
  • Loading branch information
bikramSingh91 committed Jul 21, 2022
1 parent da8b01a commit 089437a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 15 deletions.
48 changes: 34 additions & 14 deletions velox/functions/prestosql/aggregates/SumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "velox/expression/FunctionSignature.h"
#include "velox/functions/prestosql/CheckedArithmeticImpl.h"
#include "velox/functions/prestosql/aggregates/SimpleNumericAggregate.h"

namespace facebook::velox::aggregate {
Expand Down Expand Up @@ -83,8 +84,8 @@ class SumAggregate
group,
rows,
args[0],
[](TAccumulator& result, TInput value) { result += value; },
[](TAccumulator& result, TInput value, int n) { result += n * value; },
&updateSingleValue<TAccumulator>,
&updateDuplicateValues<TAccumulator>,
mayPushdown,
0);
}
Expand All @@ -98,8 +99,8 @@ class SumAggregate
group,
rows,
args[0],
[](ResultType& result, TInput value) { result += value; },
[](ResultType& result, TInput value, int n) { result += n * value; },
&updateSingleValue<ResultType>,
&updateDuplicateValues<ResultType>,
mayPushdown,
0);
}
Expand All @@ -123,18 +124,37 @@ class SumAggregate

if (exec::Aggregate::numNulls_) {
BaseAggregate::template updateGroups<true, TData>(
groups,
rows,
arg,
[](TData& result, TInput value) { result += value; },
false);
groups, rows, arg, &updateSingleValue<TData>, false);
} else {
BaseAggregate::template updateGroups<false, TData>(
groups,
rows,
arg,
[](TData& result, TInput value) { result += value; },
false);
groups, rows, arg, &updateSingleValue<TData>, false);
}
}

private:
/// Update functions that check for overflows for integer types.
/// For floating points, an overflow results in +/- infinity which is a
/// valid output.
template <typename TOutput>
static void updateSingleValue(TOutput& result, TInput value) {
if constexpr (
std::is_same<TOutput, double>::value ||
std::is_same<TOutput, float>::value) {
result += value;
} else {
result = functions::checkedPlus<TOutput>(result, value);
}
}

template <typename TOutput>
static void updateDuplicateValues(TOutput& result, TInput value, int n) {
if constexpr (
std::is_same<TOutput, double>::value ||
std::is_same<TOutput, float>::value) {
result += n * value;
} else {
result = functions::checkedPlus<TOutput>(
result, functions::checkedMultiply<TOutput>(n, value));
}
}
};
Expand Down
60 changes: 59 additions & 1 deletion velox/functions/prestosql/aggregates/tests/SumTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,59 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/exec/AggregationHook.h"
#include "velox/exec/tests/utils/Cursor.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/functions/prestosql/aggregates/tests/AggregationTestBase.h"

using facebook::velox::exec::test::PlanBuilder;
using namespace facebook::velox::exec::test;

namespace facebook::velox::aggregate::test {

namespace {

class SumTest : public AggregationTestBase {};
class SumTest : public AggregationTestBase {
protected:
template <typename InputType, typename ResultType>
void testInputTypeLimits(bool expectOverflow = false) {
std::vector<std::vector<InputType>> inputs = {
{std::numeric_limits<InputType>::min(),
std::numeric_limits<InputType>::min() + 1},
{std::numeric_limits<InputType>::max(),
std::numeric_limits<InputType>::max() - 1}};
for (auto& input : inputs) {
std::vector<RowVectorPtr> data;
data.push_back(makeRowVector({makeFlatVector<InputType>(input)}));

// Testing these two steps provides enough coverage. Adding kfinal
// involves more elaborate multi-fragment setup which would be an
// overkill.
for (auto& step :
{core::AggregationNode::Step::kPartial,
core::AggregationNode::Step::kSingle}) {
auto plan = PlanBuilder()
.values(data)
.aggregation({}, {}, {"sum(c0)"}, {}, step, false, {})
.planNode();
CursorParameters params;
params.planNode = plan;
if (expectOverflow) {
VELOX_ASSERT_THROW(
readCursor(params, [](auto /*task*/) {}), "overflow");
} else {
ResultType expectedOutput = static_cast<ResultType>(input[0]) +
static_cast<ResultType>(input[1]);
auto expectedVector = makeFlatVector<ResultType>(
std::vector<ResultType>(1, expectedOutput));
assertQuery(plan, makeRowVector({expectedVector}));
}
}
}
}
};

TEST_F(SumTest, sumTinyint) {
auto rowType = ROW({"c0", "c1"}, {BIGINT(), TINYINT()});
Expand Down Expand Up @@ -268,5 +310,21 @@ TEST_F(SumTest, hook) {
EXPECT_EQ(0, numNulls);
EXPECT_EQ(value, sumRow.sum);
}

TEST_F(SumTest, inputTypeLimits) {
// Verify sum aggregate function checks and throws an overflow error when
// appropriate. Since all integer types have output types as int64, overflow
// only occurs if the sum exceeds the max int64 value. For floating points, an
// overflow results in an infinite result but does not throw. Results are
// manually compared instead of comparing with duckDB as it throws an error
// instead when floating points go over limit.
testInputTypeLimits<int8_t, int64_t>();
testInputTypeLimits<int16_t, int64_t>();
testInputTypeLimits<int32_t, int64_t>();
testInputTypeLimits<int64_t, int64_t>(true);
// TODO: enable this test once Issue #2079 is fixed
// testInputTypeLimits<float, float>();
testInputTypeLimits<double, double>();
}
} // namespace
} // namespace facebook::velox::aggregate::test

0 comments on commit 089437a

Please sign in to comment.