Skip to content

Commit

Permalink
[FLINK-19449][table-planner] LEAD/LAG cannot work correctly in stream…
Browse files Browse the repository at this point in the history
…ing mode

This closes #15793
  • Loading branch information
JingsongLi authored Apr 28, 2021
1 parent 4885c4c commit c623db4
Show file tree
Hide file tree
Showing 16 changed files with 646 additions and 38 deletions.
6 changes: 3 additions & 3 deletions docs/data/sql_functions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,10 @@ aggregate:
- sql: ROW_NUMER()
description: Assigns a unique, sequential number to each row, starting with one, according to the ordering of rows within the window partition. ROW_NUMBER and RANK are similar. ROW_NUMBER numbers all rows sequentially (for example 1, 2, 3, 4, 5). RANK provides the same numeric value for ties (for example 1, 2, 2, 4, 5).
- sql: LEAD(expression [, offset] [, default])
description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL.
- sql: LAG(expression [, offset] [, default])
description: Returns the value of expression at the offsetth row after the current row in the window. The default value of offset is 1 and the default value of default is NULL.
- sql: FIRST_VALUE(expression)
- sql: LAG(expression [, offset] [, default])
description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL.
- sql: FIRST_VALUE(expression)
description: Returns the first value in an ordered set of values.
- sql: LAST_VALUE(expression)
description: Returns the last value in an ordered set of values.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.aggfunctions;

import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.utils.DataTypeUtils;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;

/** Lag {@link AggregateFunction}. */
public class LagAggFunction<T> extends BuiltInAggregateFunction<T, LagAggFunction.LagAcc<T>> {

private final transient DataType[] valueDataTypes;

@SuppressWarnings("unchecked")
public LagAggFunction(LogicalType[] valueTypes) {
this.valueDataTypes =
Arrays.stream(valueTypes)
.map(DataTypeUtils::toInternalDataType)
.toArray(DataType[]::new);
if (valueDataTypes.length == 3
&& valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL) {
if (valueDataTypes[0].getConversionClass() != valueDataTypes[2].getConversionClass()) {
throw new TableException(
String.format(
"Please explicitly cast default value %s to %s.",
valueDataTypes[2], valueDataTypes[1]));
}
}
}

// --------------------------------------------------------------------------------------------
// Planning
// --------------------------------------------------------------------------------------------

@Override
public List<DataType> getArgumentDataTypes() {
return Arrays.asList(valueDataTypes);
}

@Override
public DataType getAccumulatorDataType() {
return DataTypes.STRUCTURED(
LagAcc.class,
DataTypes.FIELD("offset", DataTypes.INT()),
DataTypes.FIELD("defaultValue", valueDataTypes[0]),
DataTypes.FIELD("buffer", getLinkedListType()));
}

@SuppressWarnings({"unchecked", "rawtypes"})
private DataType getLinkedListType() {
TypeSerializer<T> serializer =
InternalSerializers.create(getOutputDataType().getLogicalType());
return DataTypes.RAW(
LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer));
}

@Override
public DataType getOutputDataType() {
return valueDataTypes[0];
}

// --------------------------------------------------------------------------------------------
// Runtime
// --------------------------------------------------------------------------------------------

public void accumulate(LagAcc<T> acc, T value) throws Exception {
acc.buffer.add(value);
while (acc.buffer.size() > acc.offset + 1) {
acc.buffer.removeFirst();
}
}

public void accumulate(LagAcc<T> acc, T value, int offset) throws Exception {
if (offset < 0) {
throw new TableException(String.format("Offset(%d) should be positive.", offset));
}

acc.offset = offset;
accumulate(acc, value);
}

public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) throws Exception {
acc.defaultValue = defaultValue;
accumulate(acc, value, offset);
}

public void resetAccumulator(LagAcc<T> acc) throws Exception {
acc.offset = 1;
acc.defaultValue = null;
acc.buffer.clear();
}

@Override
public T getValue(LagAcc<T> acc) {
if (acc.buffer.size() < acc.offset + 1) {
return acc.defaultValue;
} else if (acc.buffer.size() == acc.offset + 1) {
return acc.buffer.getFirst();
} else {
throw new TableException("Too more elements: " + acc);
}
}

@Override
public LagAcc<T> createAccumulator() {
return new LagAcc<>();
}

