From 45185307c8ba6319d4cb1e89b8ae110664e0a63d Mon Sep 17 00:00:00 2001 From: liuxian Date: Fri, 12 May 2017 11:38:50 +0800 Subject: [PATCH] [SPARK-20665][SQL] Bround" and "Round" function return NULL ## What changes were proposed in this pull request? spark-sql>select bround(12.3, 2); spark-sql>NULL For this case, the expected result is 12.3, but it is null. So ,when the second parameter is bigger than "decimal.scala", the result is not we expected. "round" function has the same problem. This PR can solve the problem for both of them. ## How was this patch tested? unit test cases in MathExpressionsSuite and MathFunctionsSuite Author: liuxian Closes #17906 from 10110346/wip_lx_0509. --- .../sql/catalyst/expressions/mathExpressions.scala | 12 ++++++------ .../catalyst/expressions/MathExpressionsSuite.scala | 7 +++---- .../org/apache/spark/sql/MathFunctionsSuite.scala | 13 +++++++++++++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c4d47ab2084fd..de1a46dc47805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1023,10 +1023,10 @@ abstract class RoundBase(child: Expression, scale: Expression, // not overriding since _scale is a constant int at runtime def nullSafeEval(input1: Any): Any = { - child.dataType match { - case _: DecimalType => + dataType match { + case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, _scale, mode).orNull + decimal.toPrecision(decimal.precision, s, mode).orNull case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1055,10 +1055,10 @@ abstract class RoundBase(child: Expression, scale: Expression, override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val ce = child.genCode(ctx) - val evaluationCode = child.dataType match { - case _: DecimalType => + val evaluationCode = dataType match { + case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale}, + if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, java.math.BigDecimal.${modeStr})) { ${ev.value} = ${ce.value}; } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6b5bfac94645c..1555dd1cf58d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), BigDecimal(3.141593), BigDecimal(3.1415927)) - // round_scale > current_scale would result in precision increase - // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) } (8 to 10).foreach { scale => - checkEvaluation(Round(bdPi, scale), null, EmptyRow) - checkEvaluation(BRound(bdPi, scale), null, EmptyRow) + checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow) + checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow) } DataTypeTestUtils.numericTypes.foreach { dataType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 328c5395ec91e..c2d08a06569bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + + val bdPi: BigDecimal = BigDecimal(31415925L, 7) + checkAnswer( + sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " + + s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null)) + ) + + checkAnswer( + sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " + + s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) + ) } test("round/bround with data frame from a local Seq of Product") {