From 8cac615726b096e62d95acd83e3dea802f4f4c4e Mon Sep 17 00:00:00 2001 From: Shuo Cheng Date: Tue, 11 May 2021 18:58:58 +0800 Subject: [PATCH] [FLINK-22586][table] Improve the precision dedivation for decimal arithmetics This closes #15848 --- .../logical/utils/LogicalTypeMerging.java | 62 +++++++++----- .../types/inference/TypeStrategiesTest.java | 2 +- .../aggfunctions/AvgAggFunction.java | 12 ++- .../aggfunctions/Sum0AggFunction.java | 14 +++- .../aggfunctions/SumAggFunction.java | 26 ++++-- .../SumWithRetractAggFunction.java | 37 ++++++--- .../planner/calcite/FlinkTypeSystem.scala | 80 +++++++++++++++---- .../functions/MathFunctionsITCase.java | 2 +- .../planner/expressions/DecimalTypeTest.scala | 37 ++++----- .../stream/sql/OverAggregateITCase.scala | 32 +++++++- 10 files changed, 216 insertions(+), 88 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java index 611caba67d4c0..dfb1ec1e2ff26 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java @@ -113,6 +113,7 @@ public final class LogicalTypeMerging { YEAR_MONTH_RES_TO_BOUNDARIES = new HashMap<>(); private static final Map, YearMonthResolution> YEAR_MONTH_BOUNDARIES_TO_RES = new HashMap<>(); + private static final int MINIMUM_ADJUSTED_SCALE = 6; static { addYearMonthMapping(YEAR, YEAR); @@ -198,50 +199,50 @@ public static Optional findCommonType(List types) { return Optional.empty(); } + // ========================= Decimal Precision Deriving ========================== + // Adopted from "https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision- + // scale-and-length-transact-sql" + // + // Operation Result Precision Result Scale + // e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + // e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + // e1 * e2 p1 + p2 + 1 s1 + s2 + // e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + // e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + // + // Also, if the precision / scale are out of the range, the scale may be sacrificed + // in order to prevent the truncation of the integer part of the decimals. + /** Finds the result type of a decimal division operation. */ public static DecimalType findDivisionDecimalType( int precision1, int scale1, int precision2, int scale2) { - // adopted from - // https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql int scale = Math.max(6, scale1 + precision2 + 1); int precision = precision1 - scale1 + scale2 + scale; - if (precision > DecimalType.MAX_PRECISION) { - scale = Math.max(6, DecimalType.MAX_PRECISION - (precision - scale)); - precision = DecimalType.MAX_PRECISION; - } - return new DecimalType(false, precision, scale); + return adjustPrecisionScale(precision, scale); } /** Finds the result type of a decimal modulo operation. */ public static DecimalType findModuloDecimalType( int precision1, int scale1, int precision2, int scale2) { - // adopted from Calcite final int scale = Math.max(scale1, scale2); - int precision = - Math.min(precision1 - scale1, precision2 - scale2) + Math.max(scale1, scale2); - precision = Math.min(precision, DecimalType.MAX_PRECISION); - return new DecimalType(false, precision, scale); + int precision = Math.min(precision1 - scale1, precision2 - scale2) + scale; + return adjustPrecisionScale(precision, scale); } /** Finds the result type of a decimal multiplication operation. */ public static DecimalType findMultiplicationDecimalType( int precision1, int scale1, int precision2, int scale2) { - // adopted from Calcite int scale = scale1 + scale2; - scale = Math.min(scale, DecimalType.MAX_PRECISION); - int precision = precision1 + precision2; - precision = Math.min(precision, DecimalType.MAX_PRECISION); - return new DecimalType(false, precision, scale); + int precision = precision1 + precision2 + 1; + return adjustPrecisionScale(precision, scale); } /** Finds the result type of a decimal addition operation. */ public static DecimalType findAdditionDecimalType( int precision1, int scale1, int precision2, int scale2) { - // adopted from Calcite final int scale = Math.max(scale1, scale2); int precision = Math.max(precision1 - scale1, precision2 - scale2) + scale + 1; - precision = Math.min(precision, DecimalType.MAX_PRECISION); - return new DecimalType(false, precision, scale); + return adjustPrecisionScale(precision, scale); } /** Finds the result type of a decimal rounding operation. */ @@ -296,6 +297,27 @@ public static LogicalType findSumAggType(LogicalType argType) { // -------------------------------------------------------------------------------------------- + /** + * Scale adjustment implementation is inspired to SQLServer's one. In particular, when a result + * precision is greater than MAX_PRECISION, the corresponding scale is reduced to prevent the + * integral part of a result from being truncated. + * + *

