Skip to content

Commit

Permalink
[FLINK-22652][python][table-planner-blink] Support StreamExecPythonGr…
Browse files Browse the repository at this point in the history
…oupWindowAggregate json serialization/deserialization

This closes #15934.
  • Loading branch information
HuangXingBo committed May 17, 2021
1 parent d09745a commit 5eebab4
Show file tree
Hide file tree
Showing 11 changed files with 3,256 additions and 14 deletions.
67 changes: 67 additions & 0 deletions flink-python/pyflink/table/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,73 @@ def test_session_group_window_over_time(self):
"+I[1, 2018-03-11 03:10:00.0, 2018-03-11 04:10:00.0, 2]",
"+I[1, 2018-03-11 04:20:00.0, 2018-03-11 04:50:00.0, 1]"])

def test_execute_group_window_aggregate_from_json_plan(self):
# create source file path
tmp_dir = self.tempdir
data = [
'1,1,2,2018-03-11 03:10:00',
'3,3,2,2018-03-11 03:10:00',
'2,2,1,2018-03-11 03:10:00',
'1,1,3,2018-03-11 03:40:00',
'1,1,8,2018-03-11 04:20:00',
'2,2,3,2018-03-11 03:30:00'
]
source_path = tmp_dir + '/test_execute_group_window_aggregate_from_json_plan.csv'
sink_path = tmp_dir + '/test_execute_group_window_aggregate_from_json_plan'
with open(source_path, 'w') as fd:
for ele in data:
fd.write(ele + '\n')

source_table = """
CREATE TABLE source_table (
a TINYINT,
b SMALLINT,
c SMALLINT,
rowtime TIMESTAMP(3),
WATERMARK FOR rowtime AS rowtime - INTERVAL '60' MINUTE
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % source_path
self.t_env.execute_sql(source_table)

self.t_env.execute_sql("""
CREATE TABLE sink_table (
a BIGINT,
w_start TIMESTAMP(3),
w_end TIMESTAMP(3),
b BIGINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % sink_path)

self.t_env.create_temporary_function("my_count", CountAggregateFunction())

json_plan = self.t_env._j_tenv.getJsonPlan("INSERT INTO sink_table "
"SELECT a, "
"SESSION_START(rowtime, INTERVAL '30' MINUTE), "
"SESSION_END(rowtime, INTERVAL '30' MINUTE), "
"my_count(c) "
"FROM source_table "
"GROUP BY "
"a, b, SESSION(rowtime, INTERVAL '30' MINUTE)")
from py4j.java_gateway import get_method
get_method(self.t_env._j_tenv.executeJsonPlan(json_plan), "await")()

import glob
lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')]
lines.sort()
self.assertEqual(lines,
['1,"2018-03-11 03:10:00","2018-03-11 04:10:00",2',
'1,"2018-03-11 04:20:00","2018-03-11 04:50:00",1',
'2,"2018-03-11 03:10:00","2018-03-11 04:00:00",2',
'3,"2018-03-11 03:10:00","2018-03-11 03:40:00",1'])


if __name__ == '__main__':
import unittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
import org.apache.flink.table.planner.plan.logical.TumblingGroupWindow;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.serde.LogicalWindowJsonDeserializer;
import org.apache.flink.table.planner.plan.nodes.exec.serde.LogicalWindowJsonSerializer;
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
Expand All @@ -66,6 +66,12 @@
import org.apache.flink.table.runtime.util.TimeWindowUtil;
import org.apache.flink.table.types.logical.RowType;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonSerialize;

import org.apache.calcite.rel.core.AggregateCall;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -75,6 +81,7 @@
import java.time.ZoneId;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.apache.flink.table.planner.plan.utils.AggregateUtil.hasRowIntervalType;
import static org.apache.flink.table.planner.plan.utils.AggregateUtil.hasTimeIntervalType;
Expand All @@ -84,10 +91,12 @@
import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toDuration;
import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toLong;
import static org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToStreamAggregateInfoList;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;

