diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index 46be649696cc..2af60b7d2377 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -27,6 +27,20 @@ Mathematical Functions Returns the result of adding x to y. The types of x and y must be the same. For integral types, overflow results in an error. Corresponds to sparks's operator ``+``. +.. spark:function:: add(x, y) -> decimal + + Returns the result of adding ``x`` to ``y``. The argument types should be DECIMAL, and can have different precisions and scales. + Fast path is implemented for cases that should not overflow. For the others, the whole parts and fractional parts of input decimals are added separately and combined finally. + The result type is calculated with the max precision of input precisions, the max scale of input scales, and one extra digit for possible carrier. + Overflow results in null output. Corresponds to Spark's operator ``+``. + + :: + + SELECT CAST(1.1232100 as DECIMAL(38, 7)) + CAST(1 as DECIMAL(10, 0)); -- DECIMAL(38, 6) 2.123210 + SELECT CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)) + CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)); -- DECIMAL(31, 3) -1999999999999999999999999999.998 + SELECT CAST(99999999999999999999999999999999.99998 as DECIMAL(38, 6)) + CAST(-99999999999999999999999999999999.99999 as DECIMAL(38, 5)); -- DECIMAL(38, 6) -0.000010 + SELECT CAST(-99999999999999999999999999999999990.0 as DECIMAL(38, 3)) + CAST(-0.00001 as DECIMAL(38, 7)); -- DECIMAL(38, 6) NULL + .. spark:function:: bin(x) -> varchar Returns the string representation of the long value ``x`` represented in binary. @@ -179,6 +193,18 @@ Mathematical Functions Returns the result of subtracting y from x. The types of x and y must be the same. For integral types, overflow results in an error. Corresponds to Spark's operator ``-``. +.. spark:function:: subtract(x, y) -> decimal + + Returns the result of subtracting ``y`` from ``x``. Reuses the logic of add function for decimal type. + Corresponds to Spark's operator ``-``. + + :: + + SELECT CAST(1.1232100 as DECIMAL(38, 7)) - CAST(1 as DECIMAL(10, 0)); -- DECIMAL(38, 6) 0.123210 + SELECT CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)) - CAST(-999999999999999999999999999.999 as DECIMAL(30, 3)); -- DECIMAL(31, 3) 0.000 + SELECT CAST(99999999999999999999999999999999.99998 as DECIMAL(38, 6)) - CAST(-0.00001 as DECIMAL(38, 5)); -- DECIMAL(38, 6) 99999999999999999999999999999999.999990 + SELECT CAST(-99999999999999999999999999999999990.0 as DECIMAL(38, 3)) - CAST(0.00001 as DECIMAL(38, 7)); -- DECIMAL(38, 6) NULL + .. spark:function:: unaryminus(x) -> [same as x] Returns the negative of `x`. Corresponds to Spark's operator ``-``. diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index ffb782da0694..96b7e0cc0200 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -33,6 +33,175 @@ std::string getResultScale(std::string precision, std::string scale) { scale); } +// Returns the whole and fraction parts of a decimal value. +template +inline std::pair getWholeAndFraction(T value, uint8_t scale) { + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[scale]; + const T whole = value / scaleFactor; + return {whole, value - whole * scaleFactor}; +} + +// Increases the scale of input value by 'delta'. Returns the input value if +// delta is not positive. +inline int128_t increaseScale(int128_t in, int16_t delta) { + // No need to consider overflow as 'delta == higher scale - input scale', so + // the scaled value will not exceed the maximum of long decimal. + return delta <= 0 ? in : in * velox::DecimalUtil::kPowersOfTen[delta]; +} + +// Scales up the whole part to result scale, and combine it with fraction part +// to produce a full result for decimal add. Checks whether the result +// overflows. +template +inline T +decimalAddResult(T whole, T fraction, uint8_t resultScale, bool& overflow) { + T scaledWhole = DecimalUtil::multiply( + whole, velox::DecimalUtil::kPowersOfTen[resultScale], overflow); + if (FOLLY_UNLIKELY(overflow)) { + return 0; + } + const auto result = scaledWhole + fraction; + if constexpr (std::is_same_v) { + overflow = (result > velox::DecimalUtil::kShortDecimalMax) || + (result < velox::DecimalUtil::kShortDecimalMin); + } else { + overflow = (result > velox::DecimalUtil::kLongDecimalMax) || + (result < velox::DecimalUtil::kLongDecimalMin); + } + return result; +} + +// Reduces the scale of input value by 'delta'. Returns the input value if delta +// is not positive. +template +inline static T reduceScale(T in, int32_t delta) { + if (delta <= 0) { + return in; + } + T result; + bool overflow; + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[delta]; + if constexpr (std::is_same_v) { + VELOX_DCHECK_LE( + scaleFactor, + std::numeric_limits::max(), + "Scale factor should not exceed the maximum of int64_t."); + } + DecimalUtil::divideWithRoundUp( + result, in, T(scaleFactor), 0, overflow); + VELOX_DCHECK(!overflow); + return result; +} + +// Adds two non-negative values by adding the whole and fraction parts +// separately. +template +inline static TResult addLargeNonNegative( + A a, + B b, + uint8_t aScale, + uint8_t bScale, + uint8_t rScale, + bool& overflow) { + VELOX_DCHECK_GE( + a, 0, "Non-negative value is expected in addLargeNonNegative."); + VELOX_DCHECK_GE( + b, 0, "Non-negative value is expected in addLargeNonNegative."); + + // Separate whole and fraction parts. + const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); + const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); + + // Adjust fractional parts to higher scale. + const auto higherScale = std::max(aScale, bScale); + const auto aFractionScaled = + increaseScale((int128_t)aFraction, higherScale - aScale); + const auto bFractionScaled = + increaseScale((int128_t)bFraction, higherScale - bScale); + + int128_t fraction; + bool carryToLeft = false; + const auto carrier = velox::DecimalUtil::kPowersOfTen[higherScale]; + if (aFractionScaled >= carrier - bFractionScaled) { + fraction = aFractionScaled + bFractionScaled - carrier; + carryToLeft = true; + } else { + fraction = aFractionScaled + bFractionScaled; + } + + // Scale up the whole part and scale down the fraction part to combine them. + fraction = reduceScale(TResult(fraction), higherScale - rScale); + const auto whole = TResult(aWhole) + TResult(bWhole) + TResult(carryToLeft); + return decimalAddResult(whole, TResult(fraction), rScale, overflow); +} + +// Adds two opposite values by adding the whole and fraction parts separately. +template +inline static TResult addLargeOpposite( + A a, + B b, + uint8_t aScale, + uint8_t bScale, + int32_t rScale, + bool& overflow) { + VELOX_DCHECK( + (a < 0 && b > 0) || (a > 0 && b < 0), + "One positve and one negative value are expected in addLargeOpposite."); + + // Separate whole and fraction parts. + const auto [aWhole, aFraction] = getWholeAndFraction(a, aScale); + const auto [bWhole, bFraction] = getWholeAndFraction(b, bScale); + + // Adjust fractional parts to higher scale. + const auto higherScale = std::max(aScale, bScale); + const auto aFractionScaled = + increaseScale((int128_t)aFraction, higherScale - aScale); + const auto bFractionScaled = + increaseScale((int128_t)bFraction, higherScale - bScale); + + // No need to consider overflow because two inputs are opposite. + int128_t whole = (int128_t)aWhole + (int128_t)bWhole; + int128_t fraction = aFractionScaled + bFractionScaled; + + // If the whole and fractional parts have different signs, adjust them to the + // same sign. + const auto scaleFactor = velox::DecimalUtil::kPowersOfTen[higherScale]; + if (whole < 0 && fraction > 0) { + whole += 1; + fraction -= scaleFactor; + } else if (whole > 0 && fraction < 0) { + whole -= 1; + fraction += scaleFactor; + } + + // Scale up the whole part and scale down the fraction part to combine them. + fraction = reduceScale(TResult(fraction), higherScale - rScale); + return decimalAddResult(TResult(whole), TResult(fraction), rScale, overflow); +} + +template +inline static TResult addLarge( + A a, + B b, + uint8_t aScale, + uint8_t bScale, + int32_t rScale, + bool& overflow) { + if (a >= 0 && b >= 0) { + // Both non-negative. + return addLargeNonNegative( + a, b, aScale, bScale, rScale, overflow); + } else if (a <= 0 && b <= 0) { + // Both non-positive. + return TResult(-addLargeNonNegative( + A(-a), B(-b), aScale, bScale, rScale, overflow)); + } else { + // One positive and the other negative. + return addLargeOpposite( + a, b, aScale, bScale, rScale, overflow); + } +} + template < typename R /* Result Type */, typename A /* Argument1 */, @@ -195,6 +364,112 @@ class DecimalBaseFunction : public exec::VectorFunction { const uint8_t rScale_; }; +class Addition { + public: + template + inline static void apply( + TResult& r, + A a, + B b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t /* aPrecision */, + uint8_t aScale, + uint8_t /* bPrecision */, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, + bool& overflow) { + if (rPrecision < LongDecimalType::kMaxPrecision) { + const int128_t aRescaled = a * velox::DecimalUtil::kPowersOfTen[aRescale]; + const int128_t bRescaled = b * velox::DecimalUtil::kPowersOfTen[bRescale]; + r = TResult(aRescaled + bRescaled); + } else { + const uint32_t minLeadingZeros = + DecimalUtil::minLeadingZeros(a, b, aRescale, bRescale); + if (minLeadingZeros >= 3) { + // Fast path for no overflow. If both numbers contain at least 3 leading + // zeros, they can be added directly without the risk of overflow. + // The reason is if a number contains at least 2 leading zeros, it is + // ensured that the number fits in the maximum of decimal, because + // '2^126 - 1 < 10^38 - 1'. If both numbers contain at least 3 leading + // zeros, we are guaranteed that the result will have at least 2 leading + // zeros. + int128_t aRescaled = a * velox::DecimalUtil::kPowersOfTen[aRescale]; + int128_t bRescaled = b * velox::DecimalUtil::kPowersOfTen[bRescale]; + r = reduceScale( + TResult(aRescaled + bRescaled), std::max(aScale, bScale) - rScale); + } else { + // The risk of overflow should be considered. Add whole and fraction + // parts separately, and then combine. + r = addLarge(a, b, aScale, bScale, rScale, overflow); + } + } + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { + return std::max(0, toScale - fromScale); + } + + inline static std::pair computeResultPrecisionScale( + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { + auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + + std::max(aScale, bScale) + 1; + auto scale = std::max(aScale, bScale); + return DecimalUtil::adjustPrecisionScale(precision, scale); + } +}; + +class Subtraction { + public: + template + inline static void apply( + TResult& r, + A a, + B b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, + bool& overflow) { + Addition::apply( + r, + a, + B(-b), + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + overflow); + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale = 0) { + return std::max(0, toScale - fromScale); + } + + inline static std::pair computeResultPrecisionScale( + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { + return Addition::computeResultPrecisionScale( + aPrecision, aScale, bPrecision, bScale); + } +}; + class Multiply { public: // Derive from Arrow. @@ -202,8 +477,8 @@ class Multiply { template inline static void apply( R& r, - const A& a, - const B& b, + A a, + B b, uint8_t aRescale, uint8_t bRescale, uint8_t aPrecision, @@ -288,10 +563,10 @@ class Multiply { } inline static std::pair computeResultPrecisionScale( - const uint8_t aPrecision, - const uint8_t aScale, - const uint8_t bPrecision, - const uint8_t bScale) { + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { return DecimalUtil::adjustPrecisionScale( aPrecision + bPrecision + 1, aScale + bScale); } @@ -318,8 +593,8 @@ class Divide { template inline static void apply( R& r, - const A& a, - const B& b, + A a, + B b, uint8_t aRescale, uint8_t /* bRescale */, uint8_t /* aPrecision */, @@ -338,16 +613,38 @@ class Divide { } inline static std::pair computeResultPrecisionScale( - const uint8_t aPrecision, - const uint8_t aScale, - const uint8_t bPrecision, - const uint8_t bScale) { + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { auto scale = std::max(6, aScale + bPrecision + 1); auto precision = aPrecision - aScale + bScale + scale; return DecimalUtil::adjustPrecisionScale(precision, scale); } }; +std::vector> +decimalAddSubtractSignature() { + return { + exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", + "min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)") + .integerVariable( + "r_scale", + getResultScale( + "max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1", + "max(a_scale, b_scale)")) + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build()}; +} + std::vector> decimalMultiplySignature() { return {exec::FunctionSignatureBuilder() @@ -393,14 +690,16 @@ std::shared_ptr createDecimalFunction( const std::string& name, const std::vector& inputArgs, const core::QueryConfig& /*config*/) { - auto aType = inputArgs[0].type; - auto bType = inputArgs[1].type; - auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); - auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); - auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( + const auto& aType = inputArgs[0].type; + const auto& bType = inputArgs[1].type; + const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); + const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); + const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( aPrecision, aScale, bPrecision, bScale); - uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); - uint8_t bRescale = Operation::computeRescaleFactor(bScale, aScale, rScale); + const uint8_t aRescale = + Operation::computeRescaleFactor(aScale, bScale, rScale); + const uint8_t bRescale = + Operation::computeRescaleFactor(bScale, aScale, rScale); if (aType->isShortDecimal()) { if (bType->isShortDecimal()) { if (rPrecision > ShortDecimalType::kMaxPrecision) { @@ -478,10 +777,19 @@ std::shared_ptr createDecimalFunction( rScale); } } - VELOX_UNSUPPORTED(); } }; // namespace +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_add, + decimalAddSubtractSignature(), + createDecimalFunction); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_sub, + decimalAddSubtractSignature(), + createDecimalFunction); + VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_mul, decimalMultiplySignature(), diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index b6aa538f3e05..fbe5da77809e 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -35,8 +35,8 @@ class DecimalUtil { /// This method is used only when /// `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. inline static std::pair adjustPrecisionScale( - const uint8_t rPrecision, - const uint8_t rScale) { + uint8_t rPrecision, + uint8_t rScale) { if (rPrecision <= LongDecimalType::kMaxPrecision) { return {rPrecision, rScale}; } else { @@ -109,6 +109,30 @@ class DecimalUtil { return value; } + /// Returns the minumum number of leading zeros after scaling up two inputs + /// for certain scales. Inputs are decimal values of bigint or hugeint type. + template + inline static uint32_t + minLeadingZeros(A a, B b, uint8_t aRescale, uint8_t bRescale) { + auto minLeadingZerosAfterRescale = [](int32_t numLeadingZeros, + uint8_t scale) { + if (scale == 0) { + return numLeadingZeros; + } + /// If a value containing 'numLeadingZeros' leading zeros is scaled up by + /// 'scale', the new leading zeros depend on the max bits need to be + /// increased. + return std::max( + numLeadingZeros - kMaxBitsRequiredIncreaseAfterScaling[scale], 0); + }; + + const int32_t aLeadingZeros = minLeadingZerosAfterRescale( + bits::countLeadingZeros(absValue(a)), aRescale); + const int32_t bLeadingZeros = minLeadingZerosAfterRescale( + bits::countLeadingZeros(absValue(b)), bRescale); + return std::min(aLeadingZeros, bLeadingZeros); + } + /// Derives from Arrow BasicDecimal128 Divide. /// https://github.com/apache/arrow/blob/release-12.0.1-rc1/cpp/src/gandiva/precompiled/decimal_ops.cc#L350 /// @@ -120,12 +144,8 @@ class DecimalUtil { /// int256_t as intermediate type, and then convert to real result type with /// overflow flag. template - inline static R divideWithRoundUp( - R& r, - const A& a, - const B& b, - uint8_t aRescale, - bool& overflow) { + inline static R + divideWithRoundUp(R& r, A a, B b, uint8_t aRescale, bool& overflow) { if (b == 0) { overflow = true; return R(-1); @@ -192,9 +212,11 @@ class DecimalUtil { } private: - /// We rely on the following formula: - /// bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1 - /// We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76 + /// Maintains the max bits that need to be increased for rescaling a value by + /// certain scale. The calculation relies on the following formula: + /// bitsRequired(x * 10^y) <= bitsRequired(x) + floor(log2(10^y)) + 1. + /// This array stores the precomputed 'floor(log2(10^y)) + 1' for y = 0, + /// 1, 2, ..., 75, 76. static constexpr int32_t kMaxBitsRequiredIncreaseAfterScaling[] = { 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, 44, 47, 50, 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, @@ -204,9 +226,7 @@ class DecimalUtil { 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253}; template - inline static int32_t maxBitsRequiredAfterScaling( - const A& num, - uint8_t aRescale) { + inline static int32_t maxBitsRequiredAfterScaling(A num, uint8_t aRescale) { auto valueAbs = absValue(num); int32_t numOccupied = sizeof(A) * 8 - bits::countLeadingZeros(valueAbs); return numOccupied + kMaxBitsRequiredIncreaseAfterScaling[aRescale]; diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 5435cdf66010..08851f9e0d9f 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -91,6 +91,8 @@ void registerArithmeticFunctions(const std::string& prefix) { registerFunction({prefix + "log10"}); registerRandFunctions(prefix); + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_add, prefix + "add"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_sub, prefix + "subtract"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "multiply"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "divide"); } diff --git a/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp index a370809a946f..4558c7f4e910 100644 --- a/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/DecimalArithmeticTest.cpp @@ -43,22 +43,257 @@ class DecimalArithmeticTest : public SparkFunctionBaseTest { assertEqualVectors(expected, result); } - VectorPtr makeLongDecimalVector( - const std::vector& value, - int8_t precision, - int8_t scale) { - if (value.size() == 1) { - return makeConstant( - HugeInt::parse(std::move(value[0])), 1, DECIMAL(precision, scale)); - } - std::vector int128s; - for (auto& v : value) { - int128s.emplace_back(HugeInt::parse(std::move(v))); + void testArithmeticFunction( + const std::string& functionName, + const std::vector& inputs, + const VectorPtr& expected) { + VELOX_USER_CHECK_EQ( + inputs.size(), + 2, + "Two input vectors are needed for arithmetic function test."); + std::vector inputExprs = { + std::make_shared(inputs[0]->type(), "c0"), + std::make_shared(inputs[1]->type(), "c1")}; + auto expr = std::make_shared( + expected->type(), std::move(inputExprs), functionName); + testEncodings(expr, inputs, expected); + } + + VectorPtr makeNullableLongDecimalVector( + const std::vector& values, + const TypePtr& type) { + VELOX_USER_CHECK( + type->isDecimal(), + "Decimal type is needed to create long decimal vector."); + std::vector> numbers; + numbers.reserve(values.size()); + for (const auto& value : values) { + if (value == "null") { + numbers.emplace_back(std::nullopt); + } else { + numbers.emplace_back(HugeInt::parse(value)); + } } - return makeFlatVector(int128s, DECIMAL(precision, scale)); + return makeNullableFlatVector(numbers, type); } }; // namespace +TEST_F(DecimalArithmeticTest, add) { + // Precision < 38. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"201", "601", "1366", "999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"301", "901", "9866", "999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"502", "1502", "11232", "1999999999999999999999999999998"}, + DECIMAL(31, 3))); + + // Min leading zero >= 3. + testArithmeticFunction( + "add", + {makeFlatVector( + std::vector{11232100, 9998888, 12345678, 2135632}, + DECIMAL(38, 7)), + makeFlatVector(std::vector{1, 2, 3, 4}, DECIMAL(10, 0))}, + makeFlatVector( + std::vector{2123210, 2999889, 4234568, 4213563}, + DECIMAL(38, 6))); + + // No carry to left. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999000000", + "9999999999999999999999999999999900000", + "9999999999999999999999999999999990000", + "9999999999999999999999999999999999000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{100, 99999, 1234, 999}, DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999990000010", + "99999999999999999999999999999999010000", + "99999999999999999999999999999999900123", + "99999999999999999999999999999999990100"}, + DECIMAL(38, 6))); + + // Carry to left. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999070000", + "9999999999999999999999999999999050000", + "9999999999999999999999999999999870000", + "9999999999999999999999999999999890000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{8000000, 5000000, 8000000, 1999999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999991500000", + "99999999999999999999999999999991000000", + "99999999999999999999999999999999500000", + "99999999999999999999999999999999100000"}, + DECIMAL(38, 6))); + + // Both -ve. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"-201", "-601", "-1366", "-999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"-301", "-901", "-9866", "-999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"-502", "-1502", "-11232", "-1999999999999999999999999999998"}, + DECIMAL(31, 3))); + + // Overflow when scaling up the whole part. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"-99999999999999999999999999999999990000", + "99999999999999999999999999999999999000", + "-99999999999999999999999999999999999900", + "99999999999999999999999999999999999990"}, + DECIMAL(38, 3)), + makeFlatVector( + std::vector{-100, 9999999, -999900, 99999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"null", "null", "null", "null"}, DECIMAL(38, 6))); + + // Ve and -ve. + testArithmeticFunction( + "add", + {makeNullableLongDecimalVector( + {"99999999999999999999999999999989999990", + "-99999999999999999999999999999989999990", + "99999999999999999999999999999999999980", + "-99999999999999999999999999999999999980"}, + DECIMAL(38, 6)), + makeNullableLongDecimalVector( + {"-9999999999999999999999999999998900000", + "9999999999999999999999999999998900000", + "-9999999999999999999999999999999999999", + "9999999999999999999999999999999999999"}, + DECIMAL(38, 5))}, + makeNullableLongDecimalVector( + {"999990", "-999990", "-10", "10"}, DECIMAL(38, 6))); +} + +TEST_F(DecimalArithmeticTest, subtract) { + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"201", "601", "1366", "999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"301", "901", "9866", "-999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"-100", "-300", "-8500", "1999999999999999999999999999998"}, + DECIMAL(31, 3))); + + // Min leading zero >= 3. + testArithmeticFunction( + "subtract", + {makeFlatVector( + std::vector{11232100, 9998888, 12345678, 2135632}, + DECIMAL(38, 7)), + makeFlatVector(std::vector{1, 2, 3, 4}, DECIMAL(10, 0))}, + makeFlatVector( + std::vector{123210, -1000111, -1765432, -3786437}, + DECIMAL(38, 6))); + + // No carry to left. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999000000", + "9999999999999999999999999999999900000", + "9999999999999999999999999999999990000", + "9999999999999999999999999999999999000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{-100, -99999, -1234, -999}, DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999990000010", + "99999999999999999999999999999999010000", + "99999999999999999999999999999999900123", + "99999999999999999999999999999999990100"}, + DECIMAL(38, 6))); + + // Carry to left. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"9999999999999999999999999999999070000", + "9999999999999999999999999999999050000", + "9999999999999999999999999999999870000", + "9999999999999999999999999999999890000"}, + DECIMAL(38, 5)), + makeFlatVector( + std::vector{-8000000, -5000000, -8000000, -1999999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999991500000", + "99999999999999999999999999999991000000", + "99999999999999999999999999999999500000", + "99999999999999999999999999999999100000"}, + DECIMAL(38, 6))); + + // Both -ve. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"-201", "-601", "-1366", "-999999999999999999999999999999"}, + DECIMAL(30, 3)), + makeNullableLongDecimalVector( + {"-301", "-901", "-9866", "-999999999999999999999999999999"}, + DECIMAL(30, 3))}, + makeNullableLongDecimalVector( + {"100", "300", "8500", "0"}, DECIMAL(31, 3))); + + // Overflow when scaling up the whole part. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"-99999999999999999999999999999999990000", + "99999999999999999999999999999999999000", + "-99999999999999999999999999999999999900", + "99999999999999999999999999999999999990"}, + DECIMAL(38, 3)), + makeFlatVector( + std::vector{100, -9999999, 999900, -99999}, + DECIMAL(38, 7))}, + makeNullableLongDecimalVector( + {"null", "null", "null", "null"}, DECIMAL(38, 6))); + + // Ve and -ve. + testArithmeticFunction( + "subtract", + {makeNullableLongDecimalVector( + {"99999999999999999999999999999989999990", + "-99999999999999999999999999999989999990", + "99999999999999999999999999999999999980", + "-99999999999999999999999999999999999980"}, + DECIMAL(38, 6)), + makeFlatVector( + std::vector{-1000000, 1000000, -1, 1}, DECIMAL(38, 5))}, + makeNullableLongDecimalVector( + {"99999999999999999999999999999999999990", + "-99999999999999999999999999999999999990", + "99999999999999999999999999999999999990", + "-99999999999999999999999999999999999990"}, + DECIMAL(38, 6))); +} + TEST_F(DecimalArithmeticTest, multiply) { // The result can be obtained by Spark unit test // test("multiply") { @@ -185,8 +420,8 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { auto shortFlat = makeFlatVector({1000, 2000}, DECIMAL(17, 3)); // Divide short and short, returning long. testDecimalExpr( - makeLongDecimalVector( - {"500000000000000000000", "2000000000000000000000"}, 38, 21), + makeNullableLongDecimalVector( + {"500000000000000000000", "2000000000000000000000"}, DECIMAL(38, 21)), "divide(c0, c1)", {makeFlatVector({500, 4000}, DECIMAL(17, 3)), shortFlat}); @@ -200,15 +435,17 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { // Divide long and short, returning long. testDecimalExpr( - makeLongDecimalVector( - {"20" + std::string(20, '0'), "5" + std::string(20, '0')}, 38, 22), + makeNullableLongDecimalVector( + {"20" + std::string(20, '0'), "5" + std::string(20, '0')}, + DECIMAL(38, 22)), "divide(c0, c1)", {shortFlat, longFlat}); // Divide long and long, returning long. testDecimalExpr( - makeLongDecimalVector( - {"5" + std::string(18, '0'), "3" + std::string(18, '0')}, 38, 18), + makeNullableLongDecimalVector( + {"5" + std::string(18, '0'), "3" + std::string(18, '0')}, + DECIMAL(38, 18)), "divide(c0, c1)", {makeFlatVector({2500, 12000}, DECIMAL(20, 2)), longFlat}); @@ -236,23 +473,24 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { // checkEvaluation(Divide(l1, l2), null) // } testDecimalExpr( - makeLongDecimalVector({"497512437810945273631840796019900493"}, 38, 6), + makeNullableLongDecimalVector( + {"497512437810945273631840796019900493"}, DECIMAL(38, 6)), "c0 / c1", - {makeLongDecimalVector({std::string(35, '9')}, 35, 6), + {makeNullableLongDecimalVector({std::string(35, '9')}, DECIMAL(35, 6)), makeConstant(201, 1, DECIMAL(20, 3))}); testDecimalExpr( - makeLongDecimalVector( + makeNullableLongDecimalVector( {"1000" + std::string(17, '0'), "500" + std::string(17, '0')}, - 24, - 20), + DECIMAL(24, 20)), "1.00 / c0", {shortFlat}); // Flat and Constant arguments. testDecimalExpr( - makeLongDecimalVector( - {"500" + std::string(4, '0'), "1000" + std::string(4, '0')}, 23, 7), + makeNullableLongDecimalVector( + {"500" + std::string(4, '0'), "1000" + std::string(4, '0')}, + DECIMAL(23, 7)), "c0 / 2.00", {shortFlat}); diff --git a/velox/functions/sparksql/tests/DecimalUtilTest.cpp b/velox/functions/sparksql/tests/DecimalUtilTest.cpp index 350bbf526583..34de356ae7ab 100644 --- a/velox/functions/sparksql/tests/DecimalUtilTest.cpp +++ b/velox/functions/sparksql/tests/DecimalUtilTest.cpp @@ -43,4 +43,21 @@ TEST_F(DecimalUtilTest, divideWithRoundUp) { testDivideWithRoundUp( 6, velox::DecimalUtil::kPowersOfTen[17], 20, 6000, false); } + +TEST_F(DecimalUtilTest, minLeadingZeros) { + auto result = + DecimalUtil::minLeadingZeros(10000, 6000000, 10, 12); + ASSERT_EQ(result, 1); + + result = DecimalUtil::minLeadingZeros( + 10000, 6'000'000'000'000'000'000, 10, 12); + ASSERT_EQ(result, 16); + + result = DecimalUtil::minLeadingZeros( + velox::DecimalUtil::kLongDecimalMax, + velox::DecimalUtil::kLongDecimalMin, + 10, + 12); + ASSERT_EQ(result, 0); +} } // namespace facebook::velox::functions::sparksql::test diff --git a/velox/type/DecimalUtil.cpp b/velox/type/DecimalUtil.cpp index 62e3c61565a8..3b48c3c98a72 100644 --- a/velox/type/DecimalUtil.cpp +++ b/velox/type/DecimalUtil.cpp @@ -54,7 +54,7 @@ std::string formatDecimal(uint8_t scale, int128_t unscaledValue) { } } // namespace -std::string DecimalUtil::toString(const int128_t value, const TypePtr& type) { +std::string DecimalUtil::toString(int128_t value, const TypePtr& type) { auto [precision, scale] = getDecimalPrecisionScale(*type); return formatDecimal(scale, value); } @@ -91,7 +91,7 @@ int32_t DecimalUtil::toByteArray(int128_t value, char* out) { void DecimalUtil::computeAverage( int128_t& avg, - const int128_t& sum, + int128_t sum, int64_t count, int64_t overflow) { if (overflow == 0) { diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index d0a125b84e54..0a6e455e501f 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -98,7 +98,7 @@ class DecimalUtil { } /// Helper function to convert a decimal value to string. - static std::string toString(const int128_t value, const TypePtr& type); + static std::string toString(int128_t value, const TypePtr& type); template inline static void fillDecimals( @@ -147,11 +147,11 @@ class DecimalUtil { template inline static std::optional rescaleWithRoundUp( - const TInput inputValue, - const int fromPrecision, - const int fromScale, - const int toPrecision, - const int toScale) { + TInput inputValue, + int fromPrecision, + int fromScale, + int toPrecision, + int toScale) { int128_t rescaledValue = inputValue; auto scaleDifference = toScale - fromScale; bool isOverflow = false; @@ -183,10 +183,8 @@ class DecimalUtil { } template - inline static std::optional rescaleInt( - const TInput inputValue, - const int toPrecision, - const int toScale) { + inline static std::optional + rescaleInt(TInput inputValue, int toPrecision, int toScale) { int128_t rescaledValue = static_cast(inputValue); bool isOverflow = __builtin_mul_overflow( rescaledValue, DecimalUtil::kPowersOfTen[toScale], &rescaledValue); @@ -205,8 +203,8 @@ class DecimalUtil { template inline static R divideWithRoundUp( R& r, - const A& a, - const B& b, + A a, + B b, bool noRoundUp, uint8_t aRescale, uint8_t /*bRescale*/) { @@ -312,11 +310,8 @@ class DecimalUtil { } /// avg = (sum + overflow * kOverflowMultiplier) / count - static void computeAverage( - int128_t& avg, - const int128_t& sum, - int64_t count, - int64_t overflow); + static void + computeAverage(int128_t& avg, int128_t sum, int64_t count, int64_t overflow); /// Origins from java side BigInteger#bitLength. ///