Skip to content

Commit

Permalink
[FLINK-21302][table-planner-blink] Fix NPE when use row_number() in o…
Browse files Browse the repository at this point in the history
…ver agg.
  • Loading branch information
beyond1920 committed Apr 20, 2021
1 parent c773047 commit ebe907b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ object AggregateUtil extends Enumeration {
call,
imperativeFunction,
index,
argIndexes,
buildArgIndexes(argIndexes),
imperativeFunction.getArgumentDataTypes.asScala.toArray,
imperativeFunction.getAccumulatorDataType,
imperativeFunction.getOutputDataType,
Expand All @@ -527,7 +527,7 @@ object AggregateUtil extends Enumeration {
call,
udf,
index,
argIndexes,
buildArgIndexes(argIndexes),
null,
declarativeFunction.getAggBufferTypes,
Array(),
Expand Down Expand Up @@ -611,14 +611,21 @@ object AggregateUtil extends Enumeration {
call,
udf,
index,
argIndexes,
buildArgIndexes(argIndexes),
externalArgTypes,
externalAccTypes,
viewSpecs,
externalResultType,
needsRetraction)
}

private def buildArgIndexes(originalArgIndexes: Array[Int]): Array[Int] =
if (originalArgIndexes != null) {
originalArgIndexes
} else {
Array()
}

/**
* Inserts an COUNT(*) aggregate call if needed. The COUNT(*) aggregate call is used
* to count the number of added and retracted input records.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,36 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest
env.getCheckpointConfig.enableUnalignedCheckpoints(false)
}

@Test
def testRowNumberOnOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
tEnv.registerTable("MyTable", t)
val sqlQuery = "SELECT a, ROW_NUMBER() OVER (PARTITION BY a ORDER BY proctime()) FROM MyTable"

val sink = new TestingAppendSink
tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
env.execute()

val expected = List(
"1,1",
"2,1",
"2,2",
"3,1",
"3,2",
"3,3",
"4,1",
"4,2",
"4,3",
"4,4",
"5,1",
"5,2",
"5,3",
"5,4",
"5,5")
assertEquals(expected, sink.getAppendResults)
}

@Test
def testProcTimeBoundedPartitionedRowsOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
Expand Down

0 comments on commit ebe907b

Please sign in to comment.