diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c4335143570fa..f613e342d7d17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -97,11 +97,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi projection(inputRow) } + val context = TaskContext.get() + val inputIterator = ArrowConverters.toPayloadIterator( - projectedRowIter, schema, conf.arrowMaxRecordsPerBatch). + projectedRowIter, schema, conf.arrowMaxRecordsPerBatch, context). map(_.asPythonSerializable) - val context = TaskContext.get() + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex. + map { case (attr, i) => attr.withName(s"_$i") }) // Output iterator for results from Python. val outputIterator = new PythonRunner( @@ -112,7 +115,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val resultProj = UnsafeProjection.create(output, output) val outputRowIterator = ArrowConverters.fromPayloadIterator( - outputIterator.map(ArrowPayload(_))) + outputIterator.map(new ArrowPayload(_)), schemaOut, context) outputRowIterator.map { outputRow => resultProj(joined(queue.remove(), outputRow))