Skip to content

Commit

Permalink
[FLINK-22304][table] Refactor some interfaces for TVF based window to…
Browse files Browse the repository at this point in the history
… improve the extendability

This closes #15745
  • Loading branch information
shuo.cs authored and wuchong committed Apr 29, 2021
1 parent 8ffefda commit 6a0d44d
Show file tree
Hide file tree
Showing 22 changed files with 260 additions and 302 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
import org.apache.flink.table.runtime.generated.GeneratedNamespaceAggsHandleFunction;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.aggregate.window.LocalSlicingWindowAggOperator;
import org.apache.flink.table.runtime.operators.aggregate.window.buffers.RecordsWindowBuffer;
import org.apache.flink.table.runtime.operators.aggregate.window.buffers.WindowBuffer;
import org.apache.flink.table.runtime.operators.aggregate.window.combines.LocalAggCombiner;
import org.apache.flink.table.runtime.operators.window.slicing.SliceAssigner;
import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.runtime.typeutils.PagedTypeSerializer;
import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
Expand All @@ -65,7 +69,6 @@ public class StreamExecLocalWindowAggregate extends StreamExecWindowAggregateBas
private static final long WINDOW_AGG_MEMORY_RATIO = 100;

public static final String FIELD_NAME_WINDOWING = "windowing";
public static final String FIELD_NAME_NAMED_WINDOW_PROPERTIES = "namedWindowProperties";

@JsonProperty(FIELD_NAME_GROUPING)
private final int[] grouping;
Expand Down Expand Up @@ -139,14 +142,17 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final RowDataKeySelector selector =
KeySelectorUtil.getRowDataSelector(grouping, InternalTypeInfo.of(inputRowType));

PagedTypeSerializer<RowData> keySer =
(PagedTypeSerializer<RowData>) selector.getProducedType().toSerializer();
AbstractRowDataSerializer<RowData> valueSer = new RowDataSerializer(inputRowType);

WindowBuffer.LocalFactory bufferFactory =
new RecordsWindowBuffer.LocalFactory(
keySer, valueSer, new LocalAggCombiner.Factory(generatedAggsHandler));

final OneInputStreamOperator<RowData, RowData> localAggOperator =
new LocalSlicingWindowAggOperator(
selector,
sliceAssigner,
(PagedTypeSerializer<RowData>) selector.getProducedType().toSerializer(),
new RowDataSerializer(inputRowType),
generatedAggsHandler,
shiftTimeZone);
selector, sliceAssigner, bufferFactory, shiftTimeZone);