https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql + */ + private static DecimalType adjustPrecisionScale(int precision, int scale) { + if (precision <= DecimalType.MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + return new DecimalType(false, precision, scale); + } else { + int digitPart = precision - scale; + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; + // otherwise preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + int minScalePart = Math.min(scale, MINIMUM_ADJUSTED_SCALE); + int adjustScale = Math.max(DecimalType.MAX_PRECISION - digitPart, minScalePart); + return new DecimalType(false, DecimalType.MAX_PRECISION, adjustScale); + } + } + private static @Nullable LogicalType findCommonCastableType(List normalizedTypes) { LogicalType resultType = normalizedTypes.get(0); diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java index 908060f6adbb1..618f9a559c08c 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java @@ -197,7 +197,7 @@ public static List testData() { .expectDataType(DataTypes.DECIMAL(11, 8).notNull()), TestSpec.forStrategy("Find a decimal product", TypeStrategies.DECIMAL_TIMES) .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) - .expectDataType(DataTypes.DECIMAL(8, 6).notNull()), + .expectDataType(DataTypes.DECIMAL(9, 6).notNull()), TestSpec.forStrategy("Find a decimal modulo", TypeStrategies.DECIMAL_MOD) .inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2)) .expectDataType(DataTypes.DECIMAL(5, 4).notNull()), diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/AvgAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/AvgAggFunction.java index 8d4a05921dded..7e708b12e2826 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/AvgAggFunction.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/AvgAggFunction.java @@ -20,6 +20,7 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; import org.apache.flink.table.expressions.UnresolvedReferenceExpression; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.DecimalType; @@ -72,7 +73,7 @@ public Expression[] initialValuesExpressions() { @Override public Expression[] accumulateExpressions() { return new Expression[] { - /* sum = */ ifThenElse(isNull(operand(0)), sum, plus(sum, operand(0))), + /* sum = */ adjustSumType(ifThenElse(isNull(operand(0)), sum, plus(sum, operand(0)))), /* count = */ ifThenElse(isNull(operand(0)), count, plus(count, literal(1L))), }; } @@ -80,7 +81,7 @@ public Expression[] accumulateExpressions() { @Override public Expression[] retractExpressions() { return new Expression[] { - /* sum = */ ifThenElse(isNull(operand(0)), sum, minus(sum, operand(0))), + /* sum = */ adjustSumType(ifThenElse(isNull(operand(0)), sum, minus(sum, operand(0)))), /* count = */ ifThenElse(isNull(operand(0)), count, minus(count, literal(1L))), }; } @@ -88,10 +89,15 @@ public Expression[] retractExpressions() { @Override public Expression[] mergeExpressions() { return new Expression[] { - /* sum = */ plus(sum, mergeOperand(sum)), /* count = */ plus(count, mergeOperand(count)) + /* sum = */ adjustSumType(plus(sum, mergeOperand(sum))), + /* count = */ plus(count, mergeOperand(count)) }; } + private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) { + return cast(sumExpr, typeLiteral(getSumType())); + } + /** If all input are nulls, count will be 0 and we will get null after the division. */ @Override public Expression getValueExpression() { diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/Sum0AggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/Sum0AggFunction.java index 478e40ad9b4f5..98b9dc70d2293 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/Sum0AggFunction.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/Sum0AggFunction.java @@ -20,6 +20,7 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; import org.apache.flink.table.expressions.UnresolvedReferenceExpression; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.DecimalType; @@ -28,11 +29,13 @@ import java.math.BigDecimal; import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral; /** built-in sum0 aggregate function. */ public abstract class Sum0AggFunction extends DeclarativeAggregateFunction { @@ -56,20 +59,25 @@ public DataType[] getAggBufferTypes() { @Override public Expression[] accumulateExpressions() { return new Expression[] { - /* sum0 = */ ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0))) + /* sum0 = */ adjustSumType(ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0)))) }; } @Override public Expression[] retractExpressions() { return new Expression[] { - /* sum0 = */ ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0))) + /* sum0 = */ adjustSumType( + ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0)))) }; } @Override public Expression[] mergeExpressions() { - return new Expression[] {/* sum0 = */ plus(sum0, mergeOperand(sum0))}; + return new Expression[] {/* sum0 = */ adjustSumType(plus(sum0, mergeOperand(sum0)))}; + } + + private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) { + return cast(sumExpr, typeLiteral(getResultType())); } @Override diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java index 8bb96aa75c0f6..4e93800406640 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java @@ -21,16 +21,19 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.TableException; import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; import org.apache.flink.table.expressions.UnresolvedReferenceExpression; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.DecimalType; import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral; /** built-in sum aggregate function. */ public abstract class SumAggFunction extends DeclarativeAggregateFunction { @@ -59,10 +62,11 @@ public Expression[] initialValuesExpressions() { @Override public Expression[] accumulateExpressions() { return new Expression[] { - /* sum = */ ifThenElse( - isNull(operand(0)), - sum, - ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))) + /* sum = */ adjustSumType( + ifThenElse( + isNull(operand(0)), + sum, + ifThenElse(isNull(sum), operand(0), plus(sum, operand(0))))) }; } @@ -75,13 +79,19 @@ public Expression[] retractExpressions() { @Override public Expression[] mergeExpressions() { return new Expression[] { - /* sum = */ ifThenElse( - isNull(mergeOperand(sum)), - sum, - ifThenElse(isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))) + /* sum = */ adjustSumType( + ifThenElse( + isNull(mergeOperand(sum)), + sum, + ifThenElse( + isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum))))) }; } + private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) { + return cast(sumExpr, typeLiteral(getResultType())); + } + @Override public Expression getValueExpression() { return sum; diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumWithRetractAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumWithRetractAggFunction.java index fe0d05029c538..42c275034f3b0 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumWithRetractAggFunction.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumWithRetractAggFunction.java @@ -20,12 +20,14 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; import org.apache.flink.table.expressions.UnresolvedReferenceExpression; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.DecimalType; import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.equalTo; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull; @@ -33,6 +35,7 @@ import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral; /** built-in sum aggregate function with retraction. */ public abstract class SumWithRetractAggFunction extends DeclarativeAggregateFunction { @@ -62,10 +65,11 @@ public Expression[] initialValuesExpressions() { @Override public Expression[] accumulateExpressions() { return new Expression[] { - /* sum = */ ifThenElse( - isNull(operand(0)), - sum, - ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))), + /* sum = */ adjustSumType( + ifThenElse( + isNull(operand(0)), + sum, + ifThenElse(isNull(sum), operand(0), plus(sum, operand(0))))), /* count = */ ifThenElse(isNull(operand(0)), count, plus(count, literal(1L))) }; } @@ -73,11 +77,14 @@ public Expression[] accumulateExpressions() { @Override public Expression[] retractExpressions() { return new Expression[] { - /* sum = */ ifThenElse( - isNull(operand(0)), - sum, + /* sum = */ adjustSumType( ifThenElse( - isNull(sum), minus(zeroLiteral(), operand(0)), minus(sum, operand(0)))), + isNull(operand(0)), + sum, + ifThenElse( + isNull(sum), + minus(zeroLiteral(), operand(0)), + minus(sum, operand(0))))), /* count = */ ifThenElse(isNull(operand(0)), count, minus(count, literal(1L))) }; } @@ -85,14 +92,20 @@ public Expression[] retractExpressions() { @Override public Expression[] mergeExpressions() { return new Expression[] { - /* sum = */ ifThenElse( - isNull(mergeOperand(sum)), - sum, - ifThenElse(isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))), + /* sum = */ adjustSumType( + ifThenElse( + isNull(mergeOperand(sum)), + sum, + ifThenElse( + isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum))))), /* count = */ plus(count, mergeOperand(count)) }; } + private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) { + return cast(sumExpr, typeLiteral(getResultType())); + } + @Override public Expression getValueExpression() { return ifThenElse(equalTo(count, literal(0L)), nullOf(getResultType()), sum); diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkTypeSystem.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkTypeSystem.scala index 22f0f24fbf1af..3f9bc06b66e58 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkTypeSystem.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkTypeSystem.scala @@ -18,14 +18,12 @@ package org.apache.flink.table.planner.calcite -import org.apache.flink.table.planner.utils.ShortcutUtils import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory -import org.apache.flink.table.runtime.typeutils.TypeCheckUtils -import org.apache.flink.table.types.logical.{DecimalType, LocalZonedTimestampType, LogicalType, TimestampType} +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging +import org.apache.flink.table.types.logical.{DecimalType, LocalZonedTimestampType, TimestampType} -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory, RelDataTypeSystemImpl} +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory, RelDataTypeFactoryImpl, RelDataTypeSystemImpl} import org.apache.calcite.sql.`type`.{SqlTypeName, SqlTypeUtil} -import org.apache.flink.table.types.logical.utils.LogicalTypeMerging /** * Custom type system for Flink. @@ -96,28 +94,76 @@ class FlinkTypeSystem extends RelDataTypeSystemImpl { unwrapTypeFactory(typeFactory).createFieldTypeFromLogicalType(resultType) } - /** - * Calcite's default impl for division is apparently borrowed from T-SQL, - * but the details are a little different, e.g. when Decimal(34,0)/Decimal(10,0) - * To avoid confusion, follow the exact T-SQL behavior. - * Note that for (+-*), Calcite is also different from T-SQL; - * however, Calcite conforms to SQL2003 while T-SQL does not. - * therefore we keep Calcite's behavior on (+-*). - */ + override def deriveDecimalPlusType( + typeFactory: RelDataTypeFactory, + type1: RelDataType, + type2: RelDataType): RelDataType = { + deriveDecimalType(typeFactory, type1, type2, + (p1, s1, p2, s2) => LogicalTypeMerging.findAdditionDecimalType(p1, s1, p2, s2)) + } + + override def deriveDecimalModType( + typeFactory: RelDataTypeFactory, + type1: RelDataType, + type2: RelDataType): RelDataType = { + deriveDecimalType(typeFactory, type1, type2, + (p1, s1, p2, s2) => { + if (s1 == 0 && s2 == 0) { + return type2 + } + LogicalTypeMerging.findModuloDecimalType(p1, s1, p2, s2) + }) + } + override def deriveDecimalDivideType( typeFactory: RelDataTypeFactory, type1: RelDataType, type2: RelDataType): RelDataType = { + deriveDecimalType(typeFactory, type1, type2, + (p1, s1, p2, s2) => LogicalTypeMerging.findDivisionDecimalType(p1, s1, p2, s2)) + } + + override def deriveDecimalMultiplyType( + typeFactory: RelDataTypeFactory, + type1: RelDataType, + type2: RelDataType): RelDataType = { + deriveDecimalType(typeFactory, type1, type2, + (p1, s1, p2, s2) => LogicalTypeMerging.findMultiplicationDecimalType(p1, s1, p2, s2)) + } + + /** + * Use derivation from [[LogicalTypeMerging]] to derive decimal type. + */ + private def deriveDecimalType( + typeFactory: RelDataTypeFactory, + type1: RelDataType, + type2: RelDataType, + deriveImpl: (Int, Int, Int, Int) => DecimalType): RelDataType = { if (SqlTypeUtil.isExactNumeric(type1) && SqlTypeUtil.isExactNumeric(type2) && - (SqlTypeUtil.isDecimal(type1) || SqlTypeUtil.isDecimal(type2))) { - val result = LogicalTypeMerging.findDivisionDecimalType( - type1.getPrecision, type1.getScale, - type2.getPrecision, type2.getScale) + (SqlTypeUtil.isDecimal(type1) || SqlTypeUtil.isDecimal(type2))) { + val decType1 = adjustType(typeFactory, type1) + val decType2 = adjustType(typeFactory, type2) + val result = deriveImpl( + decType1.getPrecision, decType1.getScale, decType2.getPrecision, decType2.getScale) typeFactory.createSqlType(SqlTypeName.DECIMAL, result.getPrecision, result.getScale) } else { null } } + + /** + * Java numeric will always have invalid precision/scale, + * use its default decimal precision/scale instead. + */ + private def adjustType( + typeFactory: RelDataTypeFactory, + relDataType: RelDataType): RelDataType = { + if (RelDataTypeFactoryImpl.isJavaType(relDataType)) { + typeFactory.decimalOf(relDataType) + } else { + relDataType + } + } } object FlinkTypeSystem { diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/MathFunctionsITCase.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/MathFunctionsITCase.java index d7abdf7d17811..c7ff747d9a6fd 100644 --- a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/MathFunctionsITCase.java +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/MathFunctionsITCase.java @@ -88,7 +88,7 @@ public static List testData() { $("f0").times(6), "f0 * 6", new BigDecimal("9086137920000"), - DataTypes.DECIMAL(29, 0)) + DataTypes.DECIMAL(30, 0)) // DECIMAL(19, 0) * DECIMAL(19, 0) => DECIMAL(38, 0) .testResult( $("f0").times($("f0")), diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/DecimalTypeTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/DecimalTypeTest.scala index e62e743625fef..b84106c320f2f 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/DecimalTypeTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/expressions/DecimalTypeTest.scala @@ -410,32 +410,28 @@ class DecimalTypeTest extends ExpressionTestBase { "f3 + f7", "127.65") - // our result type precision is capped at 38 - // SQL2003 $6.26 -- result scale is dictated as max(s1,s2). no approximation allowed. - // calcite -- scale is not reduced; integral part may be reduced. overflow may occur - // (38,10)+(38,28)=>(57,28)=>(38,28) // T-SQL -- scale may be reduced to keep the integral part. approximation may occur // (38,10)+(38,28)=>(57,28)=>(38,9) testAllApis( 'f15 + 'f16, "f15 + f16", - "300.0246913578012345678901234567") + "300.024691358") testAllApis( 'f15 - 'f16, "f15 - f16", - "-100.0000000000012345678901234567") + "-100.000000000") // 10 digits integral part testAllApis( 'f17 + 'f18, "f17 + f18", - "null") + "10000000000.000000000") testAllApis( 'f17 - 'f18, "f17 - f18", - "null") + "10000000000.000000000") // requires 39 digits testAllApis( @@ -456,7 +452,6 @@ class DecimalTypeTest extends ExpressionTestBase { // see calcite ReturnTypes.DECIMAL_PRODUCT // s = s1+s2, p = p1+p2 // both p&s are capped at 38 - // if s>38, result is rounded to s=38, and the integral part can only be zero testAllApis( 'f20 * 'f20, "f20 * f20", @@ -489,33 +484,31 @@ class DecimalTypeTest extends ExpressionTestBase { "f23 * f20", "3.14") - // precision is capped at 38; scale will not be reduced (unless over 38) - // similar to plus&minus, and calcite behavior is different from T-SQL. + // (60,12) => (38,6), minimum scale 6 is preserved + // while sacrificing scale for more space for integral part testAllApis( 'f24 * 'f24, "f24 * f24", - "1.000000000000") + "1.000000") testAllApis( 'f24 * 'f25, "f24 * f25", - "2.0000000000000000") + "2.000000") testAllApis( 'f26 * 'f26, "f26 * f26", - "0.00010000000000000000000000000000000000" + "0.00010000000000000" ) - // scalastyle:off - // we don't have this ridiculous behavior: - // https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/ - // scalastyle:on - + // (76, 20) -> (38, 6), 0.0000006 is rounding to 0.000001 + // refer "https://blogs.msdn.microsoft.com/sqlprogrammability/ + // 2006/03/29/multiplication-and-division-with-numerics/" testAllApis( 'f27 * 'f28, "f27 * f28", - "0.00000060000000000000" + "0.000001" ) // result overflow @@ -525,11 +518,11 @@ class DecimalTypeTest extends ExpressionTestBase { "null" ) - //(60,40)=>(38,38), no space for integral part + //(60,40) => (38,17), scale part is reduced to make more space for integral part testAllApis( 'f30 * 'f30, "f30 * f30", - "null" + "1.00000000000000000" ) } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala index 208de4a5b1108..acc3ba0c26ced 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.runtime.stream.sql +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.api.scala._ import org.apache.flink.table.api._ import org.apache.flink.table.api.bridge.scala._ @@ -25,13 +26,13 @@ import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.S import org.apache.flink.table.planner.runtime.utils.TimeTestUtil.EventTimeProcessOperator import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils.{CountNullNonNull, CountPairs, LargerThanCount} import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestData, TestingAppendSink} +import org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo import org.apache.flink.types.Row import org.junit.Assert._ import org.junit._ import org.junit.runner.RunWith import org.junit.runners.Parameterized - import scala.collection.mutable @RunWith(classOf[Parameterized]) @@ -1131,4 +1132,33 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest "B,Hello World,10,7") assertEquals(expected, sink.getAppendResults) } + + @Test + def testDecimalSum0(): Unit = { + val data = new mutable.MutableList[Row] + data.+=(Row.of(BigDecimal(1.11).bigDecimal)) + data.+=(Row.of(BigDecimal(2.22).bigDecimal)) + data.+=(Row.of(BigDecimal(3.33).bigDecimal)) + data.+=(Row.of(BigDecimal(4.44).bigDecimal)) + + env.setParallelism(1) + val rowType = new RowTypeInfo(BigDecimalTypeInfo.of(38, 18)) + val t = failingDataSource(data)(rowType).toTable(tEnv, 'd, 'proctime.proctime) + tEnv.registerTable("T", t) + + val sqlQuery = "select sum(d) over (ORDER BY proctime rows between unbounded preceding " + + "and current row) from T" + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + val sink = new TestingAppendSink + result.addSink(sink) + env.execute() + + val expected = List( + "1.110000000000000000", + "3.330000000000000000", + "6.660000000000000000", + "11.100000000000000000") + assertEquals(expected, sink.getAppendResults) + } }