Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Sep 19, 2018
1 parent 520b64e commit 27a9ea6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ import org.apache.spark.sql.types._
* e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
* e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
* e1 * e2 p1 + p2 + 1 s1 + s2
* e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
* e1 / e2 max(p1-s1+s2, 0) + max(6, s1+adjP2+1) max(6, s1+adjP2+1)
* e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
* e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
*
* Where adjP2 is p2 - s2 if s2 < 0, p2 otherwise. This adjustment is needed because Spark does not
* forbid decimals with negative scale, while MS SQL and Hive do.
*
* When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale
* needed are out of the range of available values, the scale is reduced up to 6, in order to
* prevent the truncation of the integer part of the decimals.
Expand Down Expand Up @@ -133,12 +136,12 @@ object DecimalPrecision extends TypeCoercionRule {
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
val intDig = max(p1 - s1 + s2, 0) // can be negative if s2 < 0
val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + adjP2 + 1)
val prec = intDig + scale
DecimalType.adjustPrecisionScale(prec, scale)
} else {
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var intDig = max(min(DecimalType.MAX_SCALE, p1 - s1 + s2), 0) // can be negative if s2 < 0
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + adjP2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -366,4 +367,25 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2)
assert(ctx2.inlinedMutableStates.size == 1)
}

test("SPARK-25454: decimal operations with negative scale") {
val a = Literal(BigDecimal(1234567891))
val b = Literal(BigDecimal(100e6))
val c = Literal(BigDecimal(123456.7891))
assert(b.dataType.isInstanceOf[DecimalType] &&
b.dataType.asInstanceOf[DecimalType].scale < 0)
Seq("true", "false").foreach { allowPrecLoss =>
withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss) {
checkEvaluationWithOptimization(Add(a, b), Decimal(BigDecimal(1334567891)))
checkEvaluationWithOptimization(Add(b, c), Decimal(BigDecimal(100123456.7891)))
checkEvaluationWithOptimization(Subtract(a, b), Decimal(BigDecimal(1134567891)))
checkEvaluationWithOptimization(Subtract(b, c), Decimal(BigDecimal(99876543.2109)))
checkEvaluationWithOptimization(Multiply(a, b), Decimal(BigDecimal(123456789100000000L)))
checkEvaluationWithOptimization(Multiply(b, c), Decimal(BigDecimal(12345678910000L)))
checkEvaluationWithOptimization(Divide(a, b), Decimal(BigDecimal(12.34567891)))
checkEvaluationWithOptimization(Divide(b, c), Decimal(BigDecimal(810.000007)))
checkEvaluationWithOptimization(Divide(c, b), Decimal(BigDecimal(0.001234567891)))
}
}
}
}

0 comments on commit 27a9ea6

Please sign in to comment.