/** Stream {@link ExecNode} for group widow aggregate (Python user defined aggregate function). */
public class StreamExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
implements StreamExecNode<RowData>, SingleTransformationTranslator<RowData> {
@JsonIgnoreProperties(ignoreUnknown = true)
public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBase {
private static final Logger LOGGER =
LoggerFactory.getLogger(StreamExecPythonGroupWindowAggregate.class);

Expand All @@ -100,31 +109,71 @@ public class StreamExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
"org.apache.flink.table.runtime.operators.python.aggregate."
+ "PythonStreamGroupWindowAggregateOperator";

public static final String FIELD_NAME_WINDOW = "window";
public static final String FIELD_NAME_NAMED_WINDOW_PROPERTIES = "namedWindowProperties";

@JsonProperty(FIELD_NAME_GROUPING)
private final int[] grouping;

@JsonProperty(FIELD_NAME_AGG_CALLS)
private final AggregateCall[] aggCalls;

@JsonProperty(FIELD_NAME_WINDOW)
@JsonSerialize(using = LogicalWindowJsonSerializer.class)
@JsonDeserialize(using = LogicalWindowJsonDeserializer.class)
private final LogicalWindow window;

@JsonProperty(FIELD_NAME_NAMED_WINDOW_PROPERTIES)
private final PlannerNamedWindowProperty[] namedWindowProperties;
private final WindowEmitStrategy emitStrategy;

@JsonProperty(FIELD_NAME_NEED_RETRACTION)
private final boolean needRetraction;

@JsonProperty(FIELD_NAME_GENERATE_UPDATE_BEFORE)
private final boolean generateUpdateBefore;

public StreamExecPythonGroupWindowAggregate(
int[] grouping,
AggregateCall[] aggCalls,
LogicalWindow window,
PlannerNamedWindowProperty[] namedWindowProperties,
WindowEmitStrategy emitStrategy,
boolean generateUpdateBefore,
boolean needRetraction,
InputProperty inputProperty,
RowType outputType,
String description) {
super(Collections.singletonList(inputProperty), outputType, description);
this.grouping = grouping;
this.aggCalls = aggCalls;
this.window = window;
this.namedWindowProperties = namedWindowProperties;
this.emitStrategy = emitStrategy;
this(
grouping,
aggCalls,
window,
namedWindowProperties,
generateUpdateBefore,
needRetraction,
getNewNodeId(),
Collections.singletonList(inputProperty),
outputType,
description);
}

@JsonCreator
public StreamExecPythonGroupWindowAggregate(
@JsonProperty(FIELD_NAME_GROUPING) int[] grouping,
@JsonProperty(FIELD_NAME_AGG_CALLS) AggregateCall[] aggCalls,
@JsonProperty(FIELD_NAME_WINDOW) LogicalWindow window,
@JsonProperty(FIELD_NAME_NAMED_WINDOW_PROPERTIES)
PlannerNamedWindowProperty[] namedWindowProperties,
@JsonProperty(FIELD_NAME_GENERATE_UPDATE_BEFORE) boolean generateUpdateBefore,
@JsonProperty(FIELD_NAME_NEED_RETRACTION) boolean needRetraction,
@JsonProperty(FIELD_NAME_ID) int id,
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
super(id, inputProperties, outputType, description);
checkArgument(inputProperties.size() == 1);
this.grouping = checkNotNull(grouping);
this.aggCalls = checkNotNull(aggCalls);
this.window = checkNotNull(window);
this.namedWindowProperties = checkNotNull(namedWindowProperties);
this.generateUpdateBefore = generateUpdateBefore;
this.needRetraction = needRetraction;
}
Expand Down Expand Up @@ -187,6 +236,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
Arrays.stream(aggCalls)
.anyMatch(x -> PythonUtil.isPythonAggregate(x, PythonFunctionKind.GENERAL));
OneInputTransformation<RowData, RowData> transform;
WindowEmitStrategy emitStrategy = WindowEmitStrategy.apply(tableConfig, window);
if (isGeneralPythonUDAF) {
final boolean[] aggCallNeedRetractions = new boolean[aggCalls.length];
Arrays.fill(aggCallNeedRetractions, needRetraction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class StreamPhysicalPythonGroupWindowAggregate(
aggCalls.toArray,
window,
namedWindowProperties.toArray,
emitStrategy,
generateUpdateBefore,
needRetraction,
InputProperty.DEFAULT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ public class JsonSerdeCoverageTest {
"StreamExecLegacySink",
"StreamExecPythonGroupAggregate",
"StreamExecWindowTableFunction",
"StreamExecPythonGroupWindowAggregate",
"StreamExecGroupTableAggregate",
"StreamExecPythonGroupTableAggregate",
"StreamExecPythonOverAggregate",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.nodes.exec.stream;

import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.TestPythonAggregateFunction;
import org.apache.flink.table.planner.utils.StreamTableTestUtil;
import org.apache.flink.table.planner.utils.TableTestBase;

import org.junit.Before;
import org.junit.Test;

/** Test json serialization/deserialization for group window aggregate. */
public class PythonGroupWindowAggregateJsonPlanTest extends TableTestBase {
private StreamTableTestUtil util;
private TableEnvironment tEnv;

@Before
public void setup() {
util = streamTestUtil(TableConfig.getDefault());
tEnv = util.getTableEnv();

String srcTableDdl =
"CREATE TABLE MyTable (\n"
+ " a INT NOT NULL,\n"
+ " b BIGINT,\n"
+ " c VARCHAR,\n"
+ " `rowtime` AS TO_TIMESTAMP(c),\n"
+ " proctime as PROCTIME(),\n"
+ " WATERMARK for `rowtime` AS `rowtime` - INTERVAL '1' SECOND\n"
+ ") WITH (\n"
+ " 'connector' = 'values')\n";
tEnv.executeSql(srcTableDdl);
tEnv.createTemporarySystemFunction("pyFunc", new TestPythonAggregateFunction());
}

@Test
public void testEventTimeTumbleWindow() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " b BIGINT,\n"
+ " window_start TIMESTAMP(3),\n"
+ " window_end TIMESTAMP(3),\n"
+ " c BIGINT\n"
+ ") WITH (\n"
+ " 'connector' = 'values')\n";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan(
"insert into MySink select\n"
+ " b,\n"
+ " TUMBLE_START(rowtime, INTERVAL '5' SECOND) as window_start,\n"
+ " TUMBLE_END(rowtime, INTERVAL '5' SECOND) as window_end,\n"
+ " pyFunc(a, a + 1)\n"
+ "FROM MyTable\n"
+ "GROUP BY b, TUMBLE(rowtime, INTERVAL '5' SECOND)");
}

@Test
public void testProcTimeTumbleWindow() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " b BIGINT,\n"
+ " window_end TIMESTAMP(3),\n"
+ " c BIGINT\n"
+ ") WITH (\n"
+ " 'connector' = 'values')\n";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan(
"insert into MySink select\n"
+ " b,\n"
+ " TUMBLE_END(proctime, INTERVAL '15' MINUTE) as window_end,\n"
+ " pyFunc(a, a + 1)\n"
+ "FROM MyTable\n"
+ "GROUP BY b, TUMBLE(proctime, INTERVAL '15' MINUTE)");
}

@Test
public void testEventTimeHopWindow() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " b BIGINT,\n"
+ " c BIGINT\n"
+ ") WITH (\n"
+ " 'connector' = 'values')\n";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan(
"insert into MySink select\n"
+ " b,\n"
+ " pyFunc(a, a + 1)\n"
+ "FROM MyTable\n"
+ "GROUP BY b, HOP(rowtime, INTERVAL '5' SECOND, INTERVAL '10' SECOND)");
}

@Test
public void testProcTimeHopWindow() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " b BIGINT,\n"
+ " c BIGINT\n"
+ ") WITH (\n"
+ " 'connector' = 'values')\n";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan(
"insert into MySink select\n"
+ " b,\n"
+ " pyFunc(a, a + 1)\n"
+ "FROM MyTable\n"
+ "GROUP BY b, HOP(proctime, INTERVAL '5' MINUTE, INTERVAL '10' MINUTE)");
}

@Test
public void testEventTimeSessionWindow() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " b BIGINT,\n"
+ " c BIGINT\n"
+ ") WITH (\n"
+ " 'connector' = 'values')\n";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan(
"insert into MySink select\n"
+ " b,\n"
+ " pyFunc(a, a + 1)\n"
+ "FROM MyTable\n"
+ "GROUP BY b, Session(rowtime, INTERVAL '10' SECOND)");
}

@Test
public void testProcTimeSessionWindow() {
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " b BIGINT,\n"
+ " c BIGINT\n"
+ ") WITH (\n"
+ " 'connector' = 'values')\n";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan(
"insert into MySink select\n"
+ " b,\n"
+ " pyFunc(a, a + 1)\n"
+ "FROM MyTable\n"
+ "GROUP BY b, Session(proctime, INTERVAL '10' MINUTE)");
}
}
Loading

0 comments on commit 5eebab4

Please sign in to comment.