/** Accumulator for LAG. */
public static class LagAcc<T> {
public int offset = 1;
public T defaultValue = null;
public LinkedList<T> buffer = new LinkedList<>();

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LagAcc<?> lagAcc = (LagAcc<?>) o;
return offset == lagAcc.offset
&& Objects.equals(defaultValue, lagAcc.defaultValue)
&& Objects.equals(buffer, lagAcc.buffer);
}

@Override
public int hashCode() {
return Objects.hash(offset, defaultValue, buffer);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);

final AggregateInfoList localAggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
false); // isStateBackendDataViews

final AggregateInfoList globalAggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);

final AggregateInfoList aggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
// Hopping window requires additional COUNT(*) to determine whether to register next timer
// through whether the current fired window is empty, see SliceSharedWindowAggProcessor.
final AggregateInfoList aggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
def getAggCallFromLocalAgg(
index: Int,
aggCalls: Seq[AggregateCall],
inputType: RelDataType): AggregateCall = {
inputType: RelDataType,
isBounded: Boolean): AggregateCall = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
aggCalls, inputType)
aggCalls, inputType, isBounded)
if (outputIndexToAggCallIndexMap.containsKey(index)) {
val realIndex = outputIndexToAggCallIndexMap.get(index)
aggCalls(realIndex)
Expand All @@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
def getAggCallIndexInLocalAgg(
index: Int,
globalAggCalls: Seq[AggregateCall],
inputRowType: RelDataType): Integer = {
inputRowType: RelDataType,
isBounded: Boolean): Integer = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
globalAggCalls, inputRowType)
globalAggCalls, inputRowType, isBounded)

outputIndexToAggCallIndexMap.foreach {
case (k, v) => if (v == index) {
Expand All @@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
case agg: StreamPhysicalGlobalGroupAggregate
if agg.aggCalls.length > aggCallIndex =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.aggCalls, agg.localAggInputRowType)
aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded = false)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: StreamPhysicalLocalGroupAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded = false)
case agg: StreamPhysicalIncrementalGroupAggregate
if agg.partialAggCalls.length > aggCallIndex =>
agg.partialAggCalls(aggCallIndex)
case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchPhysicalLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchPhysicalLocalSortAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
case agg: BatchPhysicalSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveWindowAggregateInfoList(
lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ import scala.collection.JavaConversions._
* as subclasses of [[SqlAggFunction]] in Calcite but not as [[BridgingSqlAggFunction]]. The factory
* returns [[DeclarativeAggregateFunction]] or [[BuiltInAggregateFunction]].
*
* @param inputType the input rel data type
* @param orderKeyIdx the indexes of order key (null when is not over agg)
* @param needRetraction true if need retraction
* @param inputRowType the input row type
* @param orderKeyIndexes the indexes of order key (null when is not over agg)
* @param aggCallNeedRetractions true if need retraction
* @param isBounded true if the source is bounded source
*/
class AggFunctionFactory(
inputRowType: RowType,
orderKeyIndexes: Array[Int],
aggCallNeedRetractions: Array[Boolean]) {
aggCallNeedRetractions: Array[Boolean],
isBounded: Boolean) {

/**
* The entry point to create an aggregate function from the given [[AggregateCall]].
Expand Down Expand Up @@ -94,8 +96,12 @@ class AggFunctionFactory(
case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
createDenseRankAggFunction(argTypes)

case _: SqlLeadLagAggFunction =>
createLeadLagAggFunction(argTypes, index)
case func: SqlLeadLagAggFunction =>
if (isBounded) {
createBatchLeadLagAggFunction(argTypes, index)
} else {
createStreamLeadLagAggFunction(func, argTypes, index)
}

case _: SqlSingleValueAggFunction =>
createSingleValueAggFunction(argTypes)
Expand Down Expand Up @@ -328,7 +334,22 @@ class AggFunctionFactory(
}
}

private def createLeadLagAggFunction(
private def createStreamLeadLagAggFunction(
func: SqlLeadLagAggFunction,
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (func.getKind == SqlKind.LEAD) {
throw new TableException("LEAD Function is not supported in stream mode.")
}

if (aggCallNeedRetractions(index)) {
throw new TableException("LAG Function with retraction is not supported in stream mode.")
}

new LagAggFunction(argTypes)
}

private def createBatchLeadLagAggFunction(
argTypes: Array[LogicalType], index: Int): UserDefinedFunction = {
argTypes(0).getTypeRoot match {
case TINYINT =>
Expand Down
Loading

0 comments on commit c623db4

Please sign in to comment.