Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8029] Robust shuffle writer #9610

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
Expand Down Expand Up @@ -155,9 +155,10 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
writer.commitAndClose();
}

partitionLengths =
writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
partitionLengths = writePartitionedFile(tmp);
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.LZFCompressionCodec;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
Expand All @@ -53,7 +53,7 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;

@Private
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
Expand Down Expand Up @@ -206,16 +206,18 @@ void closeAndWriteOutput() throws IOException {
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = Utils.tempFileWith(output);
try {
partitionLengths = mergeSpills(spills);
partitionLengths = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
}
}
}
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

Expand Down Expand Up @@ -248,8 +250,7 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
private long[] mergeSpills(SpillInfo[] spills) throws IOException {
final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle

import java.io.File
import java.util.UUID
import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -84,17 +86,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
// Because of previous failures, the shuffle file may already exist on this machine.
// If so, remove it.
if (blockFile.exists) {
if (blockFile.delete()) {
logInfo(s"Removed existing shuffle file $blockFile")
} else {
logWarning(s"Failed to remove existing shuffle file $blockFile")
}
}
blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
writeMetrics)
val tmp = new File(blockFile.getAbsolutePath + "." + UUID.randomUUID())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use your new method Utils.withTempFile

blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
}
}
// Creating the file to write to and creating a disk writer both involve interacting with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ import java.io._

import com.google.common.io.ByteStreams

import org.apache.spark.{SparkConf, SparkEnv, Logging}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.storage._
import org.apache.spark.util.Utils

import IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.{SparkEnv, Logging, SparkConf}

/**
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
Expand All @@ -40,10 +39,17 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
*/
// Note: Changes to the format in this file should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver
private[spark] class IndexShuffleBlockResolver(
conf: SparkConf,
_blockManager: BlockManager = null)
extends ShuffleBlockResolver
with Logging {

private lazy val blockManager = SparkEnv.get.blockManager
private lazy val blockManager = if (_blockManager == null) {
SparkEnv.get.blockManager
} else {
_blockManager
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or

Option(_blockManager).getOrElse(SparkEnv.get.blockManager)


private val transportConf = SparkTransportConf.fromSparkConf(conf)

Expand Down Expand Up @@ -74,14 +80,69 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
}

/**
* Check whether there are index file and data file also they are matched with each other, returns
* the lengths of each block in data file, if there are matched, or return null.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/**
 * Check whether the given index and data files match each other.
 * If so, return the partition lengths in the data file. Otherwise return null.
 */

*/
private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
val lengths = new Array[Long](blocks)
if (index.length() == (blocks + 1) * 8) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of the nested if-else's, it might be clearer to write this as:

// can you add a comment here to explain this check
if (index.length() != (blocks + 1) * 8) {
  return null
}

// Read the lengths of blocks
val f = try {
new FileInputStream(index)
} catch {
case e: IOException =>
return null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we at least logWarning when this happens? Same in L113

}
val in = new DataInputStream(new BufferedInputStream(f))
try {
// Convert the offsets into lengths of each block
var offset = in.readLong()
if (offset != 0L) {
return null
}
var i = 0
while (i < blocks) {
val off = in.readLong()
lengths(i) = off - offset
offset = off
i += 1
}
} catch {
case e: IOException =>
return null
} finally {
in.close()
}

val size = lengths.reduce(_ + _)
// `length` returns 0 if it not exists.
if (data.length() == size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just

if (data.length() == lengths.sum) {
  lengths
} else {
  // partition lengths don't match
  return null
}

lengths
} else {
null
}
} else {
null
}
}

