From c623db4dcf7549fed263ab6f92aac49a94a25897 Mon Sep 17 00:00:00 2001 From: Jingsong Lee Date: Wed, 28 Apr 2021 17:23:38 +0800 Subject: [PATCH] [FLINK-19449][table-planner] LEAD/LAG cannot work correctly in streaming mode This closes #15793 --- docs/data/sql_functions.yml | 6 +- .../aggfunctions/LagAggFunction.java | 163 ++++++++++++++ .../StreamExecGlobalWindowAggregate.java | 4 +- .../StreamExecLocalWindowAggregate.java | 2 +- .../stream/StreamExecWindowAggregate.java | 2 +- .../metadata/FlinkRelMdColumnInterval.scala | 25 ++- .../StreamPhysicalGlobalWindowAggregate.scala | 2 +- .../StreamPhysicalLocalWindowAggregate.scala | 2 +- .../StreamPhysicalWindowAggregate.scala | 2 +- .../plan/utils/AggFunctionFactory.scala | 35 ++- .../planner/plan/utils/AggregateUtil.scala | 27 ++- .../aggfunctions/LagAggFunctionTest.java | 62 ++++++ .../metadata/FlinkRelMdHandlerTestBase.scala | 9 +- .../stream/sql/OverAggregateITCase.scala | 68 ++++++ .../typeutils/LinkedListSerializer.java | 203 ++++++++++++++++++ .../typeutils/LinkedListSerializerTest.java | 72 +++++++ 16 files changed, 646 insertions(+), 38 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java create mode 100644 flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java create mode 100644 flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java create mode 100644 flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializerTest.java diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index 6b7caa159c205..51df9d1bb30de 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -674,10 +674,10 @@ aggregate: - sql: ROW_NUMER() description: Assigns a unique, sequential number to each row, starting with one, according to the ordering of rows within the window partition. ROW_NUMBER and RANK are similar. ROW_NUMBER numbers all rows sequentially (for example 1, 2, 3, 4, 5). RANK provides the same numeric value for ties (for example 1, 2, 2, 4, 5). - sql: LEAD(expression [, offset] [, default]) - description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL. - - sql: LAG(expression [, offset] [, default]) description: Returns the value of expression at the offsetth row after the current row in the window. The default value of offset is 1 and the default value of default is NULL. - - sql: FIRST_VALUE(expression) + - sql: LAG(expression [, offset] [, default]) + description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL. + - sql: FIRST_VALUE(expression) description: Returns the first value in an ordered set of values. - sql: LAST_VALUE(expression) description: Returns the last value in an ordered set of values. diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java new file mode 100644 index 0000000000000..2ad9b976a6823 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.aggfunctions; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.utils.DataTypeUtils; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +/** Lag {@link AggregateFunction}. */ +public class LagAggFunction extends BuiltInAggregateFunction> { + + private final transient DataType[] valueDataTypes; + + @SuppressWarnings("unchecked") + public LagAggFunction(LogicalType[] valueTypes) { + this.valueDataTypes = + Arrays.stream(valueTypes) + .map(DataTypeUtils::toInternalDataType) + .toArray(DataType[]::new); + if (valueDataTypes.length == 3 + && valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL) { + if (valueDataTypes[0].getConversionClass() != valueDataTypes[2].getConversionClass()) { + throw new TableException( + String.format( + "Please explicitly cast default value %s to %s.", + valueDataTypes[2], valueDataTypes[1])); + } + } + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List getArgumentDataTypes() { + return Arrays.asList(valueDataTypes); + } + + @Override + public DataType getAccumulatorDataType() { + return DataTypes.STRUCTURED( + LagAcc.class, + DataTypes.FIELD("offset", DataTypes.INT()), + DataTypes.FIELD("defaultValue", valueDataTypes[0]), + DataTypes.FIELD("buffer", getLinkedListType())); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer serializer = + InternalSerializers.create(getOutputDataType().getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + + @Override + public DataType getOutputDataType() { + return valueDataTypes[0]; + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + public void accumulate(LagAcc acc, T value) throws Exception { + acc.buffer.add(value); + while (acc.buffer.size() > acc.offset + 1) { + acc.buffer.removeFirst(); + } + } + + public void accumulate(LagAcc acc, T value, int offset) throws Exception { + if (offset < 0) { + throw new TableException(String.format("Offset(%d) should be positive.", offset)); + } + + acc.offset = offset; + accumulate(acc, value); + } + + public void accumulate(LagAcc acc, T value, int offset, T defaultValue) throws Exception { + acc.defaultValue = defaultValue; + accumulate(acc, value, offset); + } + + public void resetAccumulator(LagAcc acc) throws Exception { + acc.offset = 1; + acc.defaultValue = null; + acc.buffer.clear(); + } + + @Override + public T getValue(LagAcc acc) { + if (acc.buffer.size() < acc.offset + 1) { + return acc.defaultValue; + } else if (acc.buffer.size() == acc.offset + 1) { + return acc.buffer.getFirst(); + } else { + throw new TableException("Too more elements: " + acc); + } + } + + @Override + public LagAcc createAccumulator() { + return new LagAcc<>(); + } + + /** Accumulator for LAG. */ + public static class LagAcc { + public int offset = 1; + public T defaultValue = null; + public LinkedList buffer = new LinkedList<>(); + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LagAcc lagAcc = (LagAcc) o; + return offset == lagAcc.offset + && Objects.equals(defaultValue, lagAcc.defaultValue) + && Objects.equals(buffer, lagAcc.buffer); + } + + @Override + public int hashCode() { + return Objects.hash(offset, defaultValue, buffer); + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java index 8df6f2aef10c4..41ab7a2089e1c 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java @@ -145,14 +145,14 @@ protected Transformation translateToPlanInternal(PlannerBase planner) { final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone); final AggregateInfoList localAggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( localAggInputRowType, // should use original input here JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), false); // isStateBackendDataViews final AggregateInfoList globalAggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( localAggInputRowType, // should use original input here JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java index f333255e53ae9..18f8a8dd21672 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java @@ -122,7 +122,7 @@ protected Transformation translateToPlanInternal(PlannerBase planner) { final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone); final AggregateInfoList aggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( inputRowType, JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java index 913abeec88bc3..322944114663e 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java @@ -143,7 +143,7 @@ protected Transformation translateToPlanInternal(PlannerBase planner) { // Hopping window requires additional COUNT(*) to determine whether to register next timer // through whether the current fired window is empty, see SliceSharedWindowAggProcessor. final AggregateInfoList aggInfoList = - AggregateUtil.deriveWindowAggregateInfoList( + AggregateUtil.deriveStreamWindowAggregateInfoList( inputRowType, JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), windowing.getWindow(), diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 23bd99ce9009c..f7c46411211e6 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { def getAggCallFromLocalAgg( index: Int, aggCalls: Seq[AggregateCall], - inputType: RelDataType): AggregateCall = { + inputType: RelDataType, + isBounded: Boolean): AggregateCall = { val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap( - aggCalls, inputType) + aggCalls, inputType, isBounded) if (outputIndexToAggCallIndexMap.containsKey(index)) { val realIndex = outputIndexToAggCallIndexMap.get(index) aggCalls(realIndex) @@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { def getAggCallIndexInLocalAgg( index: Int, globalAggCalls: Seq[AggregateCall], - inputRowType: RelDataType): Integer = { + inputRowType: RelDataType, + isBounded: Boolean): Integer = { val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap( - globalAggCalls, inputRowType) + globalAggCalls, inputRowType, isBounded) outputIndexToAggCallIndexMap.foreach { case (k, v) => if (v == index) { @@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { case agg: StreamPhysicalGlobalGroupAggregate if agg.aggCalls.length > aggCallIndex => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.aggCalls, agg.localAggInputRowType) + aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded = false) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { null } case agg: StreamPhysicalLocalGroupAggregate => - getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType) + getAggCallFromLocalAgg( + aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded = false) case agg: StreamPhysicalIncrementalGroupAggregate if agg.partialAggCalls.length > aggCallIndex => agg.partialAggCalls(aggCallIndex) case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex => agg.aggCalls(aggCallIndex) case agg: BatchPhysicalLocalHashAggregate => - getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) + getAggCallFromLocalAgg( + aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true) case agg: BatchPhysicalHashAggregate if agg.isMerge => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.getAggCallList, agg.aggInputRowType) + aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { null } case agg: BatchPhysicalLocalSortAggregate => - getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) + getAggCallFromLocalAgg( + aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true) case agg: BatchPhysicalSortAggregate if agg.isMerge => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( - aggCallIndex, agg.getAggCallList, agg.aggInputRowType) + aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true) if (aggCallIndexInLocalAgg != null) { return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg) } else { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala index bef2589d02084..bdace617da6ec 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala @@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate( extends SingleRel(cluster, traitSet, inputRel) with StreamPhysicalRel { - private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList( + private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg), aggCalls, windowing.getWindow, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala index 518ccda34a460..2823aab9919ae 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala @@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate( extends SingleRel(cluster, traitSet, inputRel) with StreamPhysicalRel { - private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList( + private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputRel.getRowType), aggCalls, windowing.getWindow, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala index 21a1f504e4b68..eaa70e2abef3e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala @@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate( extends SingleRel(cluster, traitSet, inputRel) with StreamPhysicalRel { - lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveWindowAggregateInfoList( + lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputRel.getRowType), aggCalls, windowing.getWindow, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index e271a74d715f8..a2b795be1f421 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -45,14 +45,16 @@ import scala.collection.JavaConversions._ * as subclasses of [[SqlAggFunction]] in Calcite but not as [[BridgingSqlAggFunction]]. The factory * returns [[DeclarativeAggregateFunction]] or [[BuiltInAggregateFunction]]. * - * @param inputType the input rel data type - * @param orderKeyIdx the indexes of order key (null when is not over agg) - * @param needRetraction true if need retraction + * @param inputRowType the input row type + * @param orderKeyIndexes the indexes of order key (null when is not over agg) + * @param aggCallNeedRetractions true if need retraction + * @param isBounded true if the source is bounded source */ class AggFunctionFactory( inputRowType: RowType, orderKeyIndexes: Array[Int], - aggCallNeedRetractions: Array[Boolean]) { + aggCallNeedRetractions: Array[Boolean], + isBounded: Boolean) { /** * The entry point to create an aggregate function from the given [[AggregateCall]]. @@ -94,8 +96,12 @@ class AggFunctionFactory( case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK => createDenseRankAggFunction(argTypes) - case _: SqlLeadLagAggFunction => - createLeadLagAggFunction(argTypes, index) + case func: SqlLeadLagAggFunction => + if (isBounded) { + createBatchLeadLagAggFunction(argTypes, index) + } else { + createStreamLeadLagAggFunction(func, argTypes, index) + } case _: SqlSingleValueAggFunction => createSingleValueAggFunction(argTypes) @@ -328,7 +334,22 @@ class AggFunctionFactory( } } - private def createLeadLagAggFunction( + private def createStreamLeadLagAggFunction( + func: SqlLeadLagAggFunction, + argTypes: Array[LogicalType], + index: Int): UserDefinedFunction = { + if (func.getKind == SqlKind.LEAD) { + throw new TableException("LEAD Function is not supported in stream mode.") + } + + if (aggCallNeedRetractions(index)) { + throw new TableException("LAG Function with retraction is not supported in stream mode.") + } + + new LagAggFunction(argTypes) + } + + private def createBatchLeadLagAggFunction( argTypes: Array[LogicalType], index: Int): UserDefinedFunction = { argTypes(0).getTypeRoot match { case TINYINT => diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 9bfcdeb1f69ff..31252382350ba 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -153,6 +153,7 @@ object AggregateUtil extends Enumeration { def getOutputIndexToAggCallIndexMap( aggregateCalls: Seq[AggregateCall], inputType: RelDataType, + isBounded: Boolean, orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = { val aggInfos = transformToAggregateInfoList( FlinkTypeFactory.toLogicalRowType(inputType), @@ -161,7 +162,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, - needDistinctInfo = false).aggInfos + needDistinctInfo = false, + isBounded).aggInfos val map = new util.HashMap[Integer, Integer]() var outputIndex = 0 @@ -248,7 +250,7 @@ object AggregateUtil extends Enumeration { isStateBackendDataViews = true) } - def deriveWindowAggregateInfoList( + def deriveStreamWindowAggregateInfoList( inputRowType: RowType, aggCalls: Seq[AggregateCall], windowSpec: WindowSpec, @@ -271,7 +273,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes = null, needInputCount, isStateBackendDataViews, - needDistinctInfo = true) + needDistinctInfo = true, + isBounded = false) } def transformToBatchAggregateFunctions( @@ -287,7 +290,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, - needDistinctInfo = false).aggInfos + needDistinctInfo = false, + isBounded = true).aggInfos val aggFields = aggInfos.map(_.argIndexes) val bufferTypes = aggInfos.map(_.externalAccTypes) @@ -315,7 +319,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes, needInputCount = false, isStateBackedDataViews = false, - needDistinctInfo = false) + needDistinctInfo = false, + isBounded = true) } def transformToStreamAggregateInfoList( @@ -332,7 +337,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes = null, needInputCount, isStateBackendDataViews, - needDistinctInfo) + needDistinctInfo, + isBounded = false) } /** @@ -355,7 +361,8 @@ object AggregateUtil extends Enumeration { orderKeyIndexes: Array[Int], needInputCount: Boolean, isStateBackedDataViews: Boolean, - needDistinctInfo: Boolean): AggregateInfoList = { + needDistinctInfo: Boolean, + isBounded: Boolean): AggregateInfoList = { // Step-1: // if need inputCount, find count1 in the existed aggregate calls first, @@ -375,7 +382,11 @@ object AggregateUtil extends Enumeration { // Step-3: // create aggregate information - val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes, aggCallNeedRetractions) + val factory = new AggFunctionFactory( + inputRowType, + orderKeyIndexes, + aggCallNeedRetractions, + isBounded) val aggInfos = newAggCalls .zipWithIndex .map { case (call, index) => diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java new file mode 100644 index 0000000000000..e3553d82f0121 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.aggfunctions; + +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.VarCharType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.table.data.StringData.fromString; + +/** Test for {@link LagAggFunction}. */ +public class LagAggFunctionTest + extends AggFunctionTestBase> { + + @Override + protected List> getInputValueSets() { + return Arrays.asList( + Collections.singletonList(fromString("1")), + Arrays.asList(fromString("1"), null), + Arrays.asList(null, null), + Arrays.asList(null, fromString("10"))); + } + + @Override + protected List getExpectedResults() { + return Arrays.asList(null, fromString("1"), null, null); + } + + @Override + protected AggregateFunction> getAggregator() { + return new LagAggFunction<>( + new LogicalType[] {new VarCharType(), new IntType(), new CharType()}); + } + + @Override + protected Class getAccClass() { + return LagAggFunction.LagAcc.class; + } +} diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index f0ecc2422fc26..1dbbb54c354fd 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -949,7 +949,8 @@ class FlinkRelMdHandlerTestBase { val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false)) + Array.fill(aggCalls.size())(false), + false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } @@ -1157,7 +1158,8 @@ class FlinkRelMdHandlerTestBase { val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(calcOnStudentScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false)) + Array.fill(aggCalls.size())(false), + false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } @@ -1324,7 +1326,8 @@ class FlinkRelMdHandlerTestBase { val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false)) + Array.fill(aggCalls.size())(false), + false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala index e81d115a10aee..208de4a5b1108 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala @@ -55,6 +55,74 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest env.getCheckpointConfig.enableUnalignedCheckpoints(false) } + @Test + def testLagFunction(): Unit = { + val sqlQuery = "SELECT a, b, c, " + + " LAG(b) OVER(PARTITION BY a ORDER BY rowtime)," + + " LAG(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," + + " LAG(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" + + "FROM T1" + + val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq( + Left(14000001L, (1, 1L, "Hi")), + Left(14000005L, (1, 2L, "Hi")), + Left(14000002L, (1, 3L, "Hello")), + Left(14000003L, (1, 4L, "Hello")), + Left(14000003L, (1, 5L, "Hello")), + Right(14000020L), + Left(14000021L, (1, 6L, "Hello world")), + Left(14000022L, (1, 7L, "Hello world")), + Right(14000030L)) + + val source = failingDataSource(data) + val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)]) + .setParallelism(source.parallelism) + .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) + + tEnv.registerTable("T1", t1) + + val sink = new TestingAppendSink + tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink) + env.execute() + + val expected = List( + s"1,1,Hi,null,null,10086", + s"1,3,Hello,1,null,10086", + s"1,4,Hello,4,3,3", + s"1,5,Hello,4,3,3", + s"1,2,Hi,5,4,4", + s"1,6,Hello world,2,5,5", + s"1,7,Hello world,6,2,2") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } + + @Test + def testLeadFunction(): Unit = { + expectedException.expectMessage("LEAD Function is not supported in stream mode") + + val sqlQuery = "SELECT a, b, c, " + + " LEAD(b) OVER(PARTITION BY a ORDER BY rowtime)," + + " LEAD(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," + + " LEAD(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" + + "FROM T1" + + val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq( + Left(14000001L, (1, 1L, "Hi")), + Left(14000003L, (1, 5L, "Hello")), + Right(14000020L), + Left(14000021L, (1, 6L, "Hello world")), + Left(14000022L, (1, 7L, "Hello world")), + Right(14000030L)) + val source = failingDataSource(data) + val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)]) + .setParallelism(source.parallelism) + .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) + tEnv.registerTable("T1", t1) + val sink = new TestingAppendSink + tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink) + env.execute() + } + @Test def testRowNumberOnOver(): Unit = { val t = failingDataSource(TestData.tupleData5) diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java new file mode 100644 index 0000000000000..df97203d17238 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.typeutils; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; +import java.util.LinkedList; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A serializer for {@link LinkedList}. The serializer relies on an element serializer for the + * serialization of the list's elements. + * + * @param The type of element in the list. + */ +@Internal +public final class LinkedListSerializer extends TypeSerializer> { + + private static final long serialVersionUID = 1L; + + /** The serializer for the elements of the list. */ + private final TypeSerializer elementSerializer; + + /** + * Creates a list serializer that uses the given serializer to serialize the list's elements. + * + * @param elementSerializer The serializer for the elements of the list + */ + public LinkedListSerializer(TypeSerializer elementSerializer) { + this.elementSerializer = checkNotNull(elementSerializer); + } + + // ------------------------------------------------------------------------ + // LinkedListSerializer specific properties + // ------------------------------------------------------------------------ + + /** + * Gets the serializer for the elements of the list. + * + * @return The serializer for the elements of the list + */ + public TypeSerializer getElementSerializer() { + return elementSerializer; + } + + // ------------------------------------------------------------------------ + // Type Serializer implementation + // ------------------------------------------------------------------------ + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer> duplicate() { + TypeSerializer duplicateElement = elementSerializer.duplicate(); + return duplicateElement == elementSerializer + ? this + : new LinkedListSerializer<>(duplicateElement); + } + + @Override + public LinkedList createInstance() { + return new LinkedList<>(); + } + + @Override + public LinkedList copy(LinkedList from) { + LinkedList newList = new LinkedList<>(); + for (T element : from) { + newList.add(elementSerializer.copy(element)); + } + return newList; + } + + @Override + public LinkedList copy(LinkedList from, LinkedList reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; // var length + } + + @Override + public void serialize(LinkedList list, DataOutputView target) throws IOException { + target.writeInt(list.size()); + for (T element : list) { + elementSerializer.serialize(element, target); + } + } + + @Override + public LinkedList deserialize(DataInputView source) throws IOException { + final int size = source.readInt(); + final LinkedList list = new LinkedList<>(); + for (int i = 0; i < size; i++) { + list.add(elementSerializer.deserialize(source)); + } + return list; + } + + @Override + public LinkedList deserialize(LinkedList reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + // copy number of elements + final int num = source.readInt(); + target.writeInt(num); + for (int i = 0; i < num; i++) { + elementSerializer.copy(source, target); + } + } + + // -------------------------------------------------------------------- + + @Override + public boolean equals(Object obj) { + return obj == this + || (obj != null + && obj.getClass() == getClass() + && elementSerializer.equals( + ((LinkedListSerializer) obj).elementSerializer)); + } + + @Override + public int hashCode() { + return elementSerializer.hashCode(); + } + + // -------------------------------------------------------------------------------------------- + // Serializer configuration snapshot & compatibility + // -------------------------------------------------------------------------------------------- + + @Override + public TypeSerializerSnapshot> snapshotConfiguration() { + return new LinkedListSerializerSnapshot<>(this); + } + + /** Snapshot class for the {@link LinkedListSerializer}. */ + public static class LinkedListSerializerSnapshot + extends CompositeTypeSerializerSnapshot, LinkedListSerializer> { + + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + public LinkedListSerializerSnapshot() { + super(LinkedListSerializer.class); + } + + /** Constructor to create the snapshot for writing. */ + public LinkedListSerializerSnapshot(LinkedListSerializer listSerializer) { + super(listSerializer); + } + + @Override + public int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected LinkedListSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + @SuppressWarnings("unchecked") + TypeSerializer elementSerializer = (TypeSerializer) nestedSerializers[0]; + return new LinkedListSerializer<>(elementSerializer); + } + + @Override + protected TypeSerializer[] getNestedSerializers( + LinkedListSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.getElementSerializer()}; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializerTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializerTest.java new file mode 100644 index 0000000000000..eea15560e9c4d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializerTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.typeutils; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; + +import java.util.LinkedList; +import java.util.Random; + +/** A test for the {@link LinkedListSerializer}. */ +public class LinkedListSerializerTest extends SerializerTestBase> { + + @Override + protected TypeSerializer> createSerializer() { + return new LinkedListSerializer<>(LongSerializer.INSTANCE); + } + + @Override + protected int getLength() { + return -1; + } + + @SuppressWarnings("unchecked") + @Override + protected Class> getTypeClass() { + return (Class>) (Class) LinkedList.class; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @Override + protected LinkedList[] getTestData() { + final Random rnd = new Random(123654789); + + // empty lists + final LinkedList list1 = new LinkedList<>(); + + // single element lists + final LinkedList list2 = new LinkedList<>(); + list2.add(12345L); + + // longer lists + final LinkedList list3 = new LinkedList<>(); + for (int i = 0; i < rnd.nextInt(200); i++) { + list3.add(rnd.nextLong()); + } + + final LinkedList list4 = new LinkedList<>(); + for (int i = 0; i < rnd.nextInt(200); i++) { + list4.add(rnd.nextLong()); + } + + return (LinkedList[]) new LinkedList[] {list1, list2, list3, list4}; + } +}