Skip to content

Commit

Permalink
Prefetch spark.shuffle.file.buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
ccat3z committed Nov 28, 2024
1 parent 937be05 commit 7bcaa67
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 25 deletions.
1 change: 1 addition & 0 deletions cpp/core/config/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ const std::string kSparkRedactionRegex = "spark.redaction.regex";
const std::string kSparkRedactionString = "*********(redacted)";

const std::string kSparkLegacyTimeParserPolicy = "spark.sql.legacy.timeParserPolicy";
const std::string kShuffleFileBufferSize = "spark.shuffle.file.buffer";

std::unordered_map<std::string, std::string>
parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t planDataLength);
Expand Down
7 changes: 7 additions & 0 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,13 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
partitionWriterOptions.codecBackend = getCodecBackend(env, codecBackendJstr);
partitionWriterOptions.compressionMode = getCompressionMode(env, compressionModeJstr);
}
const auto& conf = ctx->getConfMap();
{
auto it = conf.find(kShuffleFileBufferSize);
if (it != conf.end()) {
partitionWriterOptions.shuffleFileBufferSize = static_cast<int64_t>(stoi(it->second));
}
}

std::unique_ptr<PartitionWriter> partitionWriter;

Expand Down
1 change: 1 addition & 0 deletions cpp/core/shuffle/LocalPartitionWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ arrow::Status LocalPartitionWriter::mergeSpills(uint32_t partitionId) {
auto spillIter = spills_.begin();
while (spillIter != spills_.end()) {
ARROW_ASSIGN_OR_RAISE(auto st, dataFileOs_->Tell());
(*spillIter)->openForRead(options_.shuffleFileBufferSize);
// Read if partition exists in the spilled file and write to the final file.
while (auto payload = (*spillIter)->nextPayload(partitionId)) {
// May trigger spill during compression.
Expand Down
3 changes: 3 additions & 0 deletions cpp/core/shuffle/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ static constexpr bool kEnableBufferedWrite = true;
static constexpr bool kDefaultUseRadixSort = true;
static constexpr int32_t kDefaultSortBufferSize = 4096;
static constexpr int64_t kDefaultReadBufferSize = 1 << 20;
static constexpr int64_t kDefaultShuffleFileBufferSize = 32 << 10;

enum ShuffleWriterType { kHashShuffle, kSortShuffle, kRssSortShuffle };
enum PartitionWriterType { kLocal, kRss };
Expand Down Expand Up @@ -86,6 +87,8 @@ struct PartitionWriterOptions {
int64_t pushBufferMaxSize = kDefaultPushMemoryThreshold;

int64_t sortBufferMaxSize = kDefaultSortBufferThreshold;

int64_t shuffleFileBufferSize = kDefaultShuffleFileBufferSize;
};

struct ShuffleWriterMetrics {
Expand Down
6 changes: 3 additions & 3 deletions cpp/core/shuffle/Spill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ bool Spill::hasNextPayload(uint32_t partitionId) {
}

std::unique_ptr<Payload> Spill::nextPayload(uint32_t partitionId) {
openSpillFile();
GLUTEN_CHECK(is_, "openForRead before invoke nextPayload");
if (!hasNextPayload(partitionId)) {
return nullptr;
}
Expand Down Expand Up @@ -71,9 +71,9 @@ void Spill::insertPayload(
}
}

void Spill::openSpillFile() {
void Spill::openForRead(uint64_t shuffleFileBufferSize) {
if (!is_) {
GLUTEN_ASSIGN_OR_THROW(is_, MmapFileStream::open(spillFile_));
GLUTEN_ASSIGN_OR_THROW(is_, MmapFileStream::open(spillFile_, shuffleFileBufferSize));
rawIs_ = is_.get();
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/core/shuffle/Spill.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Spill final {

SpillType type() const;

void openForRead(uint64_t shuffleFileBufferSize);

bool hasNextPayload(uint32_t partitionId);

std::unique_ptr<Payload> nextPayload(uint32_t partitionId);
Expand Down Expand Up @@ -76,7 +78,5 @@ class Spill final {
int64_t compressTime_;

arrow::io::InputStream* rawIs_;

void openSpillFile();
};
} // namespace gluten
41 changes: 24 additions & 17 deletions cpp/core/shuffle/Utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ arrow::Status getLengthBufferAndValueBufferStream(
*compressedLengthPtr = actualLength;
return arrow::Status::OK();
}

uint64_t roundUpToPageSize(uint64_t value) {
static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
static auto pageMask = ~(pageSize - 1);
DCHECK_GT(pageSize, 0);
DCHECK_EQ(pageMask & pageSize, pageSize);
return (value + pageSize - 1) & pageMask;
}
} // namespace

arrow::Result<std::shared_ptr<arrow::RecordBatch>> makeCompressedRecordBatch(
Expand Down Expand Up @@ -216,10 +224,10 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> makeUncompressedRecordBatch(
return arrow::RecordBatch::Make(writeSchema, 1, {arrays});
}

MmapFileStream::MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t size)
: fd_(std::move(fd)), data_(data), size_(size){};
MmapFileStream::MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t size, uint64_t prefetchSize)
: prefetchSize_(roundUpToPageSize(prefetchSize)), fd_(std::move(fd)), data_(data), size_(size){};

arrow::Result<std::shared_ptr<MmapFileStream>> MmapFileStream::open(const std::string& path) {
arrow::Result<std::shared_ptr<MmapFileStream>> MmapFileStream::open(const std::string& path, uint64_t prefetchSize) {
ARROW_ASSIGN_OR_RAISE(auto fileName, arrow::internal::PlatformFilename::FromString(path));

ARROW_ASSIGN_OR_RAISE(auto fd, arrow::internal::FileOpenReadable(fileName));
Expand All @@ -230,7 +238,7 @@ arrow::Result<std::shared_ptr<MmapFileStream>> MmapFileStream::open(const std::s
return arrow::Status::IOError("Memory mapping file failed: ", ::arrow::internal::ErrnoMessage(errno));
}

return std::make_shared<MmapFileStream>(std::move(fd), static_cast<uint8_t*>(result), size);
return std::make_shared<MmapFileStream>(std::move(fd), static_cast<uint8_t*>(result), size, prefetchSize);
}

arrow::Result<int64_t> MmapFileStream::actualReadSize(int64_t nbytes) {
Expand All @@ -245,12 +253,8 @@ bool MmapFileStream::closed() const {
};

void MmapFileStream::advance(int64_t length) {
static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
static auto pageMask = ~(pageSize - 1);
DCHECK_GT(pageSize, 0);
DCHECK_EQ(pageMask & pageSize, pageSize);

auto purgeLength = (pos_ - posRetain_) & pageMask;
// Dont need data before pos
auto purgeLength = (pos_ - posRetain_) / prefetchSize_ * prefetchSize_;
if (purgeLength > 0) {
int ret = madvise(data_ + posRetain_, purgeLength, MADV_DONTNEED);
if (ret != 0) {
Expand All @@ -263,17 +267,20 @@ void MmapFileStream::advance(int64_t length) {
}

void MmapFileStream::willNeed(int64_t length) {
static auto pageSize = static_cast<size_t>(arrow::internal::GetPageSize());
static auto pageMask = ~(pageSize - 1);
DCHECK_GT(pageSize, 0);
DCHECK_EQ(pageMask & pageSize, pageSize);
// Skip if already fetched
if (pos_ + length <= posFetch_) {
return;
}

auto willNeedPos = pos_ & pageMask;
auto willNeedLen = pos_ + length - willNeedPos;
int ret = madvise(data_ + willNeedPos, willNeedLen, MADV_WILLNEED);
// Round up to multiple of prefetchSize
auto fetchLen = ((length + prefetchSize_ - 1) / prefetchSize_) * prefetchSize_;
fetchLen = std::min(size_ - pos_, fetchLen);
int ret = madvise(data_ + posFetch_, fetchLen, MADV_WILLNEED);
if (ret != 0) {
LOG(WARNING) << "madvise willneed failed: " << ::arrow::internal::ErrnoMessage(errno);
}

posFetch_ += fetchLen;
}

arrow::Status MmapFileStream::Close() {
Expand Down
7 changes: 5 additions & 2 deletions cpp/core/shuffle/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ std::shared_ptr<arrow::Buffer> zeroLengthNullBuffer();
// to prefetch and release memory timely.
class MmapFileStream : public arrow::io::InputStream {
public:
MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t size);
MmapFileStream(arrow::internal::FileDescriptor fd, uint8_t* data, int64_t size, uint64_t prefetchSize);

static arrow::Result<std::shared_ptr<MmapFileStream>> open(const std::string& path);
static arrow::Result<std::shared_ptr<MmapFileStream>> open(const std::string& path, uint64_t prefetchSize = 0);

arrow::Result<int64_t> Tell() const override;

Expand All @@ -95,10 +95,13 @@ class MmapFileStream : public arrow::io::InputStream {

void willNeed(int64_t length);

// Page-aligned prefetch size
const int64_t prefetchSize_;
arrow::internal::FileDescriptor fd_;
uint8_t* data_ = nullptr;
int64_t size_;
int64_t pos_ = 0;
int64_t posFetch_ = 0;
int64_t posRetain_ = 0;
};

Expand Down
12 changes: 11 additions & 1 deletion shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten

import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.sql.internal.SQLConf

import com.google.common.collect.ImmutableList
Expand Down Expand Up @@ -568,6 +568,7 @@ object GlutenConfig {
val SPARK_OFFHEAP_SIZE_KEY = "spark.memory.offHeap.size"
val SPARK_OFFHEAP_ENABLED = "spark.memory.offHeap.enabled"
val SPARK_REDACTION_REGEX = "spark.redaction.regex"
val SPARK_SHUFFLE_FILE_BUFFER = "spark.shuffle.file.buffer"

// For Soft Affinity Scheduling
// Enable Soft Affinity Scheduling, default value is false
Expand Down Expand Up @@ -736,6 +737,15 @@ object GlutenConfig {
)
keyWithDefault.forEach(e => nativeConfMap.put(e._1, conf.getOrElse(e._1, e._2)))

conf
.get(SPARK_SHUFFLE_FILE_BUFFER)
.foreach(
v =>
nativeConfMap
.put(
SPARK_SHUFFLE_FILE_BUFFER,
(JavaUtils.byteStringAs(v, ByteUnit.KiB) * 1024).toString))

// Backend's dynamic session conf only.
val confPrefix = prefixOf(backendName)
conf
Expand Down

0 comments on commit 7bcaa67

Please sign in to comment.