From 8550a15cd6463d18ce1c871708acb448a2188133 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Thu, 16 Nov 2023 04:07:33 -0800 Subject: [PATCH] Introduce DecimalUtil::adjustSumForOverflow helper method (#7334) Summary: The sum aggregate function in Spark and Presto uses this logic to adjust the sum computed via DecimalUtil::addWithOverflow. Pull Request resolved: https://github.com/facebookincubator/velox/pull/7334 Reviewed By: pedroerp Differential Revision: D51157206 Pulled By: mbasmanova fbshipit-source-id: 6cb34fd2dec9190e75fc9f4edcc5f885d19c7c7d --- .../lib/aggregates/SumAggregateBase.h | 18 ++----- velox/type/DecimalUtil.h | 48 +++++++++++++++-- velox/type/tests/DecimalTest.cpp | 54 +++++++++++++++++++ 3 files changed, 103 insertions(+), 17 deletions(-) diff --git a/velox/functions/lib/aggregates/SumAggregateBase.h b/velox/functions/lib/aggregates/SumAggregateBase.h index a159ae7f1345..ddb7af71aa63 100644 --- a/velox/functions/lib/aggregates/SumAggregateBase.h +++ b/velox/functions/lib/aggregates/SumAggregateBase.h @@ -185,19 +185,11 @@ class DecimalSumAggregate virtual int128_t computeFinalValue( functions::aggregate::LongDecimalWithOverflowState* accumulator) final { - // Value is valid if the conditions below are true. - int128_t sum = accumulator->sum; - if ((accumulator->overflow == 1 && accumulator->sum < 0) || - (accumulator->overflow == -1 && accumulator->sum > 0)) { - sum = static_cast( - DecimalUtil::kOverflowMultiplier * accumulator->overflow + - accumulator->sum); - } else { - VELOX_CHECK(accumulator->overflow == 0, "Decimal overflow"); - } - - DecimalUtil::valueInRange(sum); - return sum; + auto sum = DecimalUtil::adjustSumForOverflow( + accumulator->sum, accumulator->overflow); + VELOX_USER_CHECK(sum.has_value(), "Decimal overflow"); + DecimalUtil::valueInRange(sum.value()); + return sum.value(); } }; diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index 99ab2d2e68d2..71827a4ac611 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -83,7 +83,7 @@ class DecimalUtil { static constexpr uint128_t kInt128Mask = (static_cast(1) << 127); FOLLY_ALWAYS_INLINE static void valueInRange(int128_t value) { - VELOX_CHECK( + VELOX_USER_CHECK( (value >= kLongDecimalMin && value <= kLongDecimalMax), "Decimal overflow. Value '{}' is not in the range of Decimal Type", value); @@ -240,8 +240,8 @@ class DecimalUtil { */ inline static int64_t addUnsignedValues( int128_t& sum, - const int128_t& lhs, - const int128_t& rhs, + int128_t lhs, + int128_t rhs, bool isResultNegative) { __uint128_t unsignedSum = (__uint128_t)lhs + (__uint128_t)rhs; // Ignore overflow value. @@ -250,8 +250,20 @@ class DecimalUtil { return (unsignedSum >> 127); } + /// Adds two signed 128-bit numbers (int128_t), calculates the sum, and + /// returns the overflow. It can be used to track the number of overflow when + /// adding a batch of input numbers. It takes lhs and rhs as input, and stores + /// their sum in result. overflow == 1 indicates upward overflow. overflow == + /// -1 indicates downward overflow. overflow == 0 indicates no overflow. + /// Adding negative and non-negative numbers never overflows, so we can + /// directly add them. Adding two negative or two positive numbers may + /// overflow. To add numbers that may overflow, first convert both numbers to + /// unsigned 128-bit number (uint128_t), and perform the addition. The highest + /// bits in the result indicates overflow. Adjust the signs of sum and + /// overflow based on the signs of the inputs. The caller must sum up overflow + /// values and call adjustSumForOverflow after processing all inputs. inline static int64_t - addWithOverflow(int128_t& result, const int128_t& lhs, const int128_t& rhs) { + addWithOverflow(int128_t& result, int128_t lhs, int128_t rhs) { bool isLhsNegative = lhs < 0; bool isRhsNegative = rhs < 0; int64_t overflow = 0; @@ -259,6 +271,8 @@ class DecimalUtil { // Both inputs of same time. if (isLhsNegative) { // Both negative, ignore signs and add. + VELOX_DCHECK_NE(lhs, std::numeric_limits::min()); + VELOX_DCHECK_NE(rhs, std::numeric_limits::min()); overflow = addUnsignedValues(result, -lhs, -rhs, true); overflow = -overflow; } else { @@ -271,6 +285,32 @@ class DecimalUtil { return overflow; } + /// Corrects the sum result calculated using addWithOverflow. Since the sum + /// calculated by addWithOverflow only retains the lower 127 bits, + /// it may miss one calculation of +(1 << 127) or -(1 << 127). + /// Therefore, we need to make the following adjustments: + /// 1. If overflow = 1 && sum < 0, the calculation missed +(1 << 127). + /// Add 1 << 127 to the sum. + /// 2. If overflow = -1 && sum > 0, the calculation missed -(1 << 127). + /// Subtract 1 << 127 to the sum. + /// If an overflow indeed occurs and the result cannot be adjusted, + /// it will return std::nullopt. + inline static std::optional adjustSumForOverflow( + int128_t sum, + int64_t overflow) { + // Value is valid if the conditions below are true. + if ((overflow == 1 && sum < 0) || (overflow == -1 && sum > 0)) { + return static_cast( + DecimalUtil::kOverflowMultiplier * overflow + sum); + } + if (overflow != 0) { + // The actual overflow occurred. + return std::nullopt; + } + + return sum; + } + /* * Computes average. If there is an overflow value uses the following * expression to compute the average. diff --git a/velox/type/tests/DecimalTest.cpp b/velox/type/tests/DecimalTest.cpp index 10de56c3aece..f47e9717726f 100644 --- a/velox/type/tests/DecimalTest.cpp +++ b/velox/type/tests/DecimalTest.cpp @@ -193,5 +193,59 @@ TEST(DecimalTest, valueInPrecisionRange) { DecimalUtil::kLongDecimalMin - 1, LongDecimalType::kMaxPrecision)); } +TEST(DecimalAggregateTest, adjustSumForOverflow) { + struct SumWithOverflow { + int128_t sum{0}; + int64_t overflow{0}; + + void add(int128_t input) { + overflow += DecimalUtil::addWithOverflow(sum, sum, input); + } + + std::optional adjustedSum() const { + return DecimalUtil::adjustSumForOverflow(sum, overflow); + } + + void reset() { + sum = 0; + overflow = 0; + } + }; + + SumWithOverflow accumulator; + // kLongDecimalMax + kLongDecimalMax will trigger one upward overflow, and the + // final sum result calculated by DecimalUtil::addWithOverflow is negative. + // DecimalUtil::adjustSumForOverflow can adjust the sum to kLongDecimalMax + // correctly. + accumulator.add(DecimalUtil::kLongDecimalMax); + accumulator.add(DecimalUtil::kLongDecimalMax); + accumulator.add(DecimalUtil::kLongDecimalMin); + EXPECT_EQ(accumulator.adjustedSum(), DecimalUtil::kLongDecimalMax); + + accumulator.reset(); + // kLongDecimalMin + kLongDecimalMin will trigger one downward overflow, and + // the final sum result calculated by DecimalUtil::addWithOverflow is + // positive. DecimalUtil::adjustSumForOverflow can adjust the sum to + // kLongDecimalMin correctly. + accumulator.add(DecimalUtil::kLongDecimalMin); + accumulator.add(DecimalUtil::kLongDecimalMin); + accumulator.add(DecimalUtil::kLongDecimalMax); + EXPECT_EQ(accumulator.adjustedSum(), DecimalUtil::kLongDecimalMin); + + accumulator.reset(); + // These inputs will eventually trigger an upward overflow, and + // DecimalUtil::adjustSumForOverflow will return std::nullopt. + accumulator.add(DecimalUtil::kLongDecimalMax); + accumulator.add(DecimalUtil::kLongDecimalMax); + EXPECT_FALSE(accumulator.adjustedSum().has_value()); + + accumulator.reset(); + // These inputs will eventually trigger a downward overflow, and + // DecimalUtil::adjustSumForOverflow will return std::nullopt. + accumulator.add(DecimalUtil::kLongDecimalMin); + accumulator.add(DecimalUtil::kLongDecimalMin); + EXPECT_FALSE(accumulator.adjustedSum().has_value()); +} + } // namespace } // namespace facebook::velox