Skip to content

Commit

Permalink
fix: fix the memory reclaim bytes for hash join (facebookincubator#11642
Browse files Browse the repository at this point in the history
)

Summary:

Both hash join and probe does the coordinated spill so we shouldn't report the reclaimed bytes from a single node
but shall report from the plan node. Also probe side spill might spill built table from join side and the memory is
actually reclaimed from build side pool instead of probe side.

This PR also removes the unused wait for spill state from hash build

Differential Revision: D66437719
  • Loading branch information
xiaoxmeng authored and facebook-github-bot committed Nov 26, 2024
1 parent c395c55 commit 01f2833
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 33 deletions.
2 changes: 0 additions & 2 deletions velox/exec/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,8 +1140,6 @@ std::string blockingReasonToString(BlockingReason reason) {
return "kWaitForMemory";
case BlockingReason::kWaitForConnector:
return "kWaitForConnector";
case BlockingReason::kWaitForSpill:
return "kWaitForSpill";
case BlockingReason::kYield:
return "kYield";
case BlockingReason::kWaitForArbitration:
Expand Down
2 changes: 2 additions & 0 deletions velox/exec/Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ enum class BlockingReason {
kWaitForConnector,
/// Build operator is blocked waiting for all its peers to stop to run group
/// spill on all of them.
///
/// TODO: remove this after Prestissimo is updated.
kWaitForSpill,
/// Some operators (like Table Scan) may run long loops and can 'voluntarily'
/// exit them because Task requested to yield or stop or after a certain time.
Expand Down
13 changes: 0 additions & 13 deletions velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ BlockingReason fromStateToBlockingReason(HashBuild::State state) {
return BlockingReason::kNotBlocked;
case HashBuild::State::kYield:
return BlockingReason::kYield;
case HashBuild::State::kWaitForSpill:
return BlockingReason::kWaitForSpill;
case HashBuild::State::kWaitForBuild:
return BlockingReason::kWaitForJoinBuild;
case HashBuild::State::kWaitForProbe:
Expand Down Expand Up @@ -944,13 +942,6 @@ BlockingReason HashBuild::isBlocked(ContinueFuture* future) {
break;
case State::kFinish:
break;
case State::kWaitForSpill:
if (!future_.valid()) {
setRunning();
VELOX_CHECK_NOT_NULL(input_);
addInput(std::move(input_));
}
break;
case State::kWaitForBuild:
[[fallthrough]];
case State::kWaitForProbe:
Expand Down Expand Up @@ -1003,8 +994,6 @@ void HashBuild::checkStateTransition(State state) {
break;
case State::kWaitForBuild:
[[fallthrough]];
case State::kWaitForSpill:
[[fallthrough]];
case State::kWaitForProbe:
[[fallthrough]];
case State::kFinish:
Expand All @@ -1022,8 +1011,6 @@ std::string HashBuild::stateName(State state) {
return "RUNNING";
case State::kYield:
return "YIELD";
case State::kWaitForSpill:
return "WAIT_FOR_SPILL";
case State::kWaitForBuild:
return "WAIT_FOR_BUILD";
case State::kWaitForProbe:
Expand Down
9 changes: 3 additions & 6 deletions velox/exec/HashBuild.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@ class HashBuild final : public Operator {
/// The yield state that voluntarily yield cpu after running too long when
/// processing input from spilled file.
kYield = 2,
/// The state that waits for the pending group spill to finish. This state
/// only applies if disk spilling is enabled.
kWaitForSpill = 3,
/// The state that waits for the hash tables to be merged together.
kWaitForBuild = 4,
kWaitForBuild = 3,
/// The state that waits for the hash probe to finish before start to build
/// the hash table for one of previously spilled partition. This state only
/// applies if disk spilling is enabled.
kWaitForProbe = 5,
kWaitForProbe = 4,
/// The finishing state.
kFinish = 6,
kFinish = 5,
};
static std::string stateName(State state);

Expand Down
24 changes: 14 additions & 10 deletions velox/exec/HashJoinBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,19 +382,20 @@ uint64_t HashJoinMemoryReclaimer::reclaim(
uint64_t targetBytes,
uint64_t maxWaitMs,
memory::MemoryReclaimer::Stats& stats) {
const auto prevNodeReservedMemory = pool->reservedBytes();

// The flags to track if we have reclaimed from both build and probe operators
// under a hash join node.
bool hasReclaimedFromBuild{false};
bool hasReclaimedFromProbe{false};
uint64_t reclaimedBytes{0};
pool->visitChildren([&](memory::MemoryPool* child) {
VELOX_CHECK_EQ(child->kind(), memory::MemoryPool::Kind::kLeaf);
const bool isBuild = isHashBuildMemoryPool(*child);
if (isBuild) {
if (!hasReclaimedFromBuild) {
// We just need to reclaim from any one of the hash build operator.
hasReclaimedFromBuild = true;
reclaimedBytes += child->reclaim(targetBytes, maxWaitMs, stats);
child->reclaim(targetBytes, maxWaitMs, stats);
}
return !hasReclaimedFromProbe;
}
Expand All @@ -403,22 +404,25 @@ uint64_t HashJoinMemoryReclaimer::reclaim(
// The same as build operator, we only need to reclaim from any one of the
// hash probe operator.
hasReclaimedFromProbe = true;
reclaimedBytes += child->reclaim(targetBytes, maxWaitMs, stats);
child->reclaim(targetBytes, maxWaitMs, stats);
}
return !hasReclaimedFromBuild;
});
if (reclaimedBytes != 0) {
return reclaimedBytes;

auto currNodeReservedMemory = pool->reservedBytes();
VELOX_CHECK_LE(currNodeReservedMemory, prevNodeReservedMemory);
if (currNodeReservedMemory < prevNodeReservedMemory) {
return prevNodeReservedMemory - currNodeReservedMemory;
}

auto joinBridge = joinBridge_.lock();
if (joinBridge == nullptr) {
return reclaimedBytes;
return 0;
}
const auto oldNodeReservedMemory = pool->reservedBytes();
joinBridge->reclaim();
const auto newNodeReservedMemory = pool->reservedBytes();
VELOX_CHECK_LE(newNodeReservedMemory, oldNodeReservedMemory);
return oldNodeReservedMemory - newNodeReservedMemory;
currNodeReservedMemory = pool->reservedBytes();
VELOX_CHECK_LE(currNodeReservedMemory, prevNodeReservedMemory);
return prevNodeReservedMemory - currNodeReservedMemory;
}

bool isHashBuildMemoryPool(const memory::MemoryPool& pool) {
Expand Down
246 changes: 244 additions & 2 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
#include <fmt/format.h>
#include "folly/experimental/EventCount.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/common/memory/SharedArbitrator.h"
#include "velox/common/memory/tests/SharedArbitratorTestUtil.h"
#include "velox/common/testutil/TestValue.h"
#include "velox/dwio/common/tests/utils/BatchMaker.h"
#include "velox/exec/HashBuild.h"
Expand Down Expand Up @@ -8555,4 +8553,248 @@ TEST_F(HashJoinTest, combineSmallVectorsAfterFilter) {
true);
}
}

TEST_F(HashJoinTest, buildReclaimedMemoryReport) {
constexpr int64_t kMaxBytes = 1LL << 30; // 1GB
const int32_t numBuildVectors = 3;
std::vector<RowVectorPtr> buildVectors;
for (int32_t i = 0; i < numBuildVectors; ++i) {
VectorFuzzer fuzzer({.vectorSize = 200}, pool());
buildVectors.push_back(fuzzer.fuzzRow(buildType_));
}

const int32_t numProbeVectors = 3;
std::vector<RowVectorPtr> probeVectors;
for (int32_t i = 0; i < numProbeVectors; ++i) {
VectorFuzzer fuzzer({.vectorSize = 200}, pool());
probeVectors.push_back(fuzzer.fuzzRow(probeType_));
}

const int numDrivers{2};
// duckdb need double probe and build inputs as we run two drivers for hash
// join.
std::vector<RowVectorPtr> totalProbeVectors = probeVectors;
totalProbeVectors.insert(
totalProbeVectors.end(), probeVectors.begin(), probeVectors.end());
std::vector<RowVectorPtr> totalBuildVectors = buildVectors;
totalBuildVectors.insert(
totalBuildVectors.end(), buildVectors.begin(), buildVectors.end());

createDuckDbTable("t", totalProbeVectors);
createDuckDbTable("u", totalBuildVectors);

auto tempDirectory = exec::test::TempDirectoryPath::create();
auto queryPool = memory::memoryManager()->addRootPool(
"", kMaxBytes, memory::MemoryReclaimer::create());

auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.values(probeVectors, true)
.hashJoin(
{"t_k1"},
{"u_k1"},
PlanBuilder(planNodeIdGenerator)
.values(buildVectors, true)
.planNode(),
"",
concat(probeType_->names(), buildType_->names()))
.planNode();

folly::EventCount driverWait;
std::atomic_bool driverWaitFlag{true};
folly::EventCount taskWait;
std::atomic_bool taskWaitFlag{true};

Operator* op{nullptr};
SCOPED_TESTVALUE_SET(
"facebook::velox::exec::Driver::runInternal::addInput",
std::function<void(Operator*)>(([&](Operator* testOp) {
if (testOp->operatorType() != "HashBuild") {
return;
}
op = testOp;
})));

std::atomic_bool injectOnce{true};
SCOPED_TESTVALUE_SET(
"facebook::velox::common::memory::MemoryPoolImpl::maybeReserve",
std::function<void(memory::MemoryPoolImpl*)>(
([&](memory::MemoryPoolImpl* pool) {
ASSERT_TRUE(op != nullptr);
if (!isHashBuildMemoryPool(*pool)) {
return;
}
ASSERT_TRUE(op->canReclaim());
if (op->pool()->usedBytes() == 0) {
// We skip trigger memory reclaim when the hash table is empty on
// memory reservation.
return;
}
if (op->pool()->parent()->reservedBytes() ==
op->pool()->reservedBytes()) {
// We skip trigger memory reclaim if the other peer hash build
// operator hasn't run yet.
return;
}
if (!injectOnce.exchange(false)) {
return;
}
uint64_t reclaimableBytes{0};
const bool reclaimable = op->reclaimableBytes(reclaimableBytes);
ASSERT_TRUE(reclaimable);
ASSERT_GT(reclaimableBytes, 0);
auto* driver = op->testingOperatorCtx()->driver();
SuspendedSection suspendedSection(driver);
taskWaitFlag = false;
taskWait.notifyAll();
driverWait.await([&]() { return !driverWaitFlag.load(); });
})));

std::thread taskThread([&]() {
HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.numDrivers(numDrivers)
.planNode(plan)
.queryPool(std::move(queryPool))
.injectSpill(false)
.spillDirectory(tempDirectory->getPath())
.referenceQuery(
"SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1")
.config(core::QueryConfig::kSpillStartPartitionBit, "29")
.verifier([&](const std::shared_ptr<Task>& task, bool /*unused*/) {
const auto statsPair = taskSpilledStats(*task);
ASSERT_GT(statsPair.first.spilledBytes, 0);
ASSERT_EQ(statsPair.first.spilledPartitions, 16);
ASSERT_GT(statsPair.second.spilledBytes, 0);
ASSERT_EQ(statsPair.second.spilledPartitions, 16);
verifyTaskSpilledRuntimeStats(*task, true);
})
.run();
});

taskWait.await([&]() { return !taskWaitFlag.load(); });
ASSERT_TRUE(op != nullptr);
auto task = op->testingOperatorCtx()->task();
auto* nodePool = op->pool()->parent();
const auto nodeMemoryUsage = nodePool->reservedBytes();
{
memory::ScopedMemoryArbitrationContext ctx(op->pool());
const uint64_t reclaimedBytes = task->pool()->reclaim(
task->pool()->capacity(), 1'000'000, reclaimerStats_);
ASSERT_GT(reclaimedBytes, 0);
ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes);
}
// Verify all the memory has been freed.
ASSERT_EQ(nodePool->reservedBytes(), 0);

driverWaitFlag = false;
driverWait.notifyAll();
task.reset();

taskThread.join();
}

TEST_F(HashJoinTest, probeReclaimedMemoryReport) {
constexpr int64_t kMaxBytes = 1LL << 30; // 1GB
const int32_t numBuildVectors = 3;
std::vector<RowVectorPtr> buildVectors;
for (int32_t i = 0; i < numBuildVectors; ++i) {
VectorFuzzer fuzzer({.vectorSize = 200}, pool());
buildVectors.push_back(fuzzer.fuzzRow(buildType_));
}

const int32_t numProbeVectors = 3;
std::vector<RowVectorPtr> probeVectors;
for (int32_t i = 0; i < numProbeVectors; ++i) {
VectorFuzzer fuzzer({.vectorSize = 200}, pool());
probeVectors.push_back(fuzzer.fuzzRow(probeType_));
}

createDuckDbTable("t", probeVectors);
createDuckDbTable("u", buildVectors);

auto tempDirectory = exec::test::TempDirectoryPath::create();
auto queryPool = memory::memoryManager()->addRootPool(
"", kMaxBytes, memory::MemoryReclaimer::create());

auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.values(probeVectors, true)
.hashJoin(
{"t_k1"},
{"u_k1"},
PlanBuilder(planNodeIdGenerator)
.values(buildVectors, true)
.planNode(),
"",
concat(probeType_->names(), buildType_->names()))
.planNode();

folly::EventCount driverWait;
std::atomic_bool driverWaitFlag{true};
folly::EventCount taskWait;
std::atomic_bool taskWaitFlag{true};

Operator* op{nullptr};
std::atomic_int probeInputCount{0};
SCOPED_TESTVALUE_SET(
"facebook::velox::exec::Driver::runInternal::addInput",
std::function<void(Operator*)>(([&](Operator* testOp) {
if (testOp->operatorType() != "HashProbe") {
return;
}
op = testOp;

ASSERT_TRUE(op->canReclaim());
if (probeInputCount++ != 1) {
return;
}
auto* driver = op->testingOperatorCtx()->driver();
SuspendedSection suspendedSection(driver);
taskWaitFlag = false;
taskWait.notifyAll();
driverWait.await([&]() { return !driverWaitFlag.load(); });
})));

std::thread taskThread([&]() {
HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.numDrivers(1)
.planNode(plan)
.queryPool(std::move(queryPool))
.injectSpill(false)
.spillDirectory(tempDirectory->getPath())
.referenceQuery(
"SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1")
.config(core::QueryConfig::kSpillStartPartitionBit, "29")
.verifier([&](const std::shared_ptr<Task>& task, bool /*unused*/) {
const auto statsPair = taskSpilledStats(*task);
// The spill triggered at the probe side.
ASSERT_EQ(statsPair.first.spilledBytes, 0);
ASSERT_EQ(statsPair.first.spilledPartitions, 0);
ASSERT_GT(statsPair.second.spilledBytes, 0);
ASSERT_EQ(statsPair.second.spilledPartitions, 16);
})
.run();
});

taskWait.await([&]() { return !taskWaitFlag.load(); });
ASSERT_TRUE(op != nullptr);
auto task = op->testingOperatorCtx()->task();
auto* nodePool = op->pool()->parent();
const auto nodeMemoryUsage = nodePool->reservedBytes();
{
memory::ScopedMemoryArbitrationContext ctx(op->pool());
const uint64_t reclaimedBytes = task->pool()->reclaim(
task->pool()->capacity(), 1'000'000, reclaimerStats_);
ASSERT_GT(reclaimedBytes, 0);
ASSERT_EQ(nodeMemoryUsage - nodePool->reservedBytes(), reclaimedBytes);
}
// Verify all the memory has been freed.
ASSERT_EQ(nodePool->reservedBytes(), 0);

driverWaitFlag = false;
driverWait.notifyAll();
task.reset();

taskThread.join();
}
} // namespace

0 comments on commit 01f2833

Please sign in to comment.