Skip to content

Commit

Permalink
[SPARK-18406][CORE] Race between end-of-task and completion iterator …
Browse files Browse the repository at this point in the history
…read lock release

When a TaskContext is not propagated properly to all child threads for the task, just like the reported cases in this issue, we fail to get to TID from TaskContext and that causes unable to release the lock and assertion failures. To resolve this, we have to explicitly pass the TID value to the `unlock` method.

Add new failing regression test case in `RDDSuite`.

Author: Xingbo Jiang <xingbo.jiang@databricks.com>

Closes apache#18076 from jiangxb1987/completion-iterator.
  • Loading branch information
jiangxb1987 committed May 24, 2017
1 parent 2f68631 commit aa59b1b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ trait BlockDataManager {
/**
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
*/
def releaseLock(blockId: BlockId): Unit
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging {

/**
* Release a lock on the given block.
* In case a TaskContext is not propagated properly to all child threads for the task, we fail to
* get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock.
*
* See SPARK-18406 for more discussion of this issue.
*/
def unlock(blockId: BlockId): Unit = synchronized {
logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId")
def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized {
val taskId = taskAttemptId.getOrElse(currentTaskAttemptId)
logTrace(s"Task $taskId releasing lock for $blockId")
val info = get(blockId).getOrElse {
throw new IllegalStateException(s"Block $blockId not found")
}
if (info.writerTask != BlockInfo.NO_WRITER) {
info.writerTask = BlockInfo.NO_WRITER
writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
writeLocksByTask.removeBinding(taskId, blockId)
} else {
assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
info.readerCount -= 1
val countsForTask = readLocksByTask(currentTaskAttemptId)
val countsForTask = readLocksByTask(taskId)
val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
assert(newPinCountForTask >= 0,
s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it")
s"Task $taskId release lock on block $blockId more times than it acquired it")
}
notifyAll()
}
Expand Down
21 changes: 16 additions & 5 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -454,14 +454,20 @@ private[spark] class BlockManager(
case Some(info) =>
val level = info.level
logDebug(s"Level for block $blockId is $level")
val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
if (level.useMemory && memoryStore.contains(blockId)) {
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
} else {
serializerManager.dataDeserializeStream(
blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
}
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
// We need to capture the current taskId in case the iterator completion is triggered
// from a different thread which does not have TaskContext set; see SPARK-18406 for
// discussion.
val ci = CompletionIterator[Any, Iterator[Any]](iter, {
releaseLock(blockId, taskAttemptId)
})
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
val iterToReturn: Iterator[Any] = {
Expand All @@ -478,7 +484,9 @@ private[spark] class BlockManager(
serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
}
}
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
releaseLock(blockId, taskAttemptId)
})
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
} else {
handleLocalReadFailure(blockId)
Expand Down Expand Up @@ -654,10 +662,13 @@ private[spark] class BlockManager(
}

/**
* Release a lock on the given block.
* Release a lock on the given block with explicit TID.
* The param `taskAttemptId` should be passed in case we can't get the correct TID from
* TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child
* thread.
*/
def releaseLock(blockId: BlockId): Unit = {
blockInfoManager.unlock(blockId)
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = {
blockInfoManager.unlock(blockId, taskAttemptId)
}

/**
Expand Down
18 changes: 17 additions & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
import org.apache.spark._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDDSuiteUtils._
import org.apache.spark.util.Utils
import org.apache.spark.util.{ThreadUtils, Utils}

class RDDSuite extends SparkFunSuite with SharedSparkContext {
var tempDir: File = _
Expand Down Expand Up @@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(totalPartitionCount == 10)
}

test("SPARK-18406: race between end-of-task and completion iterator read lock release") {
val rdd = sc.parallelize(1 to 1000, 10)
rdd.cache()

rdd.mapPartitions { iter =>
ThreadUtils.runInNewThread("TestThread") {
// Iterate to the end of the input iterator, to cause the CompletionIterator completion to
// fire outside of the task's main thread.
while (iter.hasNext) {
iter.next()
}
iter
}
}.collect()
}

// NOTE
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
// running after them and if they access sc those tests will fail as sc is already closed, because
Expand Down

0 comments on commit aa59b1b

Please sign in to comment.