Skip to content

Commit

Permalink
[FLINK-19449][table-planner] LEAD/LAG cannot work correctly in stream…
Browse files Browse the repository at this point in the history
…ing mode

This closes apache#15747
  • Loading branch information
JingsongLi committed Apr 28, 2021
1 parent 8961283 commit 3019fba
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.aggfunctions;

import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.utils.DataTypeUtils;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;

/** Lag {@link AggregateFunction}. */
public class LagAggFunction<T> extends BuiltInAggregateFunction<T, LagAggFunction.LagAcc<T>> {

private final transient DataType[] valueDataTypes;

@SuppressWarnings("unchecked")
public LagAggFunction(LogicalType[] valueTypes) {
this.valueDataTypes =
Arrays.stream(valueTypes)
.map(DataTypeUtils::toInternalDataType)
.toArray(DataType[]::new);
if (valueDataTypes.length == 3
&& valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL) {
if (valueDataTypes[0].getConversionClass() != valueDataTypes[2].getConversionClass()) {
throw new TableException(
String.format(
"Please explicitly cast default value %s to %s.",
valueDataTypes[2], valueDataTypes[1]));
}
}
}

// --------------------------------------------------------------------------------------------
// Planning
// --------------------------------------------------------------------------------------------

@Override
public List<DataType> getArgumentDataTypes() {
return Arrays.asList(valueDataTypes);
}

@Override
public DataType getAccumulatorDataType() {
return DataTypes.STRUCTURED(
LagAcc.class,
DataTypes.FIELD("offset", DataTypes.INT()),
DataTypes.FIELD("defaultValue", valueDataTypes[0]),
DataTypes.FIELD("buffer", getLinkedListType()));
}

@SuppressWarnings({"unchecked", "rawtypes"})
private DataType getLinkedListType() {
TypeSerializer<T> serializer =
InternalSerializers.create(getOutputDataType().getLogicalType());
return DataTypes.RAW(
LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer));
}

@Override
public DataType getOutputDataType() {
return valueDataTypes[0];
}

// --------------------------------------------------------------------------------------------
// Runtime
// --------------------------------------------------------------------------------------------

public void accumulate(LagAcc<T> acc, T value) throws Exception {
acc.buffer.add(value);
while (acc.buffer.size() > acc.offset + 1) {
acc.buffer.removeFirst();
}
}

public void accumulate(LagAcc<T> acc, T value, int offset) throws Exception {
if (offset < 0) {
throw new TableException(String.format("Offset(%d) should be positive.", offset));
}

acc.offset = offset;
accumulate(acc, value);
}

public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) throws Exception {
acc.defaultValue = defaultValue;
accumulate(acc, value, offset);
}

public void resetAccumulator(LagAcc<T> acc) throws Exception {
acc.offset = 1;
acc.defaultValue = null;
acc.buffer.clear();
}

@Override
public T getValue(LagAcc<T> acc) {
if (acc.buffer.size() < acc.offset + 1) {
return acc.defaultValue;
} else if (acc.buffer.size() == acc.offset + 1) {
return acc.buffer.getFirst();
} else {
throw new TableException("Too more elements: " + acc);
}
}

@Override
public LagAcc<T> createAccumulator() {
return new LagAcc<>();
}

/** Accumulator for LAG. */
public static class LagAcc<T> {
public int offset = 1;
public T defaultValue = null;
public LinkedList<T> buffer = new LinkedList<>();

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LagAcc<?> lagAcc = (LagAcc<?>) o;
return offset == lagAcc.offset
&& Objects.equals(defaultValue, lagAcc.defaultValue)
&& Objects.equals(buffer, lagAcc.buffer);
}

@Override
public int hashCode() {
return Objects.hash(offset, defaultValue, buffer);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.aggfunctions;

import org.apache.flink.table.data.StringData;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.types.logical.CharType;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.VarCharType;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.apache.flink.table.data.StringData.fromString;

/** Test for {@link LagAggFunction}. */
public class LagAggFunctionTest
extends AggFunctionTestBase<StringData, LagAggFunction.LagAcc<StringData>> {

@Override
protected List<List<StringData>> getInputValueSets() {
return Arrays.asList(
Collections.singletonList(fromString("1")),
Arrays.asList(fromString("1"), null),
Arrays.asList(null, null),
Arrays.asList(null, fromString("10")));
}

@Override
protected List<StringData> getExpectedResults() {
return Arrays.asList(null, fromString("1"), null, null);
}

@Override
protected AggregateFunction<StringData, LagAggFunction.LagAcc<StringData>> getAggregator() {
return new LagAggFunction<>(
new LogicalType[] {new VarCharType(), new IntType(), new CharType()});
}

@Override
protected Class<?> getAccClass() {
return LagAggFunction.LagAcc.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,74 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest
env.getCheckpointConfig.enableUnalignedCheckpoints(false)
}

@Test
def testLagFunction(): Unit = {
val sqlQuery = "SELECT a, b, c, " +
" LAG(b) OVER(PARTITION BY a ORDER BY rowtime)," +
" LAG(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
" LAG(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" +
"FROM T1"

val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
Left(14000001L, (1, 1L, "Hi")),
Left(14000005L, (1, 2L, "Hi")),
Left(14000002L, (1, 3L, "Hello")),
Left(14000003L, (1, 4L, "Hello")),
Left(14000003L, (1, 5L, "Hello")),
Right(14000020L),
Left(14000021L, (1, 6L, "Hello world")),
Left(14000022L, (1, 7L, "Hello world")),
Right(14000030L))

val source = failingDataSource(data)
val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)])
.setParallelism(source.parallelism)
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)

tEnv.registerTable("T1", t1)

val sink = new TestingAppendSink
tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
env.execute()

val expected = List(
s"1,1,Hi,null,null,10086",
s"1,3,Hello,1,null,10086",
s"1,4,Hello,4,3,3",
s"1,5,Hello,4,3,3",
s"1,2,Hi,5,4,4",
s"1,6,Hello world,2,5,5",
s"1,7,Hello world,6,2,2")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}

@Test
def testLeadFunction(): Unit = {
expectedException.expectMessage("LEAD Function is not supported in stream mode")

val sqlQuery = "SELECT a, b, c, " +
" LEAD(b) OVER(PARTITION BY a ORDER BY rowtime)," +
" LEAD(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
" LEAD(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" +
"FROM T1"

val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
Left(14000001L, (1, 1L, "Hi")),
Left(14000003L, (1, 5L, "Hello")),
Right(14000020L),
Left(14000021L, (1, 6L, "Hello world")),
Left(14000022L, (1, 7L, "Hello world")),
Right(14000030L))
val source = failingDataSource(data)
val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)])
.setParallelism(source.parallelism)
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
tEnv.registerTable("T1", t1)
val sink = new TestingAppendSink
tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
env.execute()
}

@Test
def testRowNumberOnOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
Expand Down

0 comments on commit 3019fba

Please sign in to comment.