Skip to content

Commit

Permalink
[FLINK-22586][table] Improve the precision dedivation for decimal ari…
Browse files Browse the repository at this point in the history
…thmetics
  • Loading branch information
shuo.cs committed May 7, 2021
1 parent 5954122 commit 99a21fd
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public final class LogicalTypeMerging {
YEAR_MONTH_RES_TO_BOUNDARIES = new HashMap<>();
private static final Map<List<YearMonthResolution>, YearMonthResolution>
YEAR_MONTH_BOUNDARIES_TO_RES = new HashMap<>();
private static final int MINIMUM_ADJUSTED_SCALE = 6;

static {
addYearMonthMapping(YEAR, YEAR);
Expand Down Expand Up @@ -198,50 +199,50 @@ public static Optional<LogicalType> findCommonType(List<LogicalType> types) {
return Optional.empty();
}

// ========================= Decimal Precision Deriving ==========================
// Adopted from "https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-
// scale-and-length-transact-sql"
//
// Operation Result Precision Result Scale
// e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
// e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
// e1 * e2 p1 + p2 + 1 s1 + s2
// e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
// e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
//
// Also, if the precision / scale are out of the range, the scale may be sacrificed
// in order to prevent the truncation of the integer part of the decimals.

/** Finds the result type of a decimal division operation. */
public static DecimalType findDivisionDecimalType(
int precision1, int scale1, int precision2, int scale2) {
// adopted from
// https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
int scale = Math.max(6, scale1 + precision2 + 1);
int precision = precision1 - scale1 + scale2 + scale;
if (precision > DecimalType.MAX_PRECISION) {
scale = Math.max(6, DecimalType.MAX_PRECISION - (precision - scale));
precision = DecimalType.MAX_PRECISION;
}
return new DecimalType(false, precision, scale);
return adjustPrecisionScale(precision, scale);
}

/** Finds the result type of a decimal modulo operation. */
public static DecimalType findModuloDecimalType(
int precision1, int scale1, int precision2, int scale2) {
// adopted from Calcite
final int scale = Math.max(scale1, scale2);
int precision =
Math.min(precision1 - scale1, precision2 - scale2) + Math.max(scale1, scale2);
precision = Math.min(precision, DecimalType.MAX_PRECISION);
return new DecimalType(false, precision, scale);
int precision = Math.min(precision1 - scale1, precision2 - scale2) + scale;
return adjustPrecisionScale(precision, scale);
}

/** Finds the result type of a decimal multiplication operation. */
public static DecimalType findMultiplicationDecimalType(
int precision1, int scale1, int precision2, int scale2) {
// adopted from Calcite
int scale = scale1 + scale2;
scale = Math.min(scale, DecimalType.MAX_PRECISION);
int precision = precision1 + precision2;
precision = Math.min(precision, DecimalType.MAX_PRECISION);
return new DecimalType(false, precision, scale);
int precision = precision1 + precision2 + 1;
return adjustPrecisionScale(precision, scale);
}

/** Finds the result type of a decimal addition operation. */
public static DecimalType findAdditionDecimalType(
int precision1, int scale1, int precision2, int scale2) {
// adopted from Calcite
final int scale = Math.max(scale1, scale2);
int precision = Math.max(precision1 - scale1, precision2 - scale2) + scale + 1;
precision = Math.min(precision, DecimalType.MAX_PRECISION);
return new DecimalType(false, precision, scale);
return adjustPrecisionScale(precision, scale);
}

/** Finds the result type of a decimal rounding operation. */
Expand Down Expand Up @@ -296,6 +297,27 @@ public static LogicalType findSumAggType(LogicalType argType) {

// --------------------------------------------------------------------------------------------

/**
* Scale adjustment implementation is inspired to SQLServer's one. In particular, when a result
* precision is greater than MAX_PRECISION, the corresponding scale is reduced to prevent the
* integral part of a result from being truncated.
*
* <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
*/
private static DecimalType adjustPrecisionScale(int precision, int scale) {
if (precision <= DecimalType.MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
return new DecimalType(false, precision, scale);
} else {
int digitPart = precision - scale;
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value;
// otherwise preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
int minScalePart = Math.min(scale, MINIMUM_ADJUSTED_SCALE);
int adjustScale = Math.max(DecimalType.MAX_PRECISION - digitPart, minScalePart);
return new DecimalType(false, DecimalType.MAX_PRECISION, adjustScale);
}
}

private static @Nullable LogicalType findCommonCastableType(List<LogicalType> normalizedTypes) {
LogicalType resultType = normalizedTypes.get(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public static List<TestSpec> testData() {
.expectDataType(DataTypes.DECIMAL(11, 8).notNull()),
TestSpec.forStrategy("Find a decimal product", TypeStrategies.DECIMAL_TIMES)
.inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2))
.expectDataType(DataTypes.DECIMAL(8, 6).notNull()),
.expectDataType(DataTypes.DECIMAL(9, 6).notNull()),
TestSpec.forStrategy("Find a decimal modulo", TypeStrategies.DECIMAL_MOD)
.inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2))
.expectDataType(DataTypes.DECIMAL(5, 4).notNull()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.DecimalType;
Expand Down Expand Up @@ -72,26 +73,31 @@ public Expression[] initialValuesExpressions() {
@Override
public Expression[] accumulateExpressions() {
return new Expression[] {
/* sum = */ ifThenElse(isNull(operand(0)), sum, plus(sum, operand(0))),
/* sum = */ adjustSumType(ifThenElse(isNull(operand(0)), sum, plus(sum, operand(0)))),
/* count = */ ifThenElse(isNull(operand(0)), count, plus(count, literal(1L))),
};
}

@Override
public Expression[] retractExpressions() {
return new Expression[] {
/* sum = */ ifThenElse(isNull(operand(0)), sum, minus(sum, operand(0))),
/* sum = */ adjustSumType(ifThenElse(isNull(operand(0)), sum, minus(sum, operand(0)))),
/* count = */ ifThenElse(isNull(operand(0)), count, minus(count, literal(1L))),
};
}

@Override
public Expression[] mergeExpressions() {
return new Expression[] {
/* sum = */ plus(sum, mergeOperand(sum)), /* count = */ plus(count, mergeOperand(count))
/* sum = */ adjustSumType(plus(sum, mergeOperand(sum))),
/* count = */ plus(count, mergeOperand(count))
};
}

private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
return cast(sumExpr, typeLiteral(getSumType()));
}

/** If all input are nulls, count will be 0 and we will get null after the division. */
@Override
public Expression getValueExpression() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;

import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;

/** built-in sum aggregate function. */
public abstract class SumAggFunction extends DeclarativeAggregateFunction {
Expand Down Expand Up @@ -59,10 +62,11 @@ public Expression[] initialValuesExpressions() {
@Override
public Expression[] accumulateExpressions() {
return new Expression[] {
/* sum = */ ifThenElse(
isNull(operand(0)),
sum,
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0))))
/* sum = */ adjustSumType(
ifThenElse(
isNull(operand(0)),
sum,
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))))
};
}