return ExecNodeUtil.createOneInputTransformation(
inputTransform,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ object HashAggCodeGenHelper {
val rowDataType = classOf[RowData].getCanonicalName
s"""
|$iteratorType<$rowDataType, $rowDataType> $iteratorTerm =
| $aggregateMapTerm.getEntryIterator();
| $aggregateMapTerm.getEntryIterator(false); // reuse key/value during iterating
|while ($iteratorTerm.advanceNext()) {
| // set result and output
| $reuseGroupKeyTerm = ($rowDataType)$iteratorTerm.getKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ class HashWindowCodeGenerator(
val iteratorTerm = CodeGenUtils.newName("iterator")
s"""
|$iteratorType<$rowDataType, $rowDataType> $iteratorTerm =
| $aggregateMapTerm.getEntryIterator();
| $aggregateMapTerm.getEntryIterator(false); // reuse key/value during iterating
|while ($iteratorTerm.advanceNext()) {
| $reuseAggMapKeyTerm = ($rowDataType) $iteratorTerm.getKey();
| $reuseAggBufferTerm = ($rowDataType) $iteratorTerm.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,10 @@
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.generated.GeneratedNamespaceAggsHandleFunction;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.aggregate.window.buffers.RecordsWindowBuffer;
import org.apache.flink.table.runtime.operators.aggregate.window.buffers.WindowBuffer;
import org.apache.flink.table.runtime.operators.aggregate.window.combines.LocalAggRecordsCombiner;
import org.apache.flink.table.runtime.operators.window.combines.WindowCombineFunction;
import org.apache.flink.table.runtime.operators.window.slicing.ClockService;
import org.apache.flink.table.runtime.operators.window.slicing.SliceAssigner;
import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.typeutils.PagedTypeSerializer;

import java.time.ZoneId;
import java.util.TimeZone;
Expand All @@ -55,8 +49,7 @@ public class LocalSlicingWindowAggOperator extends AbstractStreamOperator<RowDat
private final RowDataKeySelector keySelector;
private final SliceAssigner sliceAssigner;
private final long windowInterval;
private final WindowBuffer.Factory windowBufferFactory;
private final WindowCombineFunction.LocalFactory combinerFactory;
private final WindowBuffer.LocalFactory windowBufferFactory;

/**
* The shift timezone of the window, if the proctime or rowtime type is TIMESTAMP_LTZ, the shift
Expand Down Expand Up @@ -88,29 +81,12 @@ public class LocalSlicingWindowAggOperator extends AbstractStreamOperator<RowDat
public LocalSlicingWindowAggOperator(
RowDataKeySelector keySelector,
SliceAssigner sliceAssigner,
PagedTypeSerializer<RowData> keySer,
AbstractRowDataSerializer<RowData> inputSer,
GeneratedNamespaceAggsHandleFunction<Long> genAggsHandler,
ZoneId shiftTimezone) {
this(
keySelector,
sliceAssigner,
new RecordsWindowBuffer.Factory(keySer, inputSer),
new LocalAggRecordsCombiner.Factory(genAggsHandler, keySer),
shiftTimezone);
}

public LocalSlicingWindowAggOperator(
RowDataKeySelector keySelector,
SliceAssigner sliceAssigner,
WindowBuffer.Factory windowBufferFactory,
WindowCombineFunction.LocalFactory combinerFactory,
WindowBuffer.LocalFactory windowBufferFactory,
ZoneId shiftTimezone) {
this.keySelector = keySelector;
this.sliceAssigner = sliceAssigner;
this.windowInterval = sliceAssigner.getSliceEndInterval();
this.windowBufferFactory = windowBufferFactory;
this.combinerFactory = combinerFactory;
this.shiftTimezone = shiftTimezone;
this.useDayLightSaving = TimeZone.getTimeZone(shiftTimezone).useDaylightTime();
}
Expand All @@ -123,14 +99,13 @@ public void open() throws Exception {
collector = new TimestampedCollector<>(output);
collector.eraseTimestamp();

final WindowCombineFunction localCombiner =
combinerFactory.create(getRuntimeContext(), collector);
this.windowBuffer =
windowBufferFactory.create(
getContainingTask(),
getContainingTask().getEnvironment().getMemoryManager(),
computeMemorySize(),
localCombiner,
getRuntimeContext(),
collector,
shiftTimezone);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@

package org.apache.flink.table.runtime.operators.aggregate.window;

import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.generated.GeneratedNamespaceAggsHandleFunction;
import org.apache.flink.table.runtime.operators.aggregate.window.buffers.RecordsWindowBuffer;
import org.apache.flink.table.runtime.operators.aggregate.window.buffers.WindowBuffer;
import org.apache.flink.table.runtime.operators.aggregate.window.combines.AggRecordsCombiner;
import org.apache.flink.table.runtime.operators.aggregate.window.combines.GlobalAggAccCombiner;
import org.apache.flink.table.runtime.operators.aggregate.window.combines.AggCombiner;
import org.apache.flink.table.runtime.operators.aggregate.window.combines.GlobalAggCombiner;
import org.apache.flink.table.runtime.operators.aggregate.window.processors.SliceSharedWindowAggProcessor;
import org.apache.flink.table.runtime.operators.aggregate.window.processors.SliceUnsharedWindowAggProcessor;
import org.apache.flink.table.runtime.operators.window.combines.WindowCombineFunction;
import org.apache.flink.table.runtime.operators.window.combines.RecordsCombiner;
import org.apache.flink.table.runtime.operators.window.slicing.SliceAssigner;
import org.apache.flink.table.runtime.operators.window.slicing.SliceAssigners.HoppingSliceAssigner;
import org.apache.flink.table.runtime.operators.window.slicing.SliceSharedAssigner;
Expand Down Expand Up @@ -64,7 +63,7 @@ public static SlicingWindowAggOperatorBuilder builder() {
private SliceAssigner assigner;
private AbstractRowDataSerializer<RowData> inputSerializer;
private PagedTypeSerializer<RowData> keySerializer;
private TypeSerializer<RowData> accSerializer;
private AbstractRowDataSerializer<RowData> accSerializer;
private GeneratedNamespaceAggsHandleFunction<Long> generatedAggregateFunction;
private GeneratedNamespaceAggsHandleFunction<Long> localGeneratedAggregateFunction;
private GeneratedNamespaceAggsHandleFunction<Long> globalGeneratedAggregateFunction;
Expand Down Expand Up @@ -95,7 +94,7 @@ public SlicingWindowAggOperatorBuilder assigner(SliceAssigner assigner) {

public SlicingWindowAggOperatorBuilder aggregate(
GeneratedNamespaceAggsHandleFunction<Long> generatedAggregateFunction,
TypeSerializer<RowData> accSerializer) {
AbstractRowDataSerializer<RowData> accSerializer) {
this.generatedAggregateFunction = generatedAggregateFunction;
this.accSerializer = accSerializer;
return this;
Expand All @@ -105,7 +104,7 @@ public SlicingWindowAggOperatorBuilder globalAggregate(
GeneratedNamespaceAggsHandleFunction<Long> localGeneratedAggregateFunction,
GeneratedNamespaceAggsHandleFunction<Long> globalGeneratedAggregateFunction,
GeneratedNamespaceAggsHandleFunction<Long> stateGeneratedAggregateFunction,
TypeSerializer<RowData> accSerializer) {
AbstractRowDataSerializer<RowData> accSerializer) {
this.localGeneratedAggregateFunction = localGeneratedAggregateFunction;
this.globalGeneratedAggregateFunction = globalGeneratedAggregateFunction;
this.generatedAggregateFunction = stateGeneratedAggregateFunction;
Expand All @@ -131,28 +130,27 @@ public SlicingWindowAggOperatorBuilder countStarIndex(int indexOfCountStart) {
checkNotNull(keySerializer);
checkNotNull(accSerializer);
checkNotNull(generatedAggregateFunction);
final WindowBuffer.Factory bufferFactory =
new RecordsWindowBuffer.Factory(keySerializer, inputSerializer);
final WindowCombineFunction.Factory combinerFactory;
if (localGeneratedAggregateFunction != null && globalGeneratedAggregateFunction != null) {

boolean isGlobalAgg =
localGeneratedAggregateFunction != null && globalGeneratedAggregateFunction != null;

RecordsCombiner.Factory combinerFactory;
if (isGlobalAgg) {
combinerFactory =
new GlobalAggAccCombiner.Factory(
localGeneratedAggregateFunction,
globalGeneratedAggregateFunction,
keySerializer);
new GlobalAggCombiner.Factory(
localGeneratedAggregateFunction, globalGeneratedAggregateFunction);
} else {
combinerFactory =
new AggRecordsCombiner.Factory(
generatedAggregateFunction, keySerializer, inputSerializer);
combinerFactory = new AggCombiner.Factory(generatedAggregateFunction);
}
final WindowBuffer.Factory bufferFactory =
new RecordsWindowBuffer.Factory(keySerializer, inputSerializer, combinerFactory);

final SlicingWindowProcessor<Long> windowProcessor;
if (assigner instanceof SliceSharedAssigner) {
windowProcessor =
new SliceSharedWindowAggProcessor(
generatedAggregateFunction,
bufferFactory,
combinerFactory,
(SliceSharedAssigner) assigner,
accSerializer,
indexOfCountStart,
Expand All @@ -162,7 +160,6 @@ public SlicingWindowAggOperatorBuilder countStarIndex(int indexOfCountStart) {
new SliceUnsharedWindowAggProcessor(
generatedAggregateFunction,
bufferFactory,
combinerFactory,
(SliceUnsharedAssigner) assigner,
accSerializer,
shiftTimeZone);
Expand Down
Loading

0 comments on commit 6a0d44d

Please sign in to comment.