Skip to content

Commit

Permalink
Various changes to the API and state assumptions in writers. (#8)
Browse files Browse the repository at this point in the history
Proposes the following changes to the API:

- closeAndGetLength() is split into separate close() and getNumBytesWritten() operations.
- openChannel and openStream renamed to toChannel and toStream

Proposes the following changes to the implementation:

- close() in the default implementation now persists the length in the partitionLengths array
- getNumBytesWritten() doesn't necessitate the writer's resources to be closed ahead of it
- Don't close the stream in BypassMergeSortShuffleWriter - only close it in DefaultShufflePartitionWriter#close (for consistency with how we treat channels)
  • Loading branch information
mccheah authored and ifilonenko committed Mar 26, 2019
1 parent 9e3f05c commit 9f6230b
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.api.shuffle;

import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
Expand All @@ -31,12 +32,43 @@
* @since 3.0.0
*/
@Experimental
public interface ShufflePartitionWriter {
OutputStream openStream() throws IOException;
public interface ShufflePartitionWriter extends Closeable {

long closeAndGetLength();
/**
* Returns an underlying {@link OutputStream} that can write bytes to the underlying data store.
* <p>
* Note that this stream itself is not closed by the caller; close the stream in
* the implementation of this class's {@link #close()}..
*/
OutputStream toStream() throws IOException;

default WritableByteChannel openChannel() throws IOException {
return Channels.newChannel(openStream());
/**
* Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data
* store.
* <p>
* Note that this channel itself is not closed by the caller; close the stream in
* the implementation of this class's {@link #close()}..
*/
default WritableByteChannel toChannel() throws IOException {
return Channels.newChannel(toStream());
}

/**
* Get the number of bytes written by this writer's stream returned by {@link #toStream()} or
* the channel returned by {@link #toChannel()}.
*/
long getNumBytesWritten();

/**
* Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()}
* or {@link #toChannel()}.
* <p>
* This must always close any stream returned by {@link #toStream()}.
* <p>
* Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel}
* that does not itself need to be closed up front; only the underlying output stream given by
* {@link #toStream()} must be closed.
*/
@Override
void close() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport
.createMapOutputWriter(shuffleId, mapId, numPartitions);
.createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
Expand All @@ -144,11 +144,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
partitionWriterSegments = new FileSegment[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
Expand Down Expand Up @@ -202,20 +202,22 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
boolean copyThrewException = true;
ShufflePartitionWriter writer = mapOutputWriter.getNextPartitionWriter();
if (transferToEnabled) {
WritableByteChannel outputChannel = writer.openChannel();
if (file.exists()) {
FileInputStream in = new FileInputStream(file);
try (FileChannel inputChannel = in.getChannel()){
Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size());
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
ShufflePartitionWriter writer = null;
try {
writer = mapOutputWriter.getNextPartitionWriter();
if (transferToEnabled) {
WritableByteChannel outputChannel = writer.toChannel();
if (file.exists()) {
FileInputStream in = new FileInputStream(file);
try (FileChannel inputChannel = in.getChannel()) {
Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size());
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
}
}
}
} else {
try (OutputStream tempOutputStream = writer.openStream()) {
} else {
OutputStream tempOutputStream = writer.toStream();
if (file.exists()) {
FileInputStream in = new FileInputStream(file);
try {
Expand All @@ -226,11 +228,14 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
}
}
}
if (file.exists() && !file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}
} finally {
Closeables.close(writer, copyThrewException);
}
lengths[i] = writer.closeAndGetLength();
if (file.exists() && !file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}

lengths[i] = writer.getNumBytesWritten();
}
} finally {
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public DefaultShuffleMapOutputWriter(
}

@Override
public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
public ShufflePartitionWriter getNextPartitionWriter() {
return new DefaultShufflePartitionWriter(currPartitionId++);
}

Expand All @@ -97,7 +97,7 @@ public void commitAllPartitions() throws IOException {
}

@Override
public void abort(Throwable error) throws IOException {
public void abort(Throwable error) {
try {
cleanUp();
} catch (Exception e) {
Expand All @@ -107,7 +107,7 @@ public void abort(Throwable error) throws IOException {
log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath());
}
if (!outputFile.delete() && outputFile.exists()) {
log.warn("Failed to delete outputshuffle file at {}", outputFile.getAbsolutePath());
log.warn("Failed to delete output shuffle file at {}", outputFile.getAbsolutePath());
}
}

Expand Down Expand Up @@ -154,42 +154,42 @@ private DefaultShufflePartitionWriter(int partitionId) {
}

@Override
public OutputStream openStream() throws IOException {
public OutputStream toStream() throws IOException {
initStream();
stream = new PartitionWriterStream();
return stream;
}

@Override
public long closeAndGetLength() {
public FileChannel toChannel() throws IOException {
initChannel();
currChannelPosition = outputFileChannel.position();
return outputFileChannel;
}

@Override
public long getNumBytesWritten() {
if (outputFileChannel != null && stream == null) {
try {
long newPosition = outputFileChannel.position();
long length = newPosition - currChannelPosition;
partitionLengths[partitionId] = length;
currChannelPosition = newPosition;
return length;
return newPosition - currChannelPosition;
} catch (Exception e) {
log.error("The currPartition is: " + partitionId, e);
throw new IllegalStateException("Attempting to calculate position of file channel", e);
log.error("The currPartition is: {}", partitionId, e);
throw new IllegalStateException("Failed to calculate position of file channel", e);
}
} else if (stream != null) {
return stream.getCount();
} else {
try {
stream.close();
} catch (Exception e) {
throw new IllegalStateException("Attempting to close output stream", e);
}
int length = stream.getCount();
partitionLengths[partitionId] = length;
return length;
return 0;
}
}

@Override
public FileChannel openChannel() throws IOException {
initChannel();
currChannelPosition = outputFileChannel.position();
return outputFileChannel;
public void close() throws IOException {
if (stream != null) {
stream.close();
}
partitionLengths[partitionId] = getNumBytesWritten();
}
}

Expand Down Expand Up @@ -218,7 +218,9 @@ public void close() throws IOException {

@Override
public void flush() throws IOException {
outputBufferedFileStream.flush();
if (!isClosed) {
outputBufferedFileStream.flush();
}
}
}
}
38 changes: 22 additions & 16 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ private[spark] object Utils extends Logging {
output: WritableByteChannel,
startPosition: Long,
bytesToCopy: Long): Unit = {
// val initialPos = output.position()
val outputInitialState = output match {
case outputFileChannel: FileChannel =>
Some((outputFileChannel.position(), outputFileChannel))
case _ => None
}
var count = 0L
// In case transferTo method transferred less data than we have required.
while (count < bytesToCopy) {
Expand All @@ -349,21 +353,23 @@ private[spark] object Utils extends Logging {
assert(count == bytesToCopy,
s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")

// // Check the position after transferTo loop to see if it is in the right position and
// // give user information if not.
// // Position will not be increased to the expected length after calling transferTo in
// // kernel version 2.6.32, this issue can be seen in
// // https://bugs.openjdk.java.net/browse/JDK-7052359
// // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
// val finalPos = output.position()
// val expectedPos = initialPos + bytesToCopy
// assert(finalPos == expectedPos,
// s"""
// |Current position $finalPos do not equal to expected position $expectedPos
// |after transferTo, please check your kernel version to see if it is 2.6.32,
// |this is a kernel bug which will lead to unexpected behavior when using transferTo.
// |You can set spark.file.transferTo = false to disable this NIO feature.
// """.stripMargin)
// Check the position after transferTo loop to see if it is in the right position and
// give user information if not.
// Position will not be increased to the expected length after calling transferTo in
// kernel version 2.6.32, this issue can be seen in
// https://bugs.openjdk.java.net/browse/JDK-7052359
// This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
outputInitialState.foreach { case (initialPos, outputFileChannel) =>
val finalPos = outputFileChannel.position()
val expectedPos = initialPos + bytesToCopy
assert(finalPos == expectedPos,
s"""
|Current position $finalPos do not equal to expected position $expectedPos
|after transferTo, please check your kernel version to see if it is 2.6.32,
|this is a kernel bug which will lead to unexpected behavior when using transferTo.
|You can set spark.file.transferTo = false to disable this NIO feature.
""".stripMargin)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,14 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("writing to an outputstream") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val stream = writer.openStream()
val stream = writer.toStream()
data(p).foreach { i => stream.write(i)}
stream.close()
intercept[IllegalStateException] {
stream.write(p)
}
assert(writer.closeAndGetLength() == D_LEN)
assert(writer.getNumBytesWritten() == D_LEN)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray
Expand All @@ -152,14 +153,15 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("writing to a channel") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val channel = writer.openChannel()
val channel = writer.toChannel()
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
val intBuffer = byteBuffer.asIntBuffer()
intBuffer.put(data(p))
assert(channel.isOpen)
channel.write(byteBuffer)
// Bytes require * 4
assert(writer.closeAndGetLength == D_LEN * 4)
assert(writer.getNumBytesWritten == D_LEN * 4)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
Expand All @@ -171,15 +173,16 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("copyStreams with an outputstream") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val stream = writer.openStream()
val stream = writer.toStream()
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
val intBuffer = byteBuffer.asIntBuffer()
intBuffer.put(data(p))
val in: InputStream = new ByteArrayInputStream(byteBuffer.array())
Utils.copyStream(in, stream, false, false)
in.close()
stream.close()
assert(writer.closeAndGetLength == D_LEN * 4)
assert(writer.getNumBytesWritten == D_LEN * 4)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
Expand All @@ -191,7 +194,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("copyStreamsWithNIO with a channel") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val channel = writer.openChannel()
val channel = writer.toChannel()
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
val intBuffer = byteBuffer.asIntBuffer()
intBuffer.put(data(p))
Expand All @@ -201,7 +204,8 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
val in = new FileInputStream(tempFile)
Utils.copyFileStreamNIO(in.getChannel, channel, 0, D_LEN * 4)
in.close()
assert(writer.closeAndGetLength == D_LEN * 4)
assert(writer.getNumBytesWritten == D_LEN * 4)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
ev.copy(code = code"$typeDef ${ev.value} = $className.getNumBytesWritten();", isNull = FalseLiteral)
}
}

0 comments on commit 9f6230b

Please sign in to comment.