Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18406][CORE] Race between end-of-task and completion iterator read lock release #18076

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging {

/**
* Release a lock on the given block.
* In case a TaskContext is not propagated properly to all child threads for the task, we fail to
* get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock.
*
* See SPARK-18406 for more discussion of this issue.
*/
def unlock(blockId: BlockId): Unit = synchronized {
logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId")
def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized {
val taskId = taskAttemptId.getOrElse(currentTaskAttemptId)
logTrace(s"Task $taskId releasing lock for $blockId")
val info = get(blockId).getOrElse {
throw new IllegalStateException(s"Block $blockId not found")
}
if (info.writerTask != BlockInfo.NO_WRITER) {
info.writerTask = BlockInfo.NO_WRITER
writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
writeLocksByTask.removeBinding(taskId, blockId)
} else {
assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
info.readerCount -= 1
val countsForTask = readLocksByTask(currentTaskAttemptId)
val countsForTask = readLocksByTask(taskId)
val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
assert(newPinCountForTask >= 0,
s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it")
s"Task $taskId release lock on block $blockId more times than it acquired it")
}
notifyAll()
}
Expand Down
29 changes: 22 additions & 7 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,18 @@ private[spark] class BlockManager(
case Some(info) =>
val level = info.level
logDebug(s"Level for block $blockId is $level")
val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
.getOrElse(BlockInfo.NON_TASK_WRITER)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can leave out the .getOrElse here and just pass in the Option itself into releaseLock. This helps to avoid exposure of BlockInfo.NON_TASK_WRITER here. Not a huge deal but just a minor nit.

if (level.useMemory && memoryStore.contains(blockId)) {
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
} else {
serializerManager.dataDeserializeStream(
blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
}
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a one-line comment before this line which references SPARK-18406, something like

"We need to capture the current taskId in case the iterator completion is triggered from a different thread which does not have TaskContext set; see SPARK-18406 for discussion"

or similar.

val ci = CompletionIterator[Any, Iterator[Any]](iter, {
releaseLock(blockId, Some(taskAttemptId))
})
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
val diskData = diskStore.getBytes(blockId)
Expand All @@ -525,8 +529,9 @@ private[spark] class BlockManager(
serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
}
}
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn,
releaseLockAndDispose(blockId, diskData))
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
releaseLockAndDispose(blockId, diskData, Some(taskAttemptId))
})
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
} else {
handleLocalReadFailure(blockId)
Expand Down Expand Up @@ -713,8 +718,15 @@ private[spark] class BlockManager(
/**
* Release a lock on the given block.
*/
def releaseLock(blockId: BlockId): Unit = {
blockInfoManager.unlock(blockId)
def releaseLock(blockId: BlockId): Unit = releaseLock(blockId, taskAttemptId = None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to overload here? Why not just have a single releaseLock method with a default argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact BlockManager extends BlockDataManager, so it have to override the releaseLock(blockId: BlockId) method, thus we keep this and implement a new method that accepts the new argument taskAttemptId.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's only one implementation of BlockDataManager these days, though? Since that's an internal interface maybe we could change it there, too?


/**
* Release a lock on the given block with explicit TID.
* This method should be used in case we can't get the correct TID from TaskContext, for example,
* the input iterator of a cached RDD iterates to the end in a child thread.
*/
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit = {
blockInfoManager.unlock(blockId, taskAttemptId)
}

/**
Expand Down Expand Up @@ -1467,8 +1479,11 @@ private[spark] class BlockManager(
}
}

def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = {
blockInfoManager.unlock(blockId)
def releaseLockAndDispose(
blockId: BlockId,
data: BlockData,
taskAttemptId: Option[Long] = None): Unit = {
releaseLock(blockId, taskAttemptId)
data.dispose()
}

Expand Down
18 changes: 17 additions & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
import org.apache.spark._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDDSuiteUtils._
import org.apache.spark.util.Utils
import org.apache.spark.util.{ThreadUtils, Utils}

class RDDSuite extends SparkFunSuite with SharedSparkContext {
var tempDir: File = _
Expand Down Expand Up @@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(totalPartitionCount == 10)
}

test("SPARK-18406: race between end-of-task and completion iterator read lock release") {
val rdd = sc.parallelize(1 to 1000, 10)
rdd.cache()

rdd.mapPartitions { iter =>
ThreadUtils.runInNewThread("TestThread") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice use of this helper method. I wasn't aware of this, but it's pretty nice. I'll use it in my own tests going forward.

// Iterate to the end of the input iterator, to cause the CompletionIterator completion to
// fire outside of the task's main thread.
while (iter.hasNext) {
iter.next()
}
iter
}
}.collect()
}

// NOTE
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
// running after them and if they access sc those tests will fail as sc is already closed, because
Expand Down