Skip to content

Commit

Permalink
Add recursive spill for RowNumber
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Mar 19, 2024
1 parent 0cd6c0a commit 59fda5e
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 82 deletions.
4 changes: 2 additions & 2 deletions velox/common/base/SpillConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ struct SpillConfig {
/// the next level of recursive spilling.
int32_t spillLevel(uint8_t startBitOffset) const;

/// Checks if the given 'startBitOffset' and 'numPartitionBits' has exceeded
/// the max hash join spill limit.
/// Checks if the given 'startBitOffset' has exceeded the max hash join spill
/// limit.
bool exceedSpillLevelLimit(uint8_t startBitOffset) const;

/// A callback function that returns the spill directory path. Implementations
Expand Down
146 changes: 101 additions & 45 deletions velox/exec/RowNumber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,52 +65,26 @@ RowNumber::RowNumber(
resultProjections_.emplace_back(0, inputType->size());
results_.resize(1);
}

if (spillEnabled()) {
setSpillPartitionBits();
}
}

void RowNumber::addInput(RowVectorPtr input) {
const auto numInput = input->size();

if (table_) {
ensureInputFits(input);

if (inputSpiller_ != nullptr) {
spillInput(input, pool());
return;
}

SelectivityVector rows(numInput);
table_->prepareForGroupProbe(
*lookup_,
input,
rows,
false,
BaseHashTable::kNoSpillInputStartPartitionBit);
table_->groupProbe(*lookup_);

// Initialize new partitions with zeros.
for (auto i : lookup_->newGroups) {
setNumRows(lookup_->hits[i], 0);
}
addInputInternal(input);
}

input_ = std::move(input);
}

void RowNumber::addSpillInput() {
const auto numInput = input_->size();
SelectivityVector rows(numInput);
table_->prepareForGroupProbe(
*lookup_, input_, rows, false, spillConfig_->startPartitionBit);
table_->groupProbe(*lookup_);

// Initialize new partitions with zeros.
for (auto i : lookup_->newGroups) {
setNumRows(lookup_->hits[i], 0);
}

// TODO Add support for recursive spilling.
}