Expand All @@ -75,13 +79,19 @@ public Expression[] retractExpressions() {
@Override
public Expression[] mergeExpressions() {
return new Expression[] {
/* sum = */ ifThenElse(
isNull(mergeOperand(sum)),
sum,
ifThenElse(isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum))))
/* sum = */ adjustSumType(
ifThenElse(
isNull(mergeOperand(sum)),
sum,
ifThenElse(
isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))))
};
}

private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
return cast(sumExpr, typeLiteral(getResultType()));
}

@Override
public Expression getValueExpression() {
return sum;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;

import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.equalTo;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;

/** built-in sum aggregate function with retraction. */
public abstract class SumWithRetractAggFunction extends DeclarativeAggregateFunction {
Expand Down Expand Up @@ -62,37 +65,47 @@ public Expression[] initialValuesExpressions() {
@Override
public Expression[] accumulateExpressions() {
return new Expression[] {
/* sum = */ ifThenElse(
isNull(operand(0)),
sum,
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))),
/* sum = */ adjustSumType(
ifThenElse(
isNull(operand(0)),
sum,
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0))))),
/* count = */ ifThenElse(isNull(operand(0)), count, plus(count, literal(1L)))
};
}

@Override
public Expression[] retractExpressions() {
return new Expression[] {
/* sum = */ ifThenElse(
isNull(operand(0)),
sum,
/* sum = */ adjustSumType(
ifThenElse(
isNull(sum), minus(zeroLiteral(), operand(0)), minus(sum, operand(0)))),
isNull(operand(0)),
sum,
ifThenElse(
isNull(sum),
minus(zeroLiteral(), operand(0)),
minus(sum, operand(0))))),
/* count = */ ifThenElse(isNull(operand(0)), count, minus(count, literal(1L)))
};
}

@Override
public Expression[] mergeExpressions() {
return new Expression[] {
/* sum = */ ifThenElse(
isNull(mergeOperand(sum)),
sum,
ifThenElse(isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))),
/* sum = */ adjustSumType(
ifThenElse(
isNull(mergeOperand(sum)),
sum,
ifThenElse(
isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum))))),
/* count = */ plus(count, mergeOperand(count))
};
}

private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
return cast(sumExpr, typeLiteral(getResultType()));
}

@Override
public Expression getValueExpression() {
return ifThenElse(equalTo(count, literal(0L)), nullOf(getResultType()), sum);
Expand Down
Loading

0 comments on commit 99a21fd

Please sign in to comment.