Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel #20179

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv(

val pipe = Pipe.open()
val source = new FileDownloadChannel(pipe.source())
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
val callback = new FileDownloadCallback(pipe.sink(), source, client)
client.stream(parsedUri.getPath(), callback)
} catch {
case e: Exception =>
pipe.sink().close()
source.close()
throw e
}
})(catchBlock = {
pipe.sink().close()
source.close()
})

source
}
Expand Down Expand Up @@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv(
fileDownloadFactory.createClient(host, port)
}

private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel {

@volatile private var error: Throwable = _

def setError(e: Throwable): Unit = {
// This setError callback is invoked by internal RPC threads in order to propagate remote
// exceptions to application-level threads which are reading from this channel. When an
// RPC error occurs, the RPC system will call setError() and then will close the
// Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe
// sink will cause `source.read()` operations to return EOF, unblocking the application-level
// reading thread. Thus there is no need to actually call `source.close()` here in the
// onError() callback and, in fact, calling it here would be dangerous because the close()
// would be asynchronous with respect to the read() call and could trigger race-conditions
// that lead to data corruption. See the PR for SPARK-22982 for more details on this topic.
error = e
source.close()
}

override def read(dst: ByteBuffer): Int = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the caller of read would close the source channel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This currently happens in two places:

Try(source.read(dst)) match {
// See the documentation above in setError(): if an RPC error has occurred then setError()
// will be called to propagate the RPC error and then `source`'s corresponding
// Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate
// the remote RPC exception (and not any exceptions triggered by the pipe close, such as
// ChannelClosedException), hence this `error != null` check:
case _ if error != null => throw error
case Success(bytesRead) => bytesRead
case Failure(readErr) =>
if (error != null) {
throw error
} else {
throw readErr
}
case Failure(readErr) => throw readErr
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.shuffle

import java.io._

import com.google.common.io.ByteStreams
import java.nio.channels.Channels
import java.nio.file.Files

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver(
// find out the consolidated file, then the offset within that from our index
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)

val in = new DataInputStream(new FileInputStream(indexFile))
// SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code
// which is incorrectly using our file descriptor then this code will fetch the wrong offsets
// (which may cause a reducer to be sent a different reducer's data). The explicit position
// checks added here were a useful debugging aid during SPARK-22982 and may help prevent this
// class of issue from re-occurring in the future which is why they are left here even though
// SPARK-22982 is fixed.
val channel = Files.newByteChannel(indexFile.toPath)
channel.position(blockId.reduceId * 8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zsxwing I recall you mentioned about a performance issue with skipping data in the file channel, do we have this problem here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made sure to incorporate @zsxwing's changes here. The problem originally related to calling skip(), but this change is from his fix to explicitly use position on a FileChannel instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not clear whether the change here is related to "asynchronous close()" issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used to detect bugs like "asynchronous close()" earlier in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some more background: the asynchronous close() bug can cause reads from a closed-and-subsequently-reassigned file descriptor number and in principle this can affect almost any IO operation anywhere in the application. For example, if the closed file descriptor number is immediately recycled by opening a socket then the invalid read can cause that socket read to miss data (since the data would have been consumed by the invalid reader and won't be delivered to the legitimate new user of the file descriptor).

Given this, I see how it might be puzzling that this patch is adding a check only here. There are two reasons for this:

  1. Many other IO operations have implicit checksumming such that dropping data due to an invalid read be detected and cause an exception. For example, many compression codecs have block-level checksumming (and magic numbers at the beginning of the stream), so dropping data (especially at the start of a read) will be detected. This particular shuffle index file, however, does not have mechanisms to detect corruption: skipping forward in the read by a multiple of 8 bytes will still read structurally-valid data (but it will be the wrong data, causing the wrong output to be read from the shuffle data file).

  2. In the investigation which uncovered this bug, the invalid reads were predominantly impacting shuffle index lookups for reading local blocks. In a nutshell, there's a subtle race condition where Janino codegen compilation triggers attempted remote classloading of classes which don't exist, triggering the error-handling / error-propagation paths in FileDownloadChannel and causing the invalid asynchronous close() call to be performed. At the same time that this close() call was being performed, another task from the same stage attempts to read the shuffle index files of local blocks and experiences an invalid read due to the falsely-shared file descriptor.

    This is a very hard-to-trigger bug: we were only able to reproduce it on large clusters with very fast machines and shuffles that contain large numbers of map and reduce tasks (more shuffle blocks means more index file reads and more chances for the race to occur; faster machines increase the likelihood of the race occurring; larger clusters give us more chances for the error to occur). In our reproduction, this race occurred on a microsecond timescale (measured via kernel syscall tracing) and occurred relatively rarely, requiring many iterations until we could trigger a reproduction.

While investigating, I added these checks so that the index read fails-fast when this issue occurs, which made it significantly easier to reproduce and diagnose the root cause (fixed by the other changes in this patch).

There are a number of interesting details in the story of how we worked from the original high-level data corruption symptom to this low-level IO bug. I'll see about writing up the complete story in a blog post at some point.

val in = new DataInputStream(Channels.newInputStream(channel))
try {
ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
val actualPosition = channel.position()
val expectedPosition = blockId.reduceId * 8 + 16
if (actualPosition != expectedPosition) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an assert assert(actualPosition == expectedPosition, $msg) is better for things like this so we may elide them using compiler flags if desired

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this, but I don't think there's ever a case where we want to elide this particular check: if we read an incorrect offset here then there's (potentially) no other mechanism to detect this error, leading to silent wrong answers.

throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we'd better change to some specific Exception type here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions for a better exception subtype? I don't expect this to be a recoverable error and wanted to avoid the possibility that downstream code catches and handles this error. Maybe I should go further and make it a RuntimeException to make it even more fatal?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks!

s"expected $expectedPosition but actual position was $actualPosition.")
}
new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
Expand Down