Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21475][Core] Use NIO's Files API to replace FileInputStream/FileOutputStream in some critical paths #18684

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
package org.apache.spark.network.buffer;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.StandardOpenOption;

import com.google.common.base.Objects;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.ByteStreams;
import io.netty.channel.DefaultFileRegion;

Expand Down Expand Up @@ -93,9 +95,9 @@ public ByteBuffer nioByteBuffer() throws IOException {

@Override
public InputStream createInputStream() throws IOException {
FileInputStream is = null;
InputStream is = null;
try {
is = new FileInputStream(file);
is = Files.newInputStream(file.toPath());
ByteStreams.skipFully(is, offset);
return new LimitedInputStream(is, length);
} catch (IOException e) {
Expand Down Expand Up @@ -132,7 +134,8 @@ public Object convertToNetty() throws IOException {
if (conf.lazyFileDescriptor()) {
return new DefaultFileRegion(file, offset, length);
} else {
FileChannel fileChannel = new FileInputStream(file).getChannel();
FileChannel fileChannel = FileChannel.open(file.toPath(),
ImmutableSet.of(StandardOpenOption.READ));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new set for this? should we call:

FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're actually the same, the one you mentioned is just a simple wrapper of the former one. I will change to yours for the simplicity.

return new DefaultFileRegion(fileChannel, offset, length);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.network.shuffle;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.nio.file.Files;
import java.util.Arrays;

import org.slf4j.Logger;
Expand Down Expand Up @@ -165,7 +165,7 @@ private class DownloadCallback implements StreamCallback {

DownloadCallback(int chunkIndex) throws IOException {
this.targetFile = tempShuffleFileManager.createTempShuffleFile();
this.channel = Channels.newChannel(new FileOutputStream(targetFile));
this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath()));
this.chunkIndex = chunkIndex;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.LongBuffer;
import java.nio.file.Files;

/**
* Keeps the index information for a particular map output
Expand All @@ -38,7 +38,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException {
offsets = buffer.asLongBuffer();
DataInputStream dis = null;
try {
dis = new DataInputStream(new FileInputStream(indexFile));
dis = new DataInputStream(Files.newInputStream(indexFile.toPath()));
dis.readFully(buffer.array());
} finally {
if (dis != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.shuffle.sort;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.channels.FileChannel;
import static java.nio.file.StandardOpenOption.*;
import javax.annotation.Nullable;

import scala.None$;
Expand All @@ -30,6 +30,7 @@
import scala.collection.Iterator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -188,17 +189,20 @@ private long[] writePartitionedFile(File outputFile) throws IOException {
return lengths;
}

final FileOutputStream out = new FileOutputStream(outputFile, true);
final FileChannel out = FileChannel.open(outputFile.toPath(),
ImmutableSet.of(WRITE, APPEND, CREATE));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final FileChannel out = FileChannel.open(outputFile.toPath(), WRITE, APPEND, CREATE)

?

final long writeStartTime = System.nanoTime();
boolean threwException = true;
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
if (file.exists()) {
final FileInputStream in = new FileInputStream(file);
final FileChannel in = FileChannel.open(file.toPath(), ImmutableSet.of(READ));
boolean copyThrewException = true;
try {
lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
long size = in.size();
Utils.copyFileStreamNIO(in, out, 0, size);
lengths[i] = size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we modify Utils.copyStream() to support this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There're lot of other places using Utils#copyStream, it would be better not to change this API.

copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
import static java.nio.file.StandardOpenOption.*;
import java.util.Iterator;

import scala.Option;
Expand All @@ -29,6 +30,7 @@
import scala.reflect.ClassTag$;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
import com.google.common.io.Files;
Expand Down Expand Up @@ -290,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
java.nio.file.Files.newOutputStream(outputFile.toPath()).close(); // Create an empty file
return new long[partitioner.numPartitions()];
} else if (spills.length == 1) {
// Here, we don't need to perform any metrics updates because the bytes written to this
Expand Down Expand Up @@ -367,7 +369,7 @@ private long[] mergeSpillsWithFileStream(
final InputStream[] spillInputStreams = new InputStream[spills.length];

final OutputStream bos = new BufferedOutputStream(
new FileOutputStream(outputFile),
java.nio.file.Files.newOutputStream(outputFile.toPath()),
outputBufferSizeInBytes);
// Use a counting output stream to avoid having to close the underlying file and ask
// the file system for its size after each partition is written.
Expand Down Expand Up @@ -442,11 +444,12 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
spillInputChannels[i] = FileChannel.open(spills[i].file.toPath(), ImmutableSet.of(READ));
}
// This file needs to opened in append mode in order to work around a Linux kernel bug that
// affects transferTo; see SPARK-3948 for more details.
mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
mergedFileOutputChannel = FileChannel.open(outputFile.toPath(),
ImmutableSet.of(WRITE, CREATE, APPEND));

long bytesWrittenToMergedFile = 0;
for (int partition = 0; partition < numPartitions; partition++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.shuffle

import java.io._
import java.nio.file.Files

import com.google.common.io.ByteStreams

Expand Down Expand Up @@ -141,7 +142,8 @@ private[spark] class IndexShuffleBlockResolver(
val indexFile = getIndexFile(shuffleId, mapId)
val indexTmp = Utils.tempFileWith(indexFile)
try {
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
val out = new DataOutputStream(
new BufferedOutputStream(Files.newOutputStream(indexTmp.toPath)))
Utils.tryWithSafeFinally {
// We take in lengths of each block, need to convert it to offsets.
var offset = 0L
Expand Down Expand Up @@ -196,7 +198,7 @@ private[spark] class IndexShuffleBlockResolver(
// find out the consolidated file, then the offset within that from our index
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)

val in = new DataInputStream(new FileInputStream(indexFile))
val in = new DataInputStream(Files.newInputStream(indexFile.toPath))
try {
ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
package org.apache.spark.util.collection

import java.io._
import java.nio.channels.{Channels, FileChannel}
import java.nio.file.StandardOpenOption
import java.util.Comparator

import scala.collection.BufferedIterator
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.collect.ImmutableSet
import com.google.common.io.ByteStreams

import org.apache.spark.{SparkEnv, TaskContext}
Expand Down Expand Up @@ -460,7 +463,7 @@ class ExternalAppendOnlyMap[K, V, C](
)

private var batchIndex = 0 // Which batch we're in
private var fileStream: FileInputStream = null
private var fileChannel: FileChannel = null

// An intermediate stream that reads from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
Expand All @@ -477,22 +480,23 @@ class ExternalAppendOnlyMap[K, V, C](
if (batchIndex < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
fileChannel.close()
deserializeStream = null
fileStream = null
fileChannel = null
}

val start = batchOffsets(batchIndex)
fileStream = new FileInputStream(file)
fileStream.getChannel.position(start)
fileChannel = FileChannel.open(file.toPath, ImmutableSet.of(StandardOpenOption.READ))
fileChannel.position(start)
batchIndex += 1

val end = batchOffsets(batchIndex)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val bufferedStream = new BufferedInputStream(
ByteStreams.limit(Channels.newInputStream(fileChannel), end - start))
val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream)
ser.deserializeStream(wrappedStream)
} else {
Expand Down Expand Up @@ -552,9 +556,9 @@ class ExternalAppendOnlyMap[K, V, C](
ds.close()
deserializeStream = null
}
if (fileStream != null) {
fileStream.close()
fileStream = null
if (fileChannel != null) {
fileChannel.close()
fileChannel = null
}
if (file.exists()) {
if (!file.delete()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
package org.apache.spark.util.collection

import java.io._
import java.nio.channels.{Channels, FileChannel}
import java.nio.file.StandardOpenOption
import java.util.Comparator

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.collect.ImmutableSet
import com.google.common.io.ByteStreams

import org.apache.spark._
Expand Down Expand Up @@ -492,7 +495,7 @@ private[spark] class ExternalSorter[K, V, C](

// Intermediate file and deserializer streams that read from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
var fileStream: FileInputStream = null
var fileChannel: FileChannel = null
var deserializeStream = nextBatchStream() // Also sets fileStream

var nextItem: (K, C) = null
Expand All @@ -505,22 +508,23 @@ private[spark] class ExternalSorter[K, V, C](
if (batchId < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
fileChannel.close()
deserializeStream = null
fileStream = null
fileChannel = null
}

val start = batchOffsets(batchId)
fileStream = new FileInputStream(spill.file)
fileStream.getChannel.position(start)
fileChannel = FileChannel.open(spill.file.toPath, ImmutableSet.of(StandardOpenOption.READ))
fileChannel.position(start)
batchId += 1

val end = batchOffsets(batchId)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val bufferedStream = new BufferedInputStream(
ByteStreams.limit(Channels.newInputStream(fileChannel), end - start))

val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream)
serInstance.deserializeStream(wrappedStream)
Expand Down Expand Up @@ -610,7 +614,7 @@ private[spark] class ExternalSorter[K, V, C](
batchId = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
fileStream = null
fileChannel = null
if (ds != null) {
ds.close()
}
Expand Down