Skip to content

Commit

Permalink
[SPARK-33277][PYSPARK][SQL][2.4] Use ContextAwareIterator to stop con…
Browse files Browse the repository at this point in the history
…sumin…

### What changes were proposed in this pull request?

This is a backport of #30177.

As the Python evaluation consumes the parent iterator in a separate thread, it could consume more data from the parent even after the task ends and the parent is closed. Thus, we should use `ContextAwareIterator` to stop consuming after the task ends.

### Why are the changes needed?

Python/Pandas UDF right after off-heap vectorized reader could cause executor crash.

E.g.,:

```py
spark.range(0, 100000, 1, 1).write.parquet(path)

spark.conf.set("spark.sql.columnVector.offheap.enabled", True)

def f(x):
    return 0

fUdf = udf(f, LongType())

spark.read.parquet(path).select(fUdf('id')).head()
```

This is because, the Python evaluation consumes the parent iterator in a separate thread and it consumes more data from the parent even after the task ends and the parent is closed. If an off-heap column vector exists in the parent iterator, it could cause segmentation fault which crashes the executor.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added tests, and manually.

Closes #30218 from ueshin/issues/SPARK-33277/2.4/python_pandas_udf.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
ueshin authored and HyukjinKwon committed Nov 2, 2020
1 parent a32178c commit cabf957
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
42 changes: 42 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3628,6 +3628,26 @@ def test_udf_in_subquery(self):
finally:
self.spark.catalog.dropTempView("v")

# SPARK-33277
def test_udf_with_column_vector(self):
path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 100000, 1, 1).write.parquet(path)

def f(x):
return 0

fUdf = udf(f, LongType())

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).select(fUdf('id')).head(), Row(0))
finally:
shutil.rmtree(path)


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down Expand Up @@ -5575,6 +5595,28 @@ def test_datasource_with_udf(self):
finally:
shutil.rmtree(path)

# SPARK-33277
def test_pandas_udf_with_column_vector(self):
import pandas as pd
from pyspark.sql.functions import pandas_udf

path = tempfile.mkdtemp()
shutil.rmtree(path)

try:
self.spark.range(0, 200000, 1, 1).write.parquet(path)

@pandas_udf(LongType())
def udf(x):
return pd.Series([0] * len(x))

for offheap in ["true", "false"]:
with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}):
self.assertEquals(
self.spark.read.parquet(path).select(udf('id')).head(), Row(0))
finally:
shutil.rmtree(path)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil

inputRDD.mapPartitions { iter =>
val context = TaskContext.get()
val contextAwareIterator = new ContextAwareIterator(iter, context)

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
Expand Down Expand Up @@ -119,7 +120,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil
})

// Add rows to queue to join later with the result.
val projectedRowIter = iter.map { inputRow =>
val projectedRowIter = contextAwareIterator.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
projection(inputRow)
}
Expand All @@ -136,3 +137,18 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chil
}
}
}

/**
* A TaskContext aware iterator.
*
* As the Python evaluation consumes the parent iterator in a separate thread,
* it could consume more data from the parent even after the task ends and the parent is closed.
* Thus, we should use ContextAwareIterator to stop consuming after the task ends.
*/
class ContextAwareIterator[IN](iter: Iterator[IN], context: TaskContext) extends Iterator[IN] {

override def hasNext: Boolean =
!context.isCompleted() && !context.isInterrupted() && iter.hasNext

override def next(): IN = iter.next()
}

0 comments on commit cabf957

Please sign in to comment.