void RowNumber::noMoreInput() {
Operator::noMoreInput();

Expand All @@ -135,6 +109,8 @@ void RowNumber::restoreNextSpillPartition() {
if (hashTableIt != spillHashTablePartitionSet_.end()) {
spillHashTableReader_ = hashTableIt->second->createUnorderedReader(pool());

setSpillPartitionBits(&(it->first));

RowVectorPtr data;
while (spillHashTableReader_->nextBatch(data)) {
// 'data' contains partition-by keys and count. Transform 'data' to match
Expand Down Expand Up @@ -168,7 +144,7 @@ void RowNumber::restoreNextSpillPartition() {
spillInputPartitionSet_.erase(it);

spillInputReader_->nextBatch(input_);
addSpillInput();
addInputInternal(input_, true);
}

void RowNumber::ensureInputFits(const RowVectorPtr& input) {
Expand Down Expand Up @@ -251,7 +227,19 @@ FlatVector<int64_t>& RowNumber::getOrCreateRowNumberVector(vector_size_t size) {

RowVectorPtr RowNumber::getOutput() {
if (input_ == nullptr) {
return nullptr;
if (spillInputReader_ == nullptr) {
return nullptr;
}

if (yield_) {
VELOX_CHECK_NULL(input_);
return nullptr;
}

recursiveSpillInput();
if (input_ == nullptr) {
return nullptr;
}
}

if (!table_) {
Expand Down Expand Up @@ -307,7 +295,7 @@ RowVectorPtr RowNumber::getOutput() {

if (spillInputReader_ != nullptr) {
if (spillInputReader_->nextBatch(input_)) {
addSpillInput();
addInputInternal(input_, true);
} else {
input_ = nullptr;
spillInputReader_ = nullptr;
Expand Down Expand Up @@ -368,8 +356,14 @@ void RowNumber::reclaim(
return;
}

if (inputSpiller_ != nullptr) {
// Already spilled.
if (exceededMaxSpillLevelLimit_) {
LOG(WARNING) << "Exceeded row spill level limit: "
<< spillConfig_->maxSpillLevel
<< ", and abandon spilling for memory pool: "
<< pool()->name();
common::SpillStats spillStats;
spillStats.spillMaxLevelExceededCount = 1;
Operator::recordSpillStats(spillStats);
return;
}

Expand All @@ -380,8 +374,6 @@ SpillPartitionNumSet RowNumber::spillHashTable() {
// TODO Replace joinPartitionBits and Spiller::Type::kHashJoinBuild.
VELOX_CHECK_NOT_NULL(table_);

const auto& spillConfig = spillConfig_.value();

auto columnTypes = table_->rows()->columnTypes();
auto tableType = ROW(std::move(columnTypes));

Expand All @@ -390,7 +382,7 @@ SpillPartitionNumSet RowNumber::spillHashTable() {
table_->rows(),
tableType,
spillPartitionBits_,
&spillConfig);
&spillConfig_.value());

hashTableSpiller->spill();
hashTableSpiller->finishSpill(spillHashTablePartitionSet_);
Expand Down Expand Up @@ -429,16 +421,11 @@ void RowNumber::setupInputSpiller(

void RowNumber::spill() {
VELOX_CHECK(spillEnabled());
VELOX_CHECK_NULL(inputSpiller_);

spillPartitionBits_ = HashBitRange(
spillConfig_->startPartitionBit,
spillConfig_->startPartitionBit + spillConfig_->numPartitionBits);

const auto spillPartitionSet = spillHashTable();
VELOX_CHECK_EQ(table_->numDistinct(), 0);

setupInputSpiller(spillPartitionSet);

if (input_ != nullptr) {
spillInput(input_, memory::spillMemoryPool());
input_ = nullptr;
Expand Down Expand Up @@ -488,4 +475,73 @@ void RowNumber::spillInput(
}
}

BlockingReason RowNumber::isBlocked(ContinueFuture* /* unused */) {
const auto reason =
yield_ ? BlockingReason::kYield : BlockingReason::kNotBlocked;
yield_ = false;
return reason;
}

void RowNumber::addInputInternal(const RowVectorPtr& input, bool fromSpill) {
if (fromSpill) {
ensureInputFits(input);
if (input == nullptr) {
return;
}
}

const auto numInput = input->size();
SelectivityVector rows(numInput);

table_->prepareForGroupProbe(
*lookup_,
input,
rows,
false,
fromSpill ? spillConfig_->startPartitionBit
: BaseHashTable::kNoSpillInputStartPartitionBit);
table_->groupProbe(*lookup_);

// Initialize new partitions with zeros.
for (auto i : lookup_->newGroups) {
setNumRows(lookup_->hits[i], 0);
}
}

void RowNumber::recursiveSpillInput() {
RowVectorPtr input;
while (spillInputReader_->nextBatch(input)) {
spillInput(input, pool());

if (operatorCtx_->driver()->shouldYield()) {
yield_ = true;
return;
}
}

inputSpiller_->finishSpill(spillInputPartitionSet_);
recordSpillStats(inputSpiller_->stats());
spillInputReader_ = nullptr;

removeEmptyPartitions(spillInputPartitionSet_);
restoreNextSpillPartition();
}

void RowNumber::setSpillPartitionBits(
const SpillPartitionId* restoredPartitionId) {
const auto startPartitionBitOffset = restoredPartitionId == nullptr
? spillConfig_->startPartitionBit
: restoredPartitionId->partitionBitOffset() +
spillConfig_->numPartitionBits;
if (spillConfig_->exceedSpillLevelLimit(startPartitionBitOffset)) {
exceededMaxSpillLevelLimit_ = true;
return;
}

exceededMaxSpillLevelLimit_ = false;
spillPartitionBits_ = HashBitRange(
startPartitionBitOffset,
startPartitionBitOffset + spillConfig_->numPartitionBits);
}

} // namespace facebook::velox::exec
26 changes: 22 additions & 4 deletions velox/exec/RowNumber.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class RowNumber : public Operator {
return !noMoreInput_ && !finishedEarly_;
}

BlockingReason isBlocked(ContinueFuture* /* unused */) override {
return BlockingReason::kNotBlocked;
}
BlockingReason isBlocked(ContinueFuture* /* unused */) override;

bool isFinished() override {
return (noMoreInput_ && input_ == nullptr &&
Expand All @@ -64,7 +62,9 @@ class RowNumber : public Operator {

void spill();

void addSpillInput();
// Probes the hash 'table_' with input. If 'fromSpill' is true, the input is
// read from the spilled input, otherwise from the source.
void addInputInternal(const RowVectorPtr& input, bool fromSpill = false);

void restoreNextSpillPartition();

Expand All @@ -78,6 +78,18 @@ class RowNumber : public Operator {

FlatVector<int64_t>& getOrCreateRowNumberVector(vector_size_t size);

// Used by recursive spill processing to read the spilled input data from the
// previous spill run through 'spillInputReader_' and then spill them back
// into a number of sub-partitions. After that, the function restores one of
// the newly spilled partitions and resets 'spillInputReader_' accordingly.
void recursiveSpillInput();

// Set 'spillPartitionBits_' used for (recursive) spill.
// If 'restoredPartitionId' is not nullptr, use it to set the
// 'spillPartitionBits_', otherwise use 'spillConfig_'.
void setSpillPartitionBits(
const SpillPartitionId* restoredPartitionId = nullptr);

const std::optional<int32_t> limit_;
const bool generateRowNumber_;

Expand Down Expand Up @@ -117,5 +129,11 @@ class RowNumber : public Operator {

// Used to calculate the spill partition numbers of the inputs.
std::unique_ptr<HashPartitionFunction> spillHashFunction_;

// The cpu may be voluntarily yield after running too long when processing
// input from spilled file.
bool yield_;

bool exceededMaxSpillLevelLimit_{false};
};
} // namespace facebook::velox::exec
Loading

0 comments on commit 59fda5e

Please sign in to comment.