From 626511953b87747e933e4f64b9fcd4c4776a5c4e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Jun 2017 19:18:28 +0800 Subject: [PATCH] [SPARK-20211][SQL][BACKPORT-2.2] Fix the Precision and Scale of Decimal Values when the Input is BigDecimal between -1.0 and 1.0 ### What changes were proposed in this pull request? This PR is to backport https://github.com/apache/spark/pull/18244 to 2.2 --- The precision and scale of decimal values are wrong when the input is BigDecimal between -1.0 and 1.0. The BigDecimal's precision is the digit count starts from the leftmost nonzero digit based on the [JAVA's BigDecimal definition](https://docs.oracle.com/javase/7/docs/api/java/math/BigDecimal.html). However, our Decimal decision follows the database decimal standard, which is the total number of digits, including both to the left and the right of the decimal point. Thus, this PR is to fix the issue by doing the conversion. Before this PR, the following queries failed: ```SQL select 1 > 0.0001 select floor(0.0001) select ceil(0.0001) ``` ### How was this patch tested? Added test cases. Author: gatorsmile Closes #18297 from gatorsmile/backport18244. --- .../org/apache/spark/sql/types/Decimal.scala | 10 +- .../apache/spark/sql/types/DecimalSuite.scala | 10 ++ .../resources/sql-tests/inputs/arithmetic.sql | 24 ++++ .../sql-tests/results/arithmetic.sql.out | 134 +++++++++++++++++- 4 files changed, 176 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 80916ee9c5379..1f1fb51addfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -126,7 +126,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + if (decimal.precision <= decimal.scale) { + // For Decimal, we expect the precision is equal to or large than the scale, however, + // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact + // result. For example, the precision of 0.01 equals to 1 based on the definition, but + // the scale is 2. The expected precision should be 3. + this._precision = decimal.scale + 1 + } else { + this._precision = decimal.precision + } this._scale = decimal.scale this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 93c231e30b49b..144f3d688d402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1) + checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql index f62b10ca0037b..492a405d7ebbd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql @@ -32,3 +32,27 @@ select 1 - 2; select 2 * 5; select 5 % 3; select pmod(-7, 3); + +-- math functions +select cot(1); +select cot(null); +select cot(0); +select cot(-1); + +-- ceil and ceiling +select ceiling(0); +select ceiling(1); +select ceil(1234567890123456); +select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); + +-- floor +select floor(0); +select floor(1); +select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001 \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out index ce42c016a7100..3811cd2c30986 100644 --- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 44 -- !query 0 @@ -224,3 +224,135 @@ select pmod(-7, 3) struct -- !query 27 output 2 + + +-- !query 28 +select cot(1) +-- !query 28 schema +struct<> +-- !query 28 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 29 +select cot(null) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 30 +select cot(0) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 31 +select cot(-1) +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 32 +select ceiling(0) +-- !query 32 schema +struct +-- !query 32 output +0 + + +-- !query 33 +select ceiling(1) +-- !query 33 schema +struct +-- !query 33 output +1 + + +-- !query 34 +select ceil(1234567890123456) +-- !query 34 schema +struct +-- !query 34 output +1234567890123456 + + +-- !query 35 +select ceiling(1234567890123456) +-- !query 35 schema +struct +-- !query 35 output +1234567890123456 + + +-- !query 36 +select ceil(0.01) +-- !query 36 schema +struct +-- !query 36 output +1 + + +-- !query 37 +select ceiling(-0.10) +-- !query 37 schema +struct +-- !query 37 output +0 + + +-- !query 38 +select floor(0) +-- !query 38 schema +struct +-- !query 38 output +0 + + +-- !query 39 +select floor(1) +-- !query 39 schema +struct +-- !query 39 output +1 + + +-- !query 40 +select floor(1234567890123456) +-- !query 40 schema +struct +-- !query 40 output +1234567890123456 + + +-- !query 41 +select floor(0.01) +-- !query 41 schema +struct +-- !query 41 output +0 + + +-- !query 42 +select floor(-0.10) +-- !query 42 schema +struct +-- !query 42 output +-1 + + +-- !query 43 +select 1 > 0.00001 +-- !query 43 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 43 output +true