Skip to content

Commit

Permalink
Use ContextAwareIterator to stop consuming after the task ends.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Oct 29, 2020
1 parent 2639ad4 commit b34805c
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests/test_pandas_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# limitations under the License.
#
import os
import shutil
import tempfile
import time
import unittest

from pyspark.sql import Row
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message

Expand Down Expand Up @@ -112,6 +115,25 @@ def func(iterator):
expected = df.collect()
self.assertEquals(actual, expected)

# SPARK-33277
def test_map_in_pandas_with_column_vector(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 200000, 1, 1).write.parquet(path)

def func(iterator):
for pdf in iterator:
yield pd.DataFrame({'id': [0] * len(pdf)})

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).mapInPandas(func, 'id long').head(), Row(0))
finally:
shutil.rmtree(path)


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_map import * # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,25 @@ def test_datasource_with_udf(self):
finally:
shutil.rmtree(path)

# SPARK-33277
def test_pandas_udf_with_column_vector(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 200000, 1, 1).write.parquet(path)

@pandas_udf(LongType())
def udf(x):
return pd.Series([0] * len(x))

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).select(udf('id')).head(), Row(0))
finally:
shutil.rmtree(path)


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_udf_scalar import * # noqa: F401
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,26 @@ def test_udf_cache(self):
self.assertEqual(df.select(udf(func)("id"))._jdf.queryExecution()
.withCachedData().getClass().getSimpleName(), 'InMemoryRelation')

# SPARK-33277
def test_udf_with_column_vector(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 100000, 1, 1).write.parquet(path)

def f(x):
return 0

fUdf = udf(f, LongType())

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).select(fUdf('id')).head(), Row(0))
finally:
shutil.rmtree(path)


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ trait EvalPythonExec extends UnaryExecNode {

inputRDD.mapPartitions { iter =>
val context = TaskContext.get()
val contextAwareIterator = new ContextAwareIterator(iter, context)

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
Expand Down Expand Up @@ -120,7 +121,7 @@ trait EvalPythonExec extends UnaryExecNode {
}.toSeq)

// Add rows to queue to join later with the result.
val projectedRowIter = iter.map { inputRow =>
val projectedRowIter = contextAwareIterator.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
projection(inputRow)
}
Expand All @@ -137,3 +138,19 @@ trait EvalPythonExec extends UnaryExecNode {
}
}
}

/**
* A TaskContext aware iterator.
*
* As the Python evaluation consumes the parent iterator in a separate thread,
* it could consume more data from the parent even after the task ends and the parent is closed.
* Thus, we should use ContextAwareIterator to stop consuming after the task ends.
*/
private[spark] class ContextAwareIterator[IN](iter: Iterator[IN], context: TaskContext)
extends Iterator[IN] {

override def hasNext: Boolean =
!context.isCompleted() && !context.isInterrupted() && iter.hasNext

override def next(): IN = iter.next()
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ case class MapInPandasExec(
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val outputTypes = child.schema

val context = TaskContext.get()
val contextAwareIterator = new ContextAwareIterator(inputIter, context)

// Here we wrap it via another row so that Python sides understand it
// as a DataFrame.
val wrappedIter = inputIter.map(InternalRow(_))
val wrappedIter = contextAwareIterator.map(InternalRow(_))

// DO NOT use iter.grouped(). See BatchIterator.
val batchIter =
if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter)

val context = TaskContext.get()

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
Expand Down

0 comments on commit b34805c

Please sign in to comment.