Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 11, 2015
1 parent 9f0d2f9 commit 55485a9
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.UUID;
import javax.annotation.Nullable;

import scala.None$;
Expand Down Expand Up @@ -126,7 +125,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
Expand Down Expand Up @@ -157,19 +156,9 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
}

File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = new File(output.getAbsolutePath() + "." + UUID.randomUUID());
File tmp = Utils.tempFileWith(output);
partitionLengths = writePartitionedFile(tmp);
if (!output.exists()) {
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
if (output.exists()) {
output.delete();
}
if (!tmp.renameTo(output)) {
throw new IOException("fail to rename data file " + tmp + " to " + output);
}
} else {
tmp.delete();
}
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.*;
import java.nio.channels.FileChannel;
import java.util.Iterator;
import java.util.UUID;

import scala.Option;
import scala.Product2;
Expand Down Expand Up @@ -54,6 +53,7 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.Utils;

@Private
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
Expand Down Expand Up @@ -207,7 +207,7 @@ void closeAndWriteOutput() throws IOException {
sorter = null;
final long[] partitionLengths;
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = new File(output.getAbsolutePath() + "." + UUID.randomUUID());
final File tmp = Utils.tempFileWith(output);
try {
partitionLengths = mergeSpills(spills, tmp);
} finally {
Expand All @@ -217,17 +217,7 @@ void closeAndWriteOutput() throws IOException {
}
}
}
if (!output.exists()) {
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
if (output.exists()) {
output.delete();
}
if (!tmp.renameTo(output)) {
throw new IOException("fail to rename data file " + tmp + " to " + output);
}
} else {
tmp.delete();
}
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
package org.apache.spark.shuffle

import java.io._
import java.util.UUID

import com.google.common.io.ByteStreams

import org.apache.spark.{SparkConf, SparkEnv, Logging}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.storage._
import org.apache.spark.util.Utils

import IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.{Logging, SparkConf, SparkEnv}

/**
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
Expand Down Expand Up @@ -80,10 +78,10 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
* end of the output file. This will be used by getBlockData to figure out where each block
* begins and ends.
* */
def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long], dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
val tmp = new File(indexFile.getAbsolutePath + "." + UUID.randomUUID())
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(tmp)))
val indexTmp = Utils.tempFileWith(indexFile)
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
Utils.tryWithSafeFinally {
// We take in lengths of each block, need to convert it to offsets.
var offset = 0L
Expand All @@ -95,9 +93,28 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
} {
out.close()
}
indexFile.deleteOnExit()
if (!tmp.renameTo(indexFile)) {
throw new IOException(s"fail to rename index file $tmp to $indexFile")

val dataFile = getDataFile(shuffleId, mapId)
synchronized {
if (dataFile.exists() && indexFile.exists()) {
if (dataTmp != null && dataTmp.exists()) {
dataTmp.delete()
}
indexTmp.delete()
} else {
if (indexFile.exists()) {
indexFile.delete()
}
if (!indexTmp.renameTo(indexFile)) {
throw new IOException("fail to rename data file " + indexTmp + " to " + indexFile)
}
if (dataFile.exists()) {
dataFile.delete()
}
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
throw new IOException("fail to rename data file " + dataTmp + " to " + dataFile)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

package org.apache.spark.shuffle.sort

import java.io.{IOException, File}
import java.util.UUID

import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
Expand Down Expand Up @@ -69,21 +67,10 @@ private[spark] class SortShuffleWriter[K, V, C](
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = new File(output.getAbsolutePath + "." + UUID.randomUUID())
val tmp = Utils.tempFileWith(output)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
if (!output.exists()) {
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
if (output.exists()) {
output.delete()
}
if (!tmp.renameTo(output)) {
throw new IOException("fail to rename data file " + tmp + " to " + output)
}
} else {
tmp.delete()
}

shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}

Expand Down
12 changes: 9 additions & 3 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent._
import java.util.{Locale, Properties, Random, UUID}
import javax.net.ssl.HttpsURLConnection

import scala.collection.JavaConverters._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}

import com.google.common.io.{ByteStreams, Files}
Expand All @@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
import org.json4s._

import tachyon.TachyonURI
import tachyon.client.{TachyonFS, TachyonFile}

Expand Down Expand Up @@ -2169,6 +2168,13 @@ private[spark] object Utils extends Logging {
val resource = createResource
try f.apply(resource) finally resource.close()
}

/**
* Returns a path of temporary file which is in the same directory with `path`.
*/
def tempFileWith(path: File): File = {
new File(path.getAbsolutePath + "." + UUID.randomUUID())
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,13 @@ public OutputStream answer(InvocationOnMock invocation) throws Throwable {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
File tmp = (File) invocationOnMock.getArguments()[3];
mergedOutputFile.delete();
tmp.renameTo(mergedOutputFile);
return null;
}
}).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
}).when(shuffleBlockResolver)
.writeIndexFile(anyInt(), anyInt(), any(long[].class), any(File.class));

when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
new Answer<Tuple2<TempShuffleBlockId, File>>() {
Expand Down

0 comments on commit 55485a9

Please sign in to comment.