Skip to content

Commit

Permalink
Introduce DecimalUtil::adjustSumForOverflow helper method (#7334)
Browse files Browse the repository at this point in the history
Summary:
The sum aggregate function in Spark and Presto uses this logic to adjust the sum
computed via DecimalUtil::addWithOverflow.

Pull Request resolved: #7334

Reviewed By: pedroerp

Differential Revision: D51157206

Pulled By: mbasmanova

fbshipit-source-id: 6cb34fd2dec9190e75fc9f4edcc5f885d19c7c7d
  • Loading branch information
liujiayi771 authored and facebook-github-bot committed Nov 16, 2023
1 parent 0821530 commit 8550a15
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 17 deletions.
18 changes: 5 additions & 13 deletions velox/functions/lib/aggregates/SumAggregateBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,11 @@ class DecimalSumAggregate

virtual int128_t computeFinalValue(
functions::aggregate::LongDecimalWithOverflowState* accumulator) final {
// Value is valid if the conditions below are true.
int128_t sum = accumulator->sum;
if ((accumulator->overflow == 1 && accumulator->sum < 0) ||
(accumulator->overflow == -1 && accumulator->sum > 0)) {
sum = static_cast<int128_t>(
DecimalUtil::kOverflowMultiplier * accumulator->overflow +
accumulator->sum);
} else {
VELOX_CHECK(accumulator->overflow == 0, "Decimal overflow");
}

DecimalUtil::valueInRange(sum);
return sum;
auto sum = DecimalUtil::adjustSumForOverflow(
accumulator->sum, accumulator->overflow);
VELOX_USER_CHECK(sum.has_value(), "Decimal overflow");
DecimalUtil::valueInRange(sum.value());
return sum.value();
}
};

