From b8e13b0aeae4a25e54a4c4b8a97bc2c19edd37bb Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 31 Jul 2019 22:40:01 +0800 Subject: [PATCH] [SPARK-28153][PYTHON] Use AtomicReference at InputFileBlockHolder (to support input_file_name with Python UDF) ## What changes were proposed in this pull request? This PR proposes to use `AtomicReference` so that parent and child threads can access to the same file block holder. Python UDF expressions are turned to a plan and then it launches a separate thread to consume the input iterator. In the separate child thread, the iterator sets `InputFileBlockHolder.set` before the parent does which the parent thread is unable to read later. 1. In this separate child thread, if it happens to call `InputFileBlockHolder.set` first without initialization of the parent's thread local (which is done when the `ThreadLocal.get()` is first called), the child thread seems calling its own `initialValue` to initialize. 2. After that, the parent calls its own `initialValue` to initializes at the first call of `ThreadLocal.get()`. 3. Both now have two different references. Updating at child isn't reflected to parent. This PR fixes it via initializing parent's thread local with `AtomicReference` for file status so that they can be used in each task, and children thread's update is reflected. I also tried to explain this a bit more at https://github.com/apache/spark/pull/24958#discussion_r297203041. ## How was this patch tested? Manually tested and unittest was added. Closes #24958 from HyukjinKwon/SPARK-28153. Authored-by: HyukjinKwon Signed-off-by: Wenchen Fan --- .../spark/rdd/InputFileBlockHolder.scala | 29 ++++++++++++++----- .../org/apache/spark/scheduler/Task.scala | 1 + python/pyspark/sql/tests/test_functions.py | 6 ++++ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala index ff2f58d81142d..bfe8152d4dee2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import java.util.concurrent.atomic.AtomicReference + import org.apache.spark.unsafe.types.UTF8String /** @@ -40,26 +42,33 @@ private[spark] object InputFileBlockHolder { /** * The thread variable for the name of the current file being read. This is used by * the InputFileName function in Spark SQL. + * + * @note `inputBlock` works somewhat complicatedly. It guarantees that `initialValue` + * is called at the start of a task. Therefore, one atomic reference is created in the task + * thread. After that, read and write happen to the same atomic reference across the parent and + * children threads. This is in order to support a case where write happens in a child thread + * but read happens at its parent thread, for instance, Python UDF execution. See SPARK-28153. */ - private[this] val inputBlock: InheritableThreadLocal[FileBlock] = - new InheritableThreadLocal[FileBlock] { - override protected def initialValue(): FileBlock = new FileBlock + private[this] val inputBlock: InheritableThreadLocal[AtomicReference[FileBlock]] = + new InheritableThreadLocal[AtomicReference[FileBlock]] { + override protected def initialValue(): AtomicReference[FileBlock] = + new AtomicReference(new FileBlock) } /** * Returns the holding file name or empty string if it is unknown. */ - def getInputFilePath: UTF8String = inputBlock.get().filePath + def getInputFilePath: UTF8String = inputBlock.get().get().filePath /** * Returns the starting offset of the block currently being read, or -1 if it is unknown. */ - def getStartOffset: Long = inputBlock.get().startOffset + def getStartOffset: Long = inputBlock.get().get().startOffset /** * Returns the length of the block being read, or -1 if it is unknown. */ - def getLength: Long = inputBlock.get().length + def getLength: Long = inputBlock.get().get().length /** * Sets the thread-local input block. @@ -68,11 +77,17 @@ private[spark] object InputFileBlockHolder { require(filePath != null, "filePath cannot be null") require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative") require(length >= 0, s"length ($length) cannot be negative") - inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) + inputBlock.get().set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) } /** * Clears the input file block to default value. */ def unset(): Unit = inputBlock.remove() + + /** + * Initializes thread local by explicitly getting the value. It triggers ThreadLocal's + * initialValue in the parent thread. + */ + def initialize(): Unit = inputBlock.get() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9dfbf862a9c57..01828f860bd5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -104,6 +104,7 @@ private[spark] abstract class Task[T]( taskContext } + InputFileBlockHolder.initialize() TaskContext.setTaskContext(context) taskThread = Thread.currentThread() diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 64f2fd6a3919f..36e1e8a660f00 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -304,6 +304,12 @@ def test_array_repeat(self): df.select(array_repeat("id", lit(3))).toDF("val").collect(), ) + def test_input_file_name_udf(self): + df = self.spark.read.text('python/test_support/hello/hello.txt') + df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file')) + file_name = df.collect()[0].file + self.assertTrue("python/test_support/hello/hello.txt" in file_name) + if __name__ == "__main__": import unittest