/**
* Write an index file with the offsets of each block, plus a final offset at the end for the
* end of the output file. This will be used by getBlockData to figure out where each block
* begins and ends.
*
* It will commit the data and index file as an atomic operation, use the existed ones (lengths of
* blocks will be refreshed), or replace them with new ones.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add that this modifies the contents of lengths if existing data and index files exist and are matching.

* */
def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
def writeIndexFileAndCommit(
shuffleId: Int,
mapId: Int,
lengths: Array[Long],
dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
val indexTmp = Utils.tempFileWith(indexFile)
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
Utils.tryWithSafeFinally {
// We take in lengths of each block, need to convert it to offsets.
var offset = 0L
Expand All @@ -93,6 +154,31 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
} {
out.close()
}

val dataFile = getDataFile(shuffleId, mapId)
// Note: there is only one IndexShuffleBlockResolver per executor
synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a big comment explaining what is going on here? Also worth noting that there is only one IndexShuffleBlockResolver per executor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this synchronized for atomic renames? If so, can you mention in a comment?

val existedLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existingLengths

if (existedLengths != null) {
// Use the lengths of existed output for MapStatus
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Another attempt for the same task has already written our map outputs successfully,
// so just use the existing partition lengths and delete our temporary map outputs.

System.arraycopy(existedLengths, 0, lengths, 0, lengths.length)
dataTmp.delete()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

though unlikely, will this throw NPE? dataTmp can be null

indexTmp.delete()
} else {
if (indexFile.exists()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// This is the first successful attempt in writing the map outputs for this task,
// so override any existing index and data files with the ones we wrote.

indexFile.delete()
}
if (dataFile.exists()) {
dataFile.delete()
}
if (!indexTmp.renameTo(indexFile)) {
throw new IOException("fail to rename data file " + indexTmp + " to " + indexFile)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data -> index

}
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
throw new IOException("fail to rename data file " + dataTmp + " to " + dataFile)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a particular flaw here, but its a bit hard to follow since its a mix of first-attempt-wins and last-attempt wins. First attempt if there is a data file & index file; last attempt if its only an index file. the problem w/ last-attempt is that this delete will fail on windows if the file is open for reading, I believe. Though we can't get around that because SPARK-4085 always requires us to delete some files that might be open, in which case we hope that we don't run into this race again on the next retry. It would be nice to minimize that case, though. We'd be closer to first-attempt-wins if we always wrote a dataFile, even if its empty when dataTmp == null.

There is also an issue w/ mapStatus & non-deterministic data. It might not matter which output you get, but the mapstatus should be consistent with the data that is read. If attempt 1 writes non-empty outputs a,b,c, and attempt 2 writes non-empty outputs d,e,f (which are not committed), the reduce tasks might get the mapstatus for attempt 2, look for outputs d,e,f, and get nothing but empty blocks. Matei had suggested writing the mapstatus to a file, so that subsequent attempts always return the map status corresponding to the first successful attempt.

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check

checkIndexAndDataFile(indexTmp, dataTmp, lengths.length) != null

what happens if this fails? Should we throw an exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will slowdown the normal path, I think it's not needed.

}
}
}

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.hash

import java.io.IOException

import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
Expand Down Expand Up @@ -106,6 +108,28 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
// rename all shuffle files to final paths
// Note: there is only one ShuffleBlockResolver in executor
shuffleBlockResolver.synchronized {
shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
val output = blockManager.diskBlockManager.getFile(writer.blockId)
if (sizes(i) > 0) {
if (output.exists()) {
// update the size of output for MapStatus
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or

// Use length of existing file and delete our own temporary one

sizes(i) = output.length()
writer.file.delete()
} else {
if (!writer.file.renameTo(output)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Commit by renaming our temporary file to something the fetcher expects

throw new IOException(s"fail to rename ${writer.file} to $output")
}
}
} else {
if (output.exists()) {
output.delete()
}
}
}
}
MapStatus(blockManager.shuffleServerId, sizes)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.shuffle.sort
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
Expand Down Expand Up @@ -65,11 +66,11 @@ private[spark] class SortShuffleWriter[K, V, C](
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you call these outputTmp or something so it's slightly easier to follow? (here and other places)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only one file output here, I think it's obvious

val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)

val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import java.io._
import java.nio.{ByteBuffer, MappedByteBuffer}

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{ExecutionContext, Await, Future}
import scala.concurrent.duration._
import scala.util.control.NonFatal
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.Random
import scala.util.control.NonFatal

import sun.nio.ch.DirectBuffer

Expand All @@ -38,9 +38,8 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.serializer.{Serializer, SerializerInstance}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._

private[spark] sealed trait BlockValues
Expand Down Expand Up @@ -660,7 +659,7 @@ private[spark] class BlockManager(
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
syncWrites, writeMetrics)
syncWrites, writeMetrics, blockId)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ import org.apache.spark.util.Utils
* reopened again.
*/
private[spark] class DiskBlockObjectWriter(
file: File,
val file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics)
writeMetrics: ShuffleWriteMetrics,
val blockId: BlockId = null)
extends OutputStream
with Logging {

Expand Down
12 changes: 9 additions & 3 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent._
import java.util.{Locale, Properties, Random, UUID}
import javax.net.ssl.HttpsURLConnection

import scala.collection.JavaConverters._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}

import com.google.common.io.{ByteStreams, Files}
Expand All @@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
import org.json4s._

import tachyon.TachyonURI
import tachyon.client.{TachyonFS, TachyonFile}

Expand Down Expand Up @@ -2169,6 +2168,13 @@ private[spark] object Utils extends Logging {
val resource = createResource
try f.apply(resource) finally resource.close()
}

/**
* Returns a path of temporary file which is in the same directory with `path`.
*/
def tempFileWith(path: File): File = {
new File(path.getAbsolutePath + "." + UUID.randomUUID())
}
}

/**
Expand Down
Loading