Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
shuo.cs committed May 11, 2021
1 parent 0573b65 commit 6b61c49
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.DecimalType;
Expand All @@ -28,11 +29,13 @@
import java.math.BigDecimal;

import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;

/** built-in sum0 aggregate function. */
public abstract class Sum0AggFunction extends DeclarativeAggregateFunction {
Expand All @@ -56,20 +59,25 @@ public DataType[] getAggBufferTypes() {
@Override
public Expression[] accumulateExpressions() {
return new Expression[] {
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0)))
/* sum0 = */ adjustSumType(ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0))))
};
}

@Override
public Expression[] retractExpressions() {
return new Expression[] {
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0)))
/* sum0 = */ adjustSumType(
ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0))))
};
}

@Override
public Expression[] mergeExpressions() {
return new Expression[] {/* sum0 = */ plus(sum0, mergeOperand(sum0))};
return new Expression[] {/* sum0 = */ adjustSumType(plus(sum0, mergeOperand(sum0)))};
}

private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
return cast(sumExpr, typeLiteral(getResultType()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@

package org.apache.flink.table.planner.runtime.stream.sql

import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.planner.runtime.utils.TimeTestUtil.EventTimeProcessOperator
import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils.{CountNullNonNull, CountPairs, LargerThanCount}
import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestData, TestingAppendSink}
import org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo
import org.apache.flink.types.Row

import org.junit.Assert._
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

import scala.collection.mutable

@RunWith(classOf[Parameterized])
Expand Down Expand Up @@ -1131,4 +1132,33 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest
"B,Hello World,10,7")
assertEquals(expected, sink.getAppendResults)
}

@Test
def testDecimalSum0(): Unit = {
val data = new mutable.MutableList[Row]
data.+=(Row.of(BigDecimal(1.11).bigDecimal))
data.+=(Row.of(BigDecimal(2.22).bigDecimal))
data.+=(Row.of(BigDecimal(3.33).bigDecimal))
data.+=(Row.of(BigDecimal(4.44).bigDecimal))

env.setParallelism(1)
val rowType = new RowTypeInfo(BigDecimalTypeInfo.of(38, 18))
val t = failingDataSource(data)(rowType).toTable(tEnv, 'd, 'proctime.proctime)
tEnv.registerTable("T", t)

val sqlQuery = "select sum(d) over (ORDER BY proctime rows between unbounded preceding " +
"and current row) from T"

val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
val sink = new TestingAppendSink
result.addSink(sink)
env.execute()

val expected = List(
"1.110000000000000000",
"3.330000000000000000",
"6.660000000000000000",
"11.100000000000000000")
assertEquals(expected, sink.getAppendResults)
}
}

0 comments on commit 6b61c49

Please sign in to comment.