From d3f0dc9ca68e1c7a39ae38a4864ca79375840a8a Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Mon, 18 Dec 2023 14:37:20 -0800 Subject: [PATCH] Implement Spark decimal add and subtract (#5791) Summary: Use Arrow Gandiva BasicDecimal128 algorithm to compute value. Arrow implementation: https://github.com/apache/arrow/blob/release-12.0.1-rc1/cpp/src/gandiva/precompiled/decimal_ops.cc#L211-L231 Spark result precision and scale maybe different with Presto, because it will use adjustPrecisionScale to change the precision and scale when precision is beyond 38. And this implement can compute data without overflow in some situation. Pull Request resolved: https://github.com/facebookincubator/velox/pull/5791 Reviewed By: kgpai, pedroerp Differential Revision: D52220210 Pulled By: Yuhta fbshipit-source-id: 96f9be2c36d37d11aa71df4308a6a0c9a3bd50a3 --- velox/docs/functions/spark/math.rst | 26 ++ .../functions/sparksql/DecimalArithmetic.cpp | 348 +++++++++++++++++- velox/functions/sparksql/DecimalUtil.h | 48 ++- .../functions/sparksql/RegisterArithmetic.cpp | 2 + .../sparksql/tests/DecimalArithmeticTest.cpp | 288 +++++++++++++-- .../sparksql/tests/DecimalUtilTest.cpp | 17 + velox/type/DecimalUtil.cpp | 4 +- velox/type/DecimalUtil.h | 29 +- 8 files changed, 684 insertions(+), 78 deletions(-) diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index 46be649696cc..2af60b7d2377 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -27,6 +27,20 @@ Mathematical Functions Returns the result of adding x to y. The types of x and y must be the same. For integral types, overflow results in an error. Corresponds to sparks's operator ``+``. +.. spark:function:: add(x, y) -> decimal + + Returns the result of adding ``x`` to ``y``. The argument types should be DECIMAL, and can have different precisions and scales. + Fast path is implemented for cases that should not overflow. For the others, the whole parts and fractional parts of input decimals are added separately and combined finally. + The result type is calculated with the max precision of input precisions, the max scale of input scales, and one extra digit for possible carrier. + Overflow results in null output. Corresponds to Spark's operator ``+``. + + :: + + SELECT CAST(1.1232100 as DECIMAL(38, 7)) + CAST(1 as DECIMAL(10, 0)); -- DECIMAL(38, 6) 2.123210 + SELECT CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)) + CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)); -- DECIMAL(31, 3) -1999999999999999999999999999.998 + SELECT CAST(99999999999999999999999999999999.99998 as DECIMAL(38, 6)) + CAST(-99999999999999999999999999999999.99999 as DECIMAL(38, 5)); -- DECIMAL(38, 6) -0.000010 + SELECT CAST(-99999999999999999999999999999999990.0 as DECIMAL(38, 3)) + CAST(-0.00001 as DECIMAL(38, 7)); -- DECIMAL(38, 6) NULL + .. spark:function:: bin(x) -> varchar Returns the string representation of the long value ``x`` represented in binary. @@ -179,6 +193,18 @@ Mathematical Functions Returns the result of subtracting y from x. The types of x and y must be the same. For integral types, overflow results in an error. Corresponds to Spark's operator ``-``. +.. spark:function:: subtract(x, y) -> decimal + + Returns the result of subtracting ``y`` from ``x``. Reuses the logic of add function for decimal type. + Corresponds to Spark's operator ``-``. + + :: + + SELECT CAST(1.1232100 as DECIMAL(38, 7)) - CAST(1 as DECIMAL(10, 0)); -- DECIMAL(38, 6) 0.123210 + SELECT CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)) - CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)); -- DECIMAL(31, 3) 0.000 + SELECT CAST(99999999999999999999999999999999.99998 as DECIMAL(38, 6)) - CAST(-0.00001 as DECIMAL(38, 5)); -- DECIMAL(38, 6) 99999999999999999999999999999999.999990 + SELECT CAST(-99999999999999999999999999999999990.0 as DECIMAL(38, 3)) - CAST(0.00001 as DECIMAL(38, 7)); -- DECIMAL(38, 6) NULL + .. spark:function:: unaryminus(x) -> [same as x] Returns the negative of `x`. Corresponds to Spark's operator ``-``. diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index ffb782da0694..96b7e0cc0200 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -33,6 +33,175 @@ std::string getResultScale(std::string precision, std::string scale) { scale); } +// Returns the whole and fraction parts of a decimal value. +template +inline std::pair getWholeAndFraction(T value, uint8_t scale) { + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[scale]; + const T whole = value / scaleFactor; + return {whole, value - whole * scaleFactor}; +} + +// Increases the scale of input value by 'delta'. Returns the input value if +// delta is not positive. +inline int128_t increaseScale(int128_t in, int16_t delta) { + // No need to consider overflow as 'delta == higher scale - input scale', so + // the scaled value will not exceed the maximum of long decimal. + return delta <= 0 ? in : in * velox::DecimalUtil::kPowersOfTen[delta]; +} + +// Scales up the whole part to result scale, and combine it with fraction part +// to produce a full result for decimal add. Checks whether the result +// overflows. +template +inline T +decimalAddResult(T whole, T fraction, uint8_t resultScale, bool& overflow) { + T scaledWhole = DecimalUtil::multiply( + whole, velox::DecimalUtil::kPowersOfTen[resultScale], overflow); + if (FOLLY_UNLIKELY(overflow)) { + return 0; + } + const auto result = scaledWhole + fraction; + if constexpr (std::is_same_v) { + overflow = (result > velox::DecimalUtil::kShortDecimalMax) || + (result < velox::DecimalUtil::kShortDecimalMin); + } else { + overflow = (result > velox::DecimalUtil::kLongDecimalMax) || + (result < velox::DecimalUtil::kLongDecimalMin); + } + return result; +} + +// Reduces the scale of input value by 'delta'. Returns the input value if delta +// is not positive. +template +inline static T reduceScale(T in, int32_t delta) { + if (delta <= 0) { + return in; + } + T result; + bool overflow; + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[delta]; + if constexpr (std::is_same_v) { + VELOX_DCHECK_LE( + scaleFactor, + std::numeric_limits::max(), + "Scale factor should not exceed the maximum of int64_t."); + } + DecimalUtil::divideWithRoundUp( + result, in, T(scaleFactor), 0, overflow); + VELOX_DCHECK(!overflow); + return result; +} + +// Adds two non-negative values by adding the whole and fraction parts +// separately. +template +inline static TResult addLargeNonNegative( + A a, + B b, + uint8_t aScale, + uint8_t bScale, + uint8_t rScale, + bool& overflow) { + VELOX_DCHECK_GE( + a, 0, "Non-negative value is expected in addLargeNonNegative."); + VELOX_DCHECK_GE( + b, 0, "Non-negative value is expected in addLargeNonNegative."); + + // Separate whole and fraction parts. + const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); + const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); + + // Adjust fractional parts to higher scale. + const auto higherScale = std::max(aScale, bScale); + const auto aFractionScaled = + increaseScale((int128_t)aFraction, higherScale - aScale); + const auto bFractionScaled = + increaseScale((int128_t)bFraction, higherScale - bScale); + + int128_t fraction; + bool carryToLeft = false; + const auto carrier = velox::DecimalUtil::kPowersOfTen[higherScale]; + if (aFractionScaled >= carrier - bFractionScaled) { + fraction = aFractionScaled + bFractionScaled - carrier; + carryToLeft = true; + } else { + fraction = aFractionScaled + bFractionScaled; + } + + // Scale up the whole part and scale down the fraction part to combine them. + fraction = reduceScale(TResult(fraction), higherScale - rScale); + const auto whole = TResult(aWhole) + TResult(bWhole) + TResult(carryToLeft); + return decimalAddResult(whole, TResult(fraction), rScale, overflow); +} + +// Adds two opposite values by adding the whole and fraction parts separately. +template +inline static TResult addLargeOpposite( + A a, + B b, + uint8_t aScale, + uint8_t bScale, + int32_t rScale, + bool& overflow) { + VELOX_DCHECK( + (a < 0 && b > 0) || (a > 0 && b < 0), + "One positve and one negative value are expected in addLargeOpposite."); + + // Separate whole and fraction parts. + const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); + const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); + + // Adjust fractional parts to higher scale. + const auto higherScale = std::max(aScale, bScale); + const auto aFractionScaled = + increaseScale((int128_t)aFraction, higherScale - aScale); + const auto bFractionScaled = + increaseScale((int128_t)bFraction, higherScale - bScale); + + // No need to consider overflow because two inputs are opposite. + int128_t whole = (int128_t)aWhole + (int128_t)bWhole; + int128_t fraction = aFractionScaled + bFractionScaled; + + // If the whole and fractional parts have different signs, adjust them to the + // same sign. + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[higherScale]; + if (whole < 0 && fraction > 0) { + whole += 1; + fraction -= scaleFactor; + } else if (whole > 0 && fraction < 0) { + whole -= 1; + fraction += scaleFactor; + } + + // Scale up the whole part and scale down the fraction part to combine them. + fraction = reduceScale(TResult(fraction), higherScale - rScale); + return decimalAddResult(TResult(whole), TResult(fraction), rScale, overflow); +} + +template +inline static TResult addLarge( + A a, + B b, + uint8_t aScale, + uint8_t bScale, + int32_t rScale, + bool& overflow) { + if (a >= 0 && b >= 0) { + // Both non-negative. + return addLargeNonNegative( + a, b, aScale, bScale, rScale, overflow); + } else if (a <= 0 && b <= 0) { + // Both non-positive. + return TResult(-addLargeNonNegative( + A(-a), B(-b), aScale, bScale, rScale, overflow)); + } else { + // One positive and the other negative. + return addLargeOpposite( + a, b, aScale, bScale, rScale, overflow); + } +} + template < typename R /* Result Type */, typename A /* Argument1 */, @@ -195,6 +364,112 @@ class DecimalBaseFunction : public exec::VectorFunction { const uint8_t rScale_; }; +class Addition { + public: + template + inline static void apply( + TResult& r, + A a, + B b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t /* aPrecision */, + uint8_t aScale, + uint8_t /* bPrecision */, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, + bool& overflow) { + if (rPrecision < LongDecimalType::kMaxPrecision) { + const int128_t aRescaled = a * velox::DecimalUtil::kPowersOfTen[aRescale]; + const int128_t bRescaled = b * velox::DecimalUtil::kPowersOfTen[bRescale]; + r = TResult(aRescaled + bRescaled); + } else { + const uint32_t minLeadingZeros = + DecimalUtil::minLeadingZeros(a, b, aRescale, bRescale); + if (minLeadingZeros >= 3) { + // Fast path for no overflow. If both numbers contain at least 3 leading + // zeros, they can be added directly without the risk of overflow. + // The reason is if a number contains at least 2 leading zeros, it is + // ensured that the number fits in the maximum of decimal, because + // '2^126 - 1 < 10^38 - 1'. If both numbers contain at least 3 leading + // zeros, we are guaranteed that the result will have at least 2 leading + // zeros. + int128_t aRescaled = a * velox::DecimalUtil::kPowersOfTen[aRescale]; + int128_t bRescaled = b * velox::DecimalUtil::kPowersOfTen[bRescale]; + r = reduceScale( + TResult(aRescaled + bRescaled), std::max(aScale, bScale) - rScale); + } else { + // The risk of overflow should be considered. Add whole and fraction + // parts separately, and then combine. + r = addLarge(a, b, aScale, bScale, rScale, overflow); + } + } + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { + return std::max(0, toScale - fromScale); + } + + inline static std::pair computeResultPrecisionScale( + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { + auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + + std::max(aScale, bScale) + 1; + auto scale = std::max(aScale, bScale); + return DecimalUtil::adjustPrecisionScale(precision, scale); + } +}; + +class Subtraction { + public: + template + inline static void apply( + TResult& r, + A a, + B b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, + bool& overflow) { + Addition::apply( + r, + a, + B(-b), + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + overflow); + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { + return std::max(0, toScale - fromScale); + } + + inline static std::pair computeResultPrecisionScale( + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { + return Addition::computeResultPrecisionScale( + aPrecision, aScale, bPrecision, bScale); + } +}; + class Multiply { public: // Derive from Arrow. @@ -202,8 +477,8 @@ class Multiply { template inline static void apply( R& r, - const A& a, - const B& b, + A a, + B b, uint8_t aRescale, uint8_t bRescale, uint8_t aPrecision, @@ -288,10 +563,10 @@ class Multiply { } inline static std::pair computeResultPrecisionScale( - const uint8_t aPrecision, - const uint8_t aScale, - const uint8_t bPrecision, - const uint8_t bScale) { + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { return DecimalUtil::adjustPrecisionScale( aPrecision + bPrecision + 1, aScale + bScale); } @@ -318,8 +593,8 @@ class Divide { template inline static void apply( R& r, - const A& a, - const B& b, + A a, + B b, uint8_t aRescale, uint8_t /* bRescale */, uint8_t /* aPrecision */, @@ -338,16 +613,38 @@ class Divide { } inline static std::pair computeResultPrecisionScale( - const uint8_t aPrecision, - const uint8_t aScale, - const uint8_t bPrecision, - const uint8_t bScale) { + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { auto scale = std::max(6, aScale + bPrecision + 1); auto precision = aPrecision - aScale + bScale + scale; return DecimalUtil::adjustPrecisionScale(precision, scale); } }; +std::vector> +decimalAddSubtractSignature() { + return { + exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", + "min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)") + .integerVariable( + "r_scale", + getResultScale( + "max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1", + "max(a_scale, b_scale)")) + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build()}; +} + std::vector> decimalMultiplySignature() { return {exec::FunctionSignatureBuilder() @@ -393,14 +690,16 @@ std::shared_ptr createDecimalFunction( const std::string& name, const std::vector& inputArgs, const core::QueryConfig& /*config*/) { - auto aType = inputArgs[0].type; - auto bType = inputArgs[1].type; - auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); - auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); - auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( + const auto& aType = inputArgs[0].type; + const auto& bType = inputArgs[1].type; + const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); + const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); + const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( aPrecision, aScale, bPrecision, bScale); - uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); - uint8_t bRescale = Operation::computeRescaleFactor(bScale, aScale, rScale); + const uint8_t aRescale = + Operation::computeRescaleFactor(aScale, bScale, rScale); + const uint8_t bRescale = + Operation::computeRescaleFactor(bScale, aScale, rScale); if (aType->isShortDecimal()) { if (bType->isShortDecimal()) { if (rPrecision > ShortDecimalType::kMaxPrecision) { @@ -478,10 +777,19 @@ std::shared_ptr createDecimalFunction( rScale); } } - VELOX_UNSUPPORTED(); } }; // namespace +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_add, + decimalAddSubtractSignature(), + createDecimalFunction); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_sub, + decimalAddSubtractSignature(), + createDecimalFunction); + VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_mul, decimalMultiplySignature(), diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index b6aa538f3e05..fbe5da77809e 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -35,8 +35,8 @@ class DecimalUtil { /// This method is used only when /// `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. inline static std::pair adjustPrecisionScale( - const uint8_t rPrecision, - const uint8_t rScale) { + uint8_t rPrecision, + uint8_t rScale) { if (rPrecision <= LongDecimalType::kMaxPrecision) { return {rPrecision, rScale}; } else { @@ -109,6 +109,30 @@ class DecimalUtil { return value; } + /// Returns the minumum number of leading zeros after scaling up two inputs + /// for certain scales. Inputs are decimal values of bigint or hugeint type. + template + inline static uint32_t + minLeadingZeros(A a, B b, uint8_t aRescale, uint8_t bRescale) { + auto minLeadingZerosAfterRescale = [](int32_t numLeadingZeros, + uint8_t scale) { + if (scale == 0) { + return numLeadingZeros; + } + /// If a value containing 'numLeadingZeros' leading zeros is scaled up by + /// 'scale', the new leading zeros depend on the max bits need to be + /// increased. + return std::max( + numLeadingZeros - kMaxBitsRequiredIncreaseAfterScaling[scale], 0); + }; + + const int32_t aLeadingZeros = minLeadingZerosAfterRescale( + bits::countLeadingZeros(absValue(a)), aRescale); + const int32_t bLeadingZeros = minLeadingZerosAfterRescale( + bits::countLeadingZeros(absValue(b)), bRescale); + return std::min(aLeadingZeros, bLeadingZeros); + } + /// Derives from Arrow BasicDecimal128 Divide. /// https://github.com/apache/arrow/blob/release-12.0.1-rc1/cpp/src/gandiva/precompiled/decimal_ops.cc#L350 /// @@ -120,12 +144,8 @@ class DecimalUtil { /// int256_t as intermediate type, and then convert to real result type with /// overflow flag. template - inline static R divideWithRoundUp( - R& r, - const A& a, - const B& b, - uint8_t aRescale, - bool& overflow) { + inline static R + divideWithRoundUp(R& r, A a, B b, uint8_t aRescale, bool& overflow) { if (b == 0) { overflow = true; return R(-1); @@ -192,9 +212,11 @@ class DecimalUtil { } private: - /// We rely on the following formula: - /// bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1 - /// We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76 + /// Maintains the max bits that need to be increased for rescaling a value by + /// certain scale. The calculation relies on the following formula: + /// bitsRequired(x * 10^y) <= bitsRequired(x) + floor(log2(10^y)) + 1. + /// This array stores the precomputed 'floor(log2(10^y)) + 1' for y = 0, + /// 1, 2, ..., 75, 76. static constexpr int32_t kMaxBitsRequiredIncreaseAfterScaling[] = { 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, 44, 47, 50, 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, @@ -204,9 +226,7 @@ class DecimalUtil { 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253}; template - inline static int32_t maxBitsRequiredAfterScaling( - const A& num, - uint8_t aRescale) { + inline static int32_t maxBitsRequiredAfterScaling(A num, uint8_t aRescale) { auto valueAbs = absValue(num); int32_t numOccupied = sizeof(A) * 8 - bits::countLeadingZeros(valueAbs); return numOccupied + kMaxBitsRequiredIncreaseAfterScaling[aRescale]; diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 5435cdf66010..08851f9e0d9f 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -91,6 +91,8 @@ void registerArithmeticFunctions(const std::string& prefix) { registerFunction({prefix + "log10"}); registerRandFunctions(prefix); + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_add, prefix + "add"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_sub, prefix + "subtract"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "multiply"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "divide"); } diff --git a/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp index a370809a946f..4558c7f4e910 100644 --- a/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp @@ -43,22 +43,257 @@ class DecimalArithmeticTest : public SparkFunctionBaseTest { assertEqualVectors(expected, result); } - VectorPtr makeLongDecimalVector( - const std::vector& value, - int8_t precision, - int8_t scale) { - if (value.size() == 1) { - return makeConstant( - HugeInt::parse(std::move(value[0])), 1, DECIMAL(precision, scale)); - } - std::vector int128s; - for (auto& v : value) { - int128s.emplace_back(HugeInt::parse(std::move(v))); + void testArithmeticFunction( + const std::string& functionName, + const std::vector& inputs, + const VectorPtr& expected) { + VELOX_USER_CHECK_EQ( + inputs.size(), + 2, + "Two input vectors are needed for arithmetic function test."); + std::vector inputExprs = { + std::make_shared(inputs[0]->type(), "c0"), + std::make_shared(inputs[1]->type(), "c1")}; + auto expr = std::make_shared( + expected->type(), std::move(inputExprs), functionName); + testEncodings(expr, inputs, expected); + } + + VectorPtr makeNullableLongDecimalVector( + const std::vector& values, + const TypePtr& type) { + VELOX_USER_CHECK( + type->isDecimal(), + "Decimal type is needed to create long decimal vector."); + std::vector> numbers; + numbers.reserve(values.size()); + for (const auto& value : values) { + if (value == "null") { + numbers.emplace_back(std::nullopt); + } else { + numbers.emplace_back(HugeInt::parse(value)); + } } - return makeFlatVector(int128s, DECIMAL(precision, scale)); + return makeNullableFlatVector(numbers, type); } }; // namespace +TEST_F(DecimalArithmeticTest, add) { + // Precision < 38. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"201", "601", "1366", "999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"301", "901", "9866", "999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"502", "1502", "11232", "1999999999999999999999999999998"}, + DECIMAL(31, 3))); + + // Min leading zero >= 3. + testArithmeticFunction( + "add", + {makeFlatVector( + std::vector{11232100, 9998888, 12345678, 2135632}, + DECIMAL(38, 7)), + makeFlatVector(std::vector{1, 2, 3, 4}, DECIMAL(10, 0))}, + makeFlatVector( + std::vector{2123210, 2999889, 4234568, 4213563}, + DECIMAL(38, 6))); + + // No carry to left. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999000000", + "9999999999999999999999999999999900000", + "9999999999999999999999999999999990000", + "9999999999999999999999999999999999000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{100, 99999, 1234, 999}, DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999990000010", + "99999999999999999999999999999999010000", + "99999999999999999999999999999999900123", + "99999999999999999999999999999999990100"}, + DECIMAL(38, 6))); + + // Carry to left. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999070000", + "9999999999999999999999999999999050000", + "9999999999999999999999999999999870000", + "9999999999999999999999999999999890000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{8000000, 5000000, 8000000, 1999999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999991500000", + "99999999999999999999999999999991000000", + "99999999999999999999999999999999500000", + "99999999999999999999999999999999100000"}, + DECIMAL(38, 6))); + + // Both -ve. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"-201", "-601", "-1366", "-999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"-301", "-901", "-9866", "-999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"-502", "-1502", "-11232", "-1999999999999999999999999999998"}, + DECIMAL(31, 3))); + + // Overflow when scaling up the whole part. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"-99999999999999999999999999999999990000", + "99999999999999999999999999999999999000", + "-99999999999999999999999999999999999900", + "99999999999999999999999999999999999990"}, + DECIMAL(38, 3)), + makeFlatVector( + std::vector{-100, 9999999, -999900, 99999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"null", "null", "null", "null"}, DECIMAL(38, 6))); + + // Ve and -ve. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"99999999999999999999999999999989999990", + "-99999999999999999999999999999989999990", + "99999999999999999999999999999999999980", + "-99999999999999999999999999999999999980"}, + DECIMAL(38, 6)), + makeNullableLongDecimalVector( + {"-9999999999999999999999999999998900000", + "9999999999999999999999999999998900000", + "-9999999999999999999999999999999999999", + "9999999999999999999999999999999999999"}, + DECIMAL(38, 5))}, + makeNullableLongDecimalVector( + {"999990", "-999990", "-10", "10"}, DECIMAL(38, 6))); +} + +TEST_F(DecimalArithmeticTest, subtract) { + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"201", "601", "1366", "999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"301", "901", "9866", "-999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"-100", "-300", "-8500", "1999999999999999999999999999998"}, + DECIMAL(31, 3))); + + // Min leading zero >= 3. + testArithmeticFunction( + "subtract", + {makeFlatVector( + std::vector{11232100, 9998888, 12345678, 2135632}, + DECIMAL(38, 7)), + makeFlatVector(std::vector{1, 2, 3, 4}, DECIMAL(10, 0))}, + makeFlatVector( + std::vector{123210, -1000111, -1765432, -3786437}, + DECIMAL(38, 6))); + + // No carry to left. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999000000", + "9999999999999999999999999999999900000", + "9999999999999999999999999999999990000", + "9999999999999999999999999999999999000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{-100, -99999, -1234, -999}, DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999990000010", + "99999999999999999999999999999999010000", + "99999999999999999999999999999999900123", + "99999999999999999999999999999999990100"}, + DECIMAL(38, 6))); + + // Carry to left. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999070000", + "9999999999999999999999999999999050000", + "9999999999999999999999999999999870000", + "9999999999999999999999999999999890000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{-8000000, -5000000, -8000000, -1999999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999991500000", + "99999999999999999999999999999991000000", + "99999999999999999999999999999999500000", + "99999999999999999999999999999999100000"}, + DECIMAL(38, 6))); + + // Both -ve. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"-201", "-601", "-1366", "-999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"-301", "-901", "-9866", "-999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"100", "300", "8500", "0"}, DECIMAL(31, 3))); + + // Overflow when scaling up the whole part. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"-99999999999999999999999999999999990000", + "99999999999999999999999999999999999000", + "-99999999999999999999999999999999999900", + "99999999999999999999999999999999999990"}, + DECIMAL(38, 3)), + makeFlatVector( + std::vector{100, -9999999, 999900, -99999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"null", "null", "null", "null"}, DECIMAL(38, 6))); + + // Ve and -ve. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"99999999999999999999999999999989999990", + "-99999999999999999999999999999989999990", + "99999999999999999999999999999999999980", + "-99999999999999999999999999999999999980"}, + DECIMAL(38, 6)), + makeFlatVector( + std::vector{-1000000, 1000000, -1, 1}, DECIMAL(38, 5))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999999999990", + "-99999999999999999999999999999999999990", + "99999999999999999999999999999999999990", + "-99999999999999999999999999999999999990"}, + DECIMAL(38, 6))); +} + TEST_F(DecimalArithmeticTest, multiply) { // The result can be obtained by Spark unit test // test("multiply") { @@ -185,8 +420,8 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { auto shortFlat = makeFlatVector({1000, 2000}, DECIMAL(17, 3)); // Divide short and short, returning long. testDecimalExpr( - makeLongDecimalVector( - {"500000000000000000000", "2000000000000000000000"}, 38, 21), + makeNullableLongDecimalVector( + {"500000000000000000000", "2000000000000000000000"}, DECIMAL(38, 21)), "divide(c0, c1)", {makeFlatVector({500, 4000}, DECIMAL(17, 3)), shortFlat}); @@ -200,15 +435,17 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { // Divide long and short, returning long. testDecimalExpr( - makeLongDecimalVector( - {"20" + std::string(20, '0'), "5" + std::string(20, '0')}, 38, 22), + makeNullableLongDecimalVector( + {"20" + std::string(20, '0'), "5" + std::string(20, '0')}, + DECIMAL(38, 22)), "divide(c0, c1)", {shortFlat, longFlat}); // Divide long and long, returning long. testDecimalExpr( - makeLongDecimalVector( - {"5" + std::string(18, '0'), "3" + std::string(18, '0')}, 38, 18), + makeNullableLongDecimalVector( + {"5" + std::string(18, '0'), "3" + std::string(18, '0')}, + DECIMAL(38, 18)), "divide(c0, c1)", {makeFlatVector({2500, 12000}, DECIMAL(20, 2)), longFlat}); @@ -236,23 +473,24 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { // checkEvaluation(Divide(l1, l2), null) // } testDecimalExpr( - makeLongDecimalVector({"497512437810945273631840796019900493"}, 38, 6), + makeNullableLongDecimalVector( + {"497512437810945273631840796019900493"}, DECIMAL(38, 6)), "c0 / c1", - {makeLongDecimalVector({std::string(35, '9')}, 35, 6), + {makeNullableLongDecimalVector({std::string(35, '9')}, DECIMAL(35, 6)), makeConstant(201, 1, DECIMAL(20, 3))}); testDecimalExpr( - makeLongDecimalVector( + makeNullableLongDecimalVector( {"1000" + std::string(17, '0'), "500" + std::string(17, '0')}, - 24, - 20), + DECIMAL(24, 20)), "1.00 / c0", {shortFlat}); // Flat and Constant arguments. testDecimalExpr( - makeLongDecimalVector( - {"500" + std::string(4, '0'), "1000" + std::string(4, '0')}, 23, 7), + makeNullableLongDecimalVector( + {"500" + std::string(4, '0'), "1000" + std::string(4, '0')}, + DECIMAL(23, 7)), "c0 / 2.00", {shortFlat}); diff --git a/velox/functions/sparksql/tests/DecimalUtilTest.cpp b/velox/functions/sparksql/tests/DecimalUtilTest.cpp index 350bbf526583..34de356ae7ab 100644 --- a/velox/functions/sparksql/tests/DecimalUtilTest.cpp +++ b/velox/functions/sparksql/tests/DecimalUtilTest.cpp @@ -43,4 +43,21 @@ TEST_F(DecimalUtilTest, divideWithRoundUp) { testDivideWithRoundUp( 6, velox::DecimalUtil::kPowersOfTen[17], 20, 6000, false); } + +TEST_F(DecimalUtilTest, minLeadingZeros) { + auto result = + DecimalUtil::minLeadingZeros(10000, 6000000, 10, 12); + ASSERT_EQ(result, 1); + + result = DecimalUtil::minLeadingZeros( + 10000, 6'000'000'000'000'000'000, 10, 12); + ASSERT_EQ(result, 16); + + result = DecimalUtil::minLeadingZeros( + velox::DecimalUtil::kLongDecimalMax, + velox::DecimalUtil::kLongDecimalMin, + 10, + 12); + ASSERT_EQ(result, 0); +} } // namespace facebook::velox::functions::sparksql::test diff --git a/velox/type/DecimalUtil.cpp b/velox/type/DecimalUtil.cpp index 62e3c61565a8..3b48c3c98a72 100644 --- a/velox/type/DecimalUtil.cpp +++ b/velox/type/DecimalUtil.cpp @@ -54,7 +54,7 @@ std::string formatDecimal(uint8_t scale, int128_t unscaledValue) { } } // namespace -std::string DecimalUtil::toString(const int128_t value, const TypePtr& type) { +std::string DecimalUtil::toString(int128_t value, const TypePtr& type) { auto [precision, scale] = getDecimalPrecisionScale(*type); return formatDecimal(scale, value); } @@ -91,7 +91,7 @@ int32_t DecimalUtil::toByteArray(int128_t value, char* out) { void DecimalUtil::computeAverage( int128_t& avg, - const int128_t& sum, + int128_t sum, int64_t count, int64_t overflow) { if (overflow == 0) { diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index d0a125b84e54..0a6e455e501f 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -98,7 +98,7 @@ class DecimalUtil { } /// Helper function to convert a decimal value to string. - static std::string toString(const int128_t value, const TypePtr& type); + static std::string toString(int128_t value, const TypePtr& type); template inline static void fillDecimals( @@ -147,11 +147,11 @@ class DecimalUtil { template inline static std::optional rescaleWithRoundUp( - const TInput inputValue, - const int fromPrecision, - const int fromScale, - const int toPrecision, - const int toScale) { + TInput inputValue, + int fromPrecision, + int fromScale, + int toPrecision, + int toScale) { int128_t rescaledValue = inputValue; auto scaleDifference = toScale - fromScale; bool isOverflow = false; @@ -183,10 +183,8 @@ class DecimalUtil { } template - inline static std::optional rescaleInt( - const TInput inputValue, - const int toPrecision, - const int toScale) { + inline static std::optional + rescaleInt(TInput inputValue, int toPrecision, int toScale) { int128_t rescaledValue = static_cast(inputValue); bool isOverflow = __builtin_mul_overflow( rescaledValue, DecimalUtil::kPowersOfTen[toScale], &rescaledValue); @@ -205,8 +203,8 @@ class DecimalUtil { template inline static R divideWithRoundUp( R& r, - const A& a, - const B& b, + A a, + B b, bool noRoundUp, uint8_t aRescale, uint8_t /*bRescale*/) { @@ -312,11 +310,8 @@ class DecimalUtil { } /// avg = (sum + overflow * kOverflowMultiplier) / count - static void computeAverage( - int128_t& avg, - const int128_t& sum, - int64_t count, - int64_t overflow); + static void + computeAverage(int128_t& avg, int128_t sum, int64_t count, int64_t overflow); /// Origins from java side BigInteger#bitLength. ///