Expand Down
48 changes: 44 additions & 4 deletions velox/type/DecimalUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DecimalUtil {
static constexpr uint128_t kInt128Mask = (static_cast<uint128_t>(1) << 127);

FOLLY_ALWAYS_INLINE static void valueInRange(int128_t value) {
VELOX_CHECK(
VELOX_USER_CHECK(
(value >= kLongDecimalMin && value <= kLongDecimalMax),
"Decimal overflow. Value '{}' is not in the range of Decimal Type",
value);
Expand Down Expand Up @@ -240,8 +240,8 @@ class DecimalUtil {
*/
inline static int64_t addUnsignedValues(
int128_t& sum,
const int128_t& lhs,
const int128_t& rhs,
int128_t lhs,
int128_t rhs,
bool isResultNegative) {
__uint128_t unsignedSum = (__uint128_t)lhs + (__uint128_t)rhs;
// Ignore overflow value.
Expand All @@ -250,15 +250,29 @@ class DecimalUtil {
return (unsignedSum >> 127);
}

/// Adds two signed 128-bit numbers (int128_t), calculates the sum, and
/// returns the overflow. It can be used to track the number of overflow when
/// adding a batch of input numbers. It takes lhs and rhs as input, and stores
/// their sum in result. overflow == 1 indicates upward overflow. overflow ==
/// -1 indicates downward overflow. overflow == 0 indicates no overflow.
/// Adding negative and non-negative numbers never overflows, so we can
/// directly add them. Adding two negative or two positive numbers may
/// overflow. To add numbers that may overflow, first convert both numbers to
/// unsigned 128-bit number (uint128_t), and perform the addition. The highest
/// bits in the result indicates overflow. Adjust the signs of sum and
/// overflow based on the signs of the inputs. The caller must sum up overflow
/// values and call adjustSumForOverflow after processing all inputs.
inline static int64_t
addWithOverflow(int128_t& result, const int128_t& lhs, const int128_t& rhs) {
addWithOverflow(int128_t& result, int128_t lhs, int128_t rhs) {
bool isLhsNegative = lhs < 0;
bool isRhsNegative = rhs < 0;
int64_t overflow = 0;
if (isLhsNegative == isRhsNegative) {
// Both inputs of same time.
if (isLhsNegative) {
// Both negative, ignore signs and add.
VELOX_DCHECK_NE(lhs, std::numeric_limits<int128_t>::min());
VELOX_DCHECK_NE(rhs, std::numeric_limits<int128_t>::min());
overflow = addUnsignedValues(result, -lhs, -rhs, true);
overflow = -overflow;
} else {
Expand All @@ -271,6 +285,32 @@ class DecimalUtil {
return overflow;
}

/// Corrects the sum result calculated using addWithOverflow. Since the sum
/// calculated by addWithOverflow only retains the lower 127 bits,
/// it may miss one calculation of +(1 << 127) or -(1 << 127).
/// Therefore, we need to make the following adjustments:
/// 1. If overflow = 1 && sum < 0, the calculation missed +(1 << 127).
/// Add 1 << 127 to the sum.
/// 2. If overflow = -1 && sum > 0, the calculation missed -(1 << 127).
/// Subtract 1 << 127 to the sum.
/// If an overflow indeed occurs and the result cannot be adjusted,
/// it will return std::nullopt.
inline static std::optional<int128_t> adjustSumForOverflow(
int128_t sum,
int64_t overflow) {
// Value is valid if the conditions below are true.
if ((overflow == 1 && sum < 0) || (overflow == -1 && sum > 0)) {
return static_cast<int128_t>(
DecimalUtil::kOverflowMultiplier * overflow + sum);
}
if (overflow != 0) {
// The actual overflow occurred.
return std::nullopt;
}

return sum;
}

/*
* Computes average. If there is an overflow value uses the following
* expression to compute the average.
Expand Down
54 changes: 54 additions & 0 deletions velox/type/tests/DecimalTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,59 @@ TEST(DecimalTest, valueInPrecisionRange) {
DecimalUtil::kLongDecimalMin - 1, LongDecimalType::kMaxPrecision));
}

TEST(DecimalAggregateTest, adjustSumForOverflow) {
struct SumWithOverflow {
int128_t sum{0};
int64_t overflow{0};

void add(int128_t input) {
overflow += DecimalUtil::addWithOverflow(sum, sum, input);
}

std::optional<int128_t> adjustedSum() const {
return DecimalUtil::adjustSumForOverflow(sum, overflow);
}

void reset() {
sum = 0;
overflow = 0;
}
};

SumWithOverflow accumulator;
// kLongDecimalMax + kLongDecimalMax will trigger one upward overflow, and the
// final sum result calculated by DecimalUtil::addWithOverflow is negative.
// DecimalUtil::adjustSumForOverflow can adjust the sum to kLongDecimalMax
// correctly.
accumulator.add(DecimalUtil::kLongDecimalMax);
accumulator.add(DecimalUtil::kLongDecimalMax);
accumulator.add(DecimalUtil::kLongDecimalMin);
EXPECT_EQ(accumulator.adjustedSum(), DecimalUtil::kLongDecimalMax);

accumulator.reset();
// kLongDecimalMin + kLongDecimalMin will trigger one downward overflow, and
// the final sum result calculated by DecimalUtil::addWithOverflow is
// positive. DecimalUtil::adjustSumForOverflow can adjust the sum to
// kLongDecimalMin correctly.
accumulator.add(DecimalUtil::kLongDecimalMin);
accumulator.add(DecimalUtil::kLongDecimalMin);
accumulator.add(DecimalUtil::kLongDecimalMax);
EXPECT_EQ(accumulator.adjustedSum(), DecimalUtil::kLongDecimalMin);

accumulator.reset();
// These inputs will eventually trigger an upward overflow, and
// DecimalUtil::adjustSumForOverflow will return std::nullopt.
accumulator.add(DecimalUtil::kLongDecimalMax);
accumulator.add(DecimalUtil::kLongDecimalMax);
EXPECT_FALSE(accumulator.adjustedSum().has_value());

accumulator.reset();
// These inputs will eventually trigger a downward overflow, and
// DecimalUtil::adjustSumForOverflow will return std::nullopt.
accumulator.add(DecimalUtil::kLongDecimalMin);
accumulator.add(DecimalUtil::kLongDecimalMin);
EXPECT_FALSE(accumulator.adjustedSum().has_value());
}

} // namespace
} // namespace facebook::velox

0 comments on commit 8550a15

Please sign in to comment.