Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-19449][table-planner] LEAD/LAG cannot work correctly in streaming mode #15747

Merged
merged 4 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/data/sql_functions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,10 @@ aggregate:
- sql: ROW_NUMER()
description: Assigns a unique, sequential number to each row, starting with one, according to the ordering of rows within the window partition. ROW_NUMBER and RANK are similar. ROW_NUMBER numbers all rows sequentially (for example 1, 2, 3, 4, 5). RANK provides the same numeric value for ties (for example 1, 2, 2, 4, 5).
- sql: LEAD(expression [, offset] [, default])
description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL.
- sql: LAG(expression [, offset] [, default])
description: Returns the value of expression at the offsetth row after the current row in the window. The default value of offset is 1 and the default value of default is NULL.
- sql: FIRST_VALUE(expression)
- sql: LAG(expression [, offset] [, default])
description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL.
- sql: FIRST_VALUE(expression)
description: Returns the first value in an ordered set of values.
- sql: LAST_VALUE(expression)
description: Returns the last value in an ordered set of values.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.aggfunctions;

import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.utils.DataTypeUtils;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;

/** Lag {@link AggregateFunction}. */
public class LagAggFunction<T> extends BuiltInAggregateFunction<T, LagAggFunction.LagAcc<T>> {

private final transient DataType[] valueDataTypes;

@SuppressWarnings("unchecked")
public LagAggFunction(LogicalType[] valueTypes) {
this.valueDataTypes =
Arrays.stream(valueTypes)
.map(DataTypeUtils::toInternalDataType)
.toArray(DataType[]::new);
if (valueDataTypes.length == 3
&& valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL) {
if (valueDataTypes[0].getConversionClass() != valueDataTypes[2].getConversionClass()) {
throw new TableException(
String.format(
"Please explicitly cast default value %s to %s.",
valueDataTypes[2], valueDataTypes[1]));
}
}
}

// --------------------------------------------------------------------------------------------
// Planning
// --------------------------------------------------------------------------------------------

@Override
public List<DataType> getArgumentDataTypes() {
return Arrays.asList(valueDataTypes);
}

@Override
public DataType getAccumulatorDataType() {
return DataTypes.STRUCTURED(
LagAcc.class,
DataTypes.FIELD("offset", DataTypes.INT()),
DataTypes.FIELD("defaultValue", valueDataTypes[0]),
DataTypes.FIELD("buffer", getLinkedListType()));
}

@SuppressWarnings({"unchecked", "rawtypes"})
private DataType getLinkedListType() {
TypeSerializer<T> serializer =
InternalSerializers.create(getOutputDataType().getLogicalType());
return DataTypes.RAW(
LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer));
}

@Override
public DataType getOutputDataType() {
return valueDataTypes[0];
}

// --------------------------------------------------------------------------------------------
// Runtime
// --------------------------------------------------------------------------------------------

public void accumulate(LagAcc<T> acc, T value) throws Exception {
acc.buffer.add(value);
while (acc.buffer.size() > acc.offset + 1) {
acc.buffer.removeFirst();
}
}

public void accumulate(LagAcc<T> acc, T value, int offset) throws Exception {
if (offset < 0) {
throw new TableException(String.format("Offset(%d) should be positive.", offset));
}

acc.offset = offset;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better validating offset > -1 and giving some meaningful exception message.

accumulate(acc, value);
}

public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) throws Exception {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method will invoke above method.

acc.defaultValue = defaultValue;
accumulate(acc, value, offset);
}

public void resetAccumulator(LagAcc<T> acc) throws Exception {
acc.offset = 1;
acc.defaultValue = null;
acc.buffer.clear();
}

@Override
public T getValue(LagAcc<T> acc) {
if (acc.buffer.size() < acc.offset + 1) {
return acc.defaultValue;
} else if (acc.buffer.size() == acc.offset + 1) {
return acc.buffer.getFirst();
} else {
throw new TableException("Too more elements: " + acc);
}
}

@Override
public LagAcc<T> createAccumulator() {
return new LagAcc<>();
}

/** Accumulator for LAG. */
public static class LagAcc<T> {
public int offset = 1;
public T defaultValue = null;
public LinkedList<T> buffer = new LinkedList<>();

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LagAcc<?> lagAcc = (LagAcc<?>) o;
return offset == lagAcc.offset
&& Objects.equals(defaultValue, lagAcc.defaultValue)
&& Objects.equals(buffer, lagAcc.buffer);
}

@Override
public int hashCode() {
return Objects.hash(offset, defaultValue, buffer);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);

final AggregateInfoList localAggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
false); // isStateBackendDataViews

final AggregateInfoList globalAggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);

final AggregateInfoList aggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
// Hopping window requires additional COUNT(*) to determine whether to register next timer
// through whether the current fired window is empty, see SliceSharedWindowAggProcessor.
final AggregateInfoList aggInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
def getAggCallFromLocalAgg(
index: Int,
aggCalls: Seq[AggregateCall],
inputType: RelDataType): AggregateCall = {
inputType: RelDataType,
isBounded: Boolean): AggregateCall = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
aggCalls, inputType)
aggCalls, inputType, isBounded)
if (outputIndexToAggCallIndexMap.containsKey(index)) {
val realIndex = outputIndexToAggCallIndexMap.get(index)
aggCalls(realIndex)
Expand All @@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
def getAggCallIndexInLocalAgg(
index: Int,
globalAggCalls: Seq[AggregateCall],
inputRowType: RelDataType): Integer = {
inputRowType: RelDataType,
isBounded: Boolean): Integer = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
globalAggCalls, inputRowType)
globalAggCalls, inputRowType, isBounded)

outputIndexToAggCallIndexMap.foreach {
case (k, v) => if (v == index) {
Expand All @@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
case agg: StreamPhysicalGlobalGroupAggregate
if agg.aggCalls.length > aggCallIndex =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.aggCalls, agg.localAggInputRowType)
aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded = false)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: StreamPhysicalLocalGroupAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded = false)
case agg: StreamPhysicalIncrementalGroupAggregate
if agg.partialAggCalls.length > aggCallIndex =>
agg.partialAggCalls(aggCallIndex)
case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchPhysicalLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchPhysicalLocalSortAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
getAggCallFromLocalAgg(
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
case agg: BatchPhysicalSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {

lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveWindowAggregateInfoList(
lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ import scala.collection.JavaConversions._
* as subclasses of [[SqlAggFunction]] in Calcite but not as [[BridgingSqlAggFunction]]. The factory
* returns [[DeclarativeAggregateFunction]] or [[BuiltInAggregateFunction]].
*
* @param inputType the input rel data type
* @param orderKeyIdx the indexes of order key (null when is not over agg)
* @param needRetraction true if need retraction
* @param inputRowType the input row type
* @param orderKeyIndexes the indexes of order key (null when is not over agg)
* @param aggCallNeedRetractions true if need retraction
* @param isBounded true if the source is bounded source
*/
class AggFunctionFactory(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor:
param in doc is miss matching.

inputRowType: RowType,
orderKeyIndexes: Array[Int],
aggCallNeedRetractions: Array[Boolean]) {
aggCallNeedRetractions: Array[Boolean],
isBounded: Boolean) {

/**
* The entry point to create an aggregate function from the given [[AggregateCall]].
Expand Down Expand Up @@ -94,8 +96,12 @@ class AggFunctionFactory(
case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
createDenseRankAggFunction(argTypes)

case _: SqlLeadLagAggFunction =>
createLeadLagAggFunction(argTypes, index)
case func: SqlLeadLagAggFunction =>
if (isBounded) {
createBatchLeadLagAggFunction(argTypes, index)
} else {
createStreamLeadLagAggFunction(func, argTypes, index)
}

case _: SqlSingleValueAggFunction =>
createSingleValueAggFunction(argTypes)
Expand Down Expand Up @@ -328,7 +334,22 @@ class AggFunctionFactory(
}
}

private def createLeadLagAggFunction(
private def createStreamLeadLagAggFunction(
func: SqlLeadLagAggFunction,
argTypes: Array[LogicalType],
index: Int): UserDefinedFunction = {
if (func.getKind == SqlKind.LEAD) {
throw new TableException("LEAD Function is not supported in stream mode.")
}

if (aggCallNeedRetractions(index)) {
throw new TableException("LAG Function with retraction is not supported in stream mode.")
}

new LagAggFunction(argTypes)
}

private def createBatchLeadLagAggFunction(
argTypes: Array[LogicalType], index: Int): UserDefinedFunction = {
argTypes(0).getTypeRoot match {
case TINYINT =>
Expand Down
Loading