Skip to content

Commit

Permalink
Added tests and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler committed Jun 21, 2019
1 parent 4eaef85 commit 9be0110
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 85 deletions.
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest

from pyspark.rdd import PythonEvalType
from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
Expand Down Expand Up @@ -461,6 +462,18 @@ def test_register_vectorized_udf_basic(self):
expected = [1, 5]
self.assertEqual(actual, expected)

def test_grouped_with_empty_partition(self):
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
expected = [Row(id=1, sum=5), Row(id=2, x=4)]
num_parts = len(data) + 1
df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))

f = pandas_udf(lambda x: x.sum(),
'int', PandasUDFType.GROUPED_AGG)

result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect()
self.assertEqual(result, expected)


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf_grouped_agg import *
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,18 @@ def test_mixed_scalar_udfs_followed_by_grouby_apply(self):

self.assertEquals(result.collect()[0]['sum'], 165)

def test_grouped_with_empty_partition(self):
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)]
num_parts = len(data) + 1
df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))

f = pandas_udf(lambda pdf: pdf.assign(x=pdf['x'].sum()),
'id long, x int', PandasUDFType.GROUPED_MAP)

result = df.groupBy('id').apply(f).collect()
self.assertEqual(result, expected)


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf_grouped_map import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,58 +105,53 @@ case class AggregateInPandasExec(
StructField(s"_$i", dt)
})

inputRDD.mapPartitionsInternal { iter =>
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
val prunedProj = UnsafeProjection.create(allInputs, child.output)

// Only execute on non-empty partitions
if (iter.nonEmpty) {
val prunedProj = UnsafeProjection.create(allInputs, child.output)

val grouped = if (groupingExpressions.isEmpty) {
// Use an empty unsafe row as a place holder for the grouping key
Iterator((new UnsafeRow(), iter))
} else {
GroupedIterator(iter, groupingExpressions, child.output)
}.map { case (key, rows) =>
(key, rows.map(prunedProj))
}
val grouped = if (groupingExpressions.isEmpty) {
// Use an empty unsafe row as a place holder for the grouping key
Iterator((new UnsafeRow(), iter))
} else {
GroupedIterator(iter, groupingExpressions, child.output)
}.map { case (key, rows) =>
(key, rows.map(prunedProj))
}

val context = TaskContext.get()
val context = TaskContext.get()

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(context.taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length)
context.addTaskCompletionListener[Unit] { _ =>
queue.close()
}
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(context.taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length)
context.addTaskCompletionListener[Unit] { _ =>
queue.close()
}

// Add rows to queue to join later with the result.
val projectedRowIter = grouped.map { case (groupingKey, rows) =>
queue.add(groupingKey.asInstanceOf[UnsafeRow])
rows
}
// Add rows to queue to join later with the result.
val projectedRowIter = grouped.map { case (groupingKey, rows) =>
queue.add(groupingKey.asInstanceOf[UnsafeRow])
rows
}

val columnarBatchIter = new ArrowPythonRunner(
pyFuncs,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)

val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
val joined = new JoinedRow
val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)

columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, aggOutputRow)
resultProj(joinedRow)
}
} else {
Iterator.empty
val columnarBatchIter = new ArrowPythonRunner(
pyFuncs,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)

val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
val joined = new JoinedRow
val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)

columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, aggOutputRow)
resultProj(joinedRow)
}
}
}}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,45 +125,38 @@ case class FlatMapGroupsInPandasExec(
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
val dedupSchema = StructType.fromAttributes(dedupAttributes)

inputRDD.mapPartitionsInternal { iter =>

// Only execute on non-empty partitions
if (iter.nonEmpty) {

val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
}
}

val context = TaskContext.get()

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
argOffsets,
dedupSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(grouped, context.partitionId(), context)

val unsafeProj = UnsafeProjection.create(output, output)

columnarBatchIter.flatMap { batch =>
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = output.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)

// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
Iterator.empty
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
}
}
}

val context = TaskContext.get()

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
argOffsets,
dedupSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(grouped, context.partitionId(), context)

val unsafeProj = UnsafeProjection.create(output, output)

columnarBatchIter.flatMap { batch =>
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = output.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)
}}
}
}

0 comments on commit 9be0110

Please sign in to comment.