From 6235132a8ce64bb12d825d0a65e5dd052d1ee647 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 2 May 2017 22:44:27 -0700 Subject: [PATCH] [SPARK-20567] Lazily bind in GenerateExec It is not valid to eagerly bind with the child's output as this causes failures when we attempt to canonicalize the plan (replacing the attribute references with dummies). Author: Michael Armbrust Closes #17838 from marmbrus/fixBindExplode. --- .../spark/sql/execution/GenerateExec.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 1812a1152cb48..c35e5638e9273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -78,7 +78,7 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) + lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f796a4cb4a398..4345a70601c34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -69,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte ) } + test("count distinct") { + val inputData = MemoryStream[(Int, Seq[Int])] + + val aggregated = + inputData.toDF() + .select($"*", explode($"_2") as 'value) + .groupBy($"_1") + .agg(size(collect_set($"value"))) + .as[(Int, Int)] + + testStream(aggregated, Update)( + AddData(inputData, (1, Seq(1, 2))), + CheckLastBatch((1, 2)) + ) + } + test("simple count, complete mode") { val inputData = MemoryStream[Int]