diff --git a/velox/common/memory/Memory.cpp b/velox/common/memory/Memory.cpp index b73575d546b2c..7c7ffd42574f5 100644 --- a/velox/common/memory/Memory.cpp +++ b/velox/common/memory/Memory.cpp @@ -131,9 +131,6 @@ MemoryManager::MemoryManager(const MemoryManagerOptions& options) debugEnabled_(options.debugEnabled), coreOnAllocationFailureEnabled_(options.coreOnAllocationFailureEnabled), poolDestructionCb_([&](MemoryPool* pool) { dropPool(pool); }), - poolGrowCb_([&](MemoryPool* pool, uint64_t targetBytes) { - return growPool(pool, targetBytes); - }), sysRoot_{std::make_shared( this, std::string(kSysRootName), @@ -141,7 +138,6 @@ MemoryManager::MemoryManager(const MemoryManagerOptions& options) nullptr, nullptr, nullptr, - nullptr, // NOTE: the default root memory pool has no capacity limit, and it is // used for system usage in production such as disk spilling. MemoryPool::Options{ @@ -268,7 +264,6 @@ std::shared_ptr MemoryManager::addRootPool( MemoryPool::Kind::kAggregate, nullptr, std::move(reclaimer), - poolGrowCb_, poolDestructionCb_, options); pools_.emplace(poolName, pool); @@ -290,12 +285,6 @@ std::shared_ptr MemoryManager::addLeafPool( return sysRoot_->addLeafChild(poolName, threadSafe, nullptr); } -bool MemoryManager::growPool(MemoryPool* pool, uint64_t incrementBytes) { - VELOX_CHECK_NOT_NULL(pool); - VELOX_CHECK_NE(pool->capacity(), kMaxMemory); - return arbitrator_->growCapacity(pool, incrementBytes); -} - uint64_t MemoryManager::shrinkPools( uint64_t targetBytes, bool allowSpill, diff --git a/velox/common/memory/Memory.h b/velox/common/memory/Memory.h index de0aa5e4fe46f..a4acb6d26e7c0 100644 --- a/velox/common/memory/Memory.h +++ b/velox/common/memory/Memory.h @@ -307,10 +307,6 @@ class MemoryManager { private: void dropPool(MemoryPool* pool); - // Invoked to grow a memory pool's free capacity with at least - // 'incrementBytes'. The function returns true on success, otherwise false. - bool growPool(MemoryPool* pool, uint64_t incrementBytes); - // Returns the shared references to all the alive memory pools in 'pools_'. std::vector> getAlivePools() const; @@ -328,8 +324,6 @@ class MemoryManager { // tracked by 'pools_'. It is invoked on the root pool destruction and removes // the pool from 'pools_'. const MemoryPoolImpl::DestructionCallback poolDestructionCb_; - // Callback invoked by the root memory pool to request memory capacity growth. - const MemoryPoolImpl::GrowCapacityCallback poolGrowCb_; const std::shared_ptr sysRoot_; const std::shared_ptr spillPool_; diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index edac6dc1973f1..d57e6e575d5d2 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -497,6 +497,17 @@ const MemoryArbitrationContext* memoryArbitrationContext() { return arbitrationCtx; } +ScopedMemoryPoolArbitration::ScopedMemoryPoolArbitration(MemoryPool* pool) + : pool_(pool) { + VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK(pool_->isLeaf()); + pool_->enterArbitration(); +} + +ScopedMemoryPoolArbitration::~ScopedMemoryPoolArbitration() { + pool_->leaveArbitration(); +} + bool underMemoryArbitration() { return memoryArbitrationContext() != nullptr; } @@ -515,20 +526,13 @@ void testingRunArbitration( MemoryPool* pool, uint64_t targetBytes, bool allowSpill) { - pool->enterArbitration(); - // Seraliazes the testing arbitration injection to make sure that the previous - // op has left arbitration section before starting the next one. This is - // guaranteed by the production code for operation triggered arbitration. - static std::mutex lock; { - std::lock_guard l(lock); + ScopedMemoryPoolArbitration scopedArbitration{pool}; static_cast(pool)->testingManager()->shrinkPools( targetBytes, allowSpill); - pool->leaveArbitration(); } - // This function is simulating an operator triggered arbitration which - // would check if the query has been aborted after finish arbitration by the - // memory pool capacity grow path. + // This function is simulating an arbitration triggered by growCapacity, which + // would check this. static_cast(pool)->testingCheckIfAborted(); } diff --git a/velox/common/memory/MemoryArbitrator.h b/velox/common/memory/MemoryArbitrator.h index be4768ab0d95c..bf14434222548 100644 --- a/velox/common/memory/MemoryArbitrator.h +++ b/velox/common/memory/MemoryArbitrator.h @@ -411,10 +411,10 @@ class ScopedMemoryArbitrationContext { public: explicit ScopedMemoryArbitrationContext(const MemoryPool* requestor); - // Can be used to restore a previously captured MemoryArbitrationContext. - // contextToRestore can be nullptr if there was no context at the time it was - // captured, in which case arbitrationCtx is unchanged upon - // contruction/destruction of this object. + /// Can be used to restore a previously captured MemoryArbitrationContext. + /// contextToRestore can be nullptr if there was no context at the time it was + /// captured, in which case arbitrationCtx is unchanged upon + /// contruction/destruction of this object. explicit ScopedMemoryArbitrationContext( const MemoryArbitrationContext* contextToRestore); @@ -425,6 +425,17 @@ class ScopedMemoryArbitrationContext { MemoryArbitrationContext currentArbitrationCtx_; }; +/// Object used to prepare arbitration on a memory pool. +class ScopedMemoryPoolArbitration { + public: + explicit ScopedMemoryPoolArbitration(MemoryPool* pool); + + ~ScopedMemoryPoolArbitration(); + + private: + MemoryPool* const pool_; +}; + /// Returns the memory arbitration context set by a per-thread local variable if /// the running thread is under memory arbitration processing. const MemoryArbitrationContext* memoryArbitrationContext(); diff --git a/velox/common/memory/MemoryPool.cpp b/velox/common/memory/MemoryPool.cpp index 53b1674289668..f77efa617c9b1 100644 --- a/velox/common/memory/MemoryPool.cpp +++ b/velox/common/memory/MemoryPool.cpp @@ -413,13 +413,12 @@ MemoryPoolImpl::MemoryPoolImpl( Kind kind, std::shared_ptr parent, std::unique_ptr reclaimer, - GrowCapacityCallback growCapacityCb, DestructionCallback destructionCb, const Options& options) : MemoryPool{name, kind, parent, options}, manager_{memoryManager}, allocator_{manager_->allocator()}, - growCapacityCb_(std::move(growCapacityCb)), + arbitrator_{manager_->arbitrator()}, destructionCb_(std::move(destructionCb)), debugPoolNameRegex_(debugEnabled_ ? *(debugPoolNameRegex().rlock()) : ""), reclaimer_(std::move(reclaimer)), @@ -428,8 +427,8 @@ MemoryPoolImpl::MemoryPoolImpl( capacity_(parent_ != nullptr ? kMaxMemory : 0) { VELOX_CHECK(options.threadSafe || isLeaf()); VELOX_CHECK( - isRoot() || (destructionCb_ == nullptr && growCapacityCb_ == nullptr), - "Only root memory pool allows to set destruction and capacity grow callbacks: {}", + isRoot() || destructionCb_ == nullptr, + "Only root memory pool allows to set destruction callbacks: {}", name_); } @@ -733,7 +732,6 @@ std::shared_ptr MemoryPoolImpl::genChild( parent, std::move(reclaimer), nullptr, - nullptr, Options{ .alignment = alignment_, .trackUsage = trackUsage_, @@ -842,8 +840,7 @@ bool MemoryPoolImpl::incrementReservationThreadSafe( VELOX_CHECK_NULL(parent_); - ++numCapacityGrowths_; - if (growCapacityCb_(requestor, size)) { + if (growCapacity(requestor, size)) { TestValue::adjust( "facebook::velox::memory::MemoryPoolImpl::incrementReservationThreadSafe::AfterGrowCallback", this); @@ -865,6 +862,14 @@ bool MemoryPoolImpl::incrementReservationThreadSafe( treeMemoryUsage())); } +bool MemoryPoolImpl::growCapacity(MemoryPool* requestor, uint64_t size) { + VELOX_CHECK(requestor->isLeaf()); + ++numCapacityGrowths_; + + ScopedMemoryPoolArbitration scopedArbitration{requestor}; + return arbitrator_->growCapacity(this, size); +} + bool MemoryPoolImpl::maybeIncrementReservation(uint64_t size) { std::lock_guard l(mutex_); if (isRoot()) { diff --git a/velox/common/memory/MemoryPool.h b/velox/common/memory/MemoryPool.h index 70bc510d81d1f..8d9edfffb30e2 100644 --- a/velox/common/memory/MemoryPool.h +++ b/velox/common/memory/MemoryPool.h @@ -372,16 +372,6 @@ class MemoryPool : public std::enable_shared_from_this { /// Returns the memory reclaimer of this memory pool if not null. virtual MemoryReclaimer* reclaimer() const = 0; - /// Invoked by the memory arbitrator to enter memory arbitration processing. - /// It is a noop if 'reclaimer' is not set, otherwise invoke the reclaimer's - /// corresponding method. - virtual void enterArbitration() = 0; - - /// Invoked by the memory arbitrator to leave memory arbitration processing. - /// It is a noop if 'reclaimer' is not set, otherwise invoke the reclaimer's - /// corresponding method. - virtual void leaveArbitration() noexcept = 0; - /// Function estimates the number of reclaimable bytes and returns in /// 'reclaimableBytes'. If the 'reclaimer' is not set, the function returns /// std::nullopt. Otherwise, it will invoke the corresponding method of the @@ -499,6 +489,16 @@ class MemoryPool : public std::enable_shared_from_this { protected: static constexpr uint64_t kMB = 1 << 20; + /// Invoked by the memory arbitrator to enter memory arbitration processing. + /// It is a noop if 'reclaimer' is not set, otherwise invoke the reclaimer's + /// corresponding method. + virtual void enterArbitration() = 0; + + /// Invoked by the memory arbitrator to leave memory arbitration processing. + /// It is a noop if 'reclaimer' is not set, otherwise invoke the reclaimer's + /// corresponding method. + virtual void leaveArbitration() noexcept = 0; + /// Invoked to free up to the specified amount of free memory by reducing /// this memory pool's capacity without actually freeing any used memory. The /// function returns the actually freed memory capacity in bytes. If @@ -557,6 +557,7 @@ class MemoryPool : public std::enable_shared_from_this { friend class velox::exec::ParallelMemoryReclaimer; friend class MemoryManager; friend class MemoryArbitrator; + friend class ScopedMemoryPoolArbitration; VELOX_FRIEND_TEST(MemoryPoolTest, shrinkAndGrowAPIs); VELOX_FRIEND_TEST(MemoryPoolTest, grow); @@ -573,11 +574,6 @@ class MemoryPoolImpl : public MemoryPool { /// The callback invoked on the root memory pool destruction. It is set by /// memory manager to removes the pool from 'MemoryManager::pools_'. using DestructionCallback = std::function; - /// The callback invoked when the used memory reservation of the root memory - /// pool exceed its capacity. It is set by memory manager to grow the memory - /// pool capacity. The callback returns true if the capacity growth succeeds, - /// otherwise false. - using GrowCapacityCallback = std::function; MemoryPoolImpl( MemoryManager* manager, @@ -585,7 +581,6 @@ class MemoryPoolImpl : public MemoryPool { Kind kind, std::shared_ptr parent, std::unique_ptr reclaimer, - GrowCapacityCallback growCapacityCb, DestructionCallback destructionCb, const Options& options = Options{}); @@ -651,10 +646,6 @@ class MemoryPoolImpl : public MemoryPool { MemoryReclaimer* reclaimer() const override; - void enterArbitration() override; - - void leaveArbitration() noexcept override; - std::optional reclaimableBytes() const override; uint64_t reclaim( @@ -731,6 +722,10 @@ class MemoryPoolImpl : public MemoryPool { } private: + void enterArbitration() override; + + void leaveArbitration() noexcept override; + uint64_t shrink(uint64_t targetBytes = 0) override; bool grow(uint64_t growBytes, uint64_t reservationBytes = 0) override; @@ -872,6 +867,11 @@ class MemoryPoolImpl : public MemoryPool { void releaseThreadSafe(uint64_t size, bool releaseOnly); + // Invoked to grow capacity of the root memory pool from the memory + // arbitrator. 'requestor' is the leaf memory pool that triggers the memory + // capacity growth. 'size' is the memory capacity growth in bytes. + bool growCapacity(MemoryPool* requestor, uint64_t size); + FOLLY_ALWAYS_INLINE void releaseNonThreadSafe( uint64_t size, bool releaseOnly) { @@ -999,7 +999,7 @@ class MemoryPoolImpl : public MemoryPool { MemoryManager* const manager_; MemoryAllocator* const allocator_; - const GrowCapacityCallback growCapacityCb_; + MemoryArbitrator* const arbitrator_; const DestructionCallback destructionCb_; // Regex for filtering on 'name_' when debug mode is enabled. This allows us diff --git a/velox/common/memory/SharedArbitrator.cpp b/velox/common/memory/SharedArbitrator.cpp index e9d429be0d014..8a71f2763ec91 100644 --- a/velox/common/memory/SharedArbitrator.cpp +++ b/velox/common/memory/SharedArbitrator.cpp @@ -204,7 +204,7 @@ void SharedArbitrator::getCandidates( std::shared_lock guard{poolLock_}; op->candidates.reserve(candidates_.size()); for (const auto& candidate : candidates_) { - const bool selfCandidate = op->requestRoot == candidate.first; + const bool selfCandidate = op->requestPool == candidate.first; std::shared_ptr pool = candidate.second.lock(); if (pool == nullptr) { VELOX_CHECK(!selfCandidate); @@ -470,18 +470,18 @@ bool SharedArbitrator::runLocalArbitration( if (!ensureCapacity(op)) { updateArbitrationFailureStats(); - VELOX_MEM_LOG(ERROR) << "Can't grow " << op->requestRoot->name() + VELOX_MEM_LOG(ERROR) << "Can't grow " << op->requestPool->name() << " capacity to " << succinctBytes( - op->requestRoot->capacity() + op->requestBytes) + op->requestPool->capacity() + op->requestBytes) << " which exceeds its max capacity " - << succinctBytes(op->requestRoot->maxCapacity()) + << succinctBytes(op->requestPool->maxCapacity()) << ", current capacity " - << succinctBytes(op->requestRoot->capacity()) + << succinctBytes(op->requestPool->capacity()) << ", request " << succinctBytes(op->requestBytes); return false; } - VELOX_CHECK(!op->requestRoot->aborted()); + VELOX_CHECK(!op->requestPool->aborted()); if (maybeGrowFromSelf(op)) { return true; @@ -499,7 +499,7 @@ bool SharedArbitrator::runLocalArbitration( } }); if (freedBytes >= op->requestBytes) { - checkedGrow(op->requestRoot, freedBytes, op->requestBytes); + checkedGrow(op->requestPool, freedBytes, op->requestBytes); freedBytes = 0; return true; } @@ -510,20 +510,20 @@ bool SharedArbitrator::runLocalArbitration( reclaimFreeMemoryFromCandidates(op, maxGrowTarget - freedBytes, true); if (freedBytes >= op->requestBytes) { const uint64_t bytesToGrow = std::min(maxGrowTarget, freedBytes); - checkedGrow(op->requestRoot, bytesToGrow, op->requestBytes); + checkedGrow(op->requestPool, bytesToGrow, op->requestBytes); freedBytes -= bytesToGrow; return true; } VELOX_CHECK_LT(freedBytes, maxGrowTarget); if (!globalArbitrationEnabled_) { - freedBytes += reclaim(op->requestRoot, maxGrowTarget - freedBytes, true); + freedBytes += reclaim(op->requestPool, maxGrowTarget - freedBytes, true); } checkIfAborted(op); if (freedBytes >= op->requestBytes) { const uint64_t bytesToGrow = std::min(maxGrowTarget, freedBytes); - checkedGrow(op->requestRoot, bytesToGrow, op->requestBytes); + checkedGrow(op->requestPool, bytesToGrow, op->requestBytes); freedBytes -= bytesToGrow; return true; } @@ -557,14 +557,14 @@ bool SharedArbitrator::runGlobalArbitration(ArbitrationOperation* op) { if (attempts > 0) { break; } - VELOX_CHECK(!op->requestRoot->aborted()); + VELOX_CHECK(!op->requestPool->aborted()); if (!handleOOM(op)) { break; } } VELOX_MEM_LOG(ERROR) << "Failed to arbitrate sufficient memory for memory pool " - << op->requestRoot->name() << ", request " + << op->requestPool->name() << ", request " << succinctBytes(op->requestBytes) << " after " << attempts << " attempts, Arbitrator state: " << toString(); updateArbitrationFailureStats(); @@ -576,21 +576,21 @@ void SharedArbitrator::getGrowTargets( uint64_t& maxGrowTarget, uint64_t& minGrowTarget) { maxGrowTarget = std::min( - maxGrowCapacity(*op->requestRoot), + maxGrowCapacity(*op->requestPool), std::max(memoryPoolTransferCapacity_, op->requestBytes)); - minGrowTarget = minGrowCapacity(*op->requestRoot); + minGrowTarget = minGrowCapacity(*op->requestPool); } void SharedArbitrator::checkIfAborted(ArbitrationOperation* op) { - if (op->requestRoot->aborted()) { + if (op->requestPool->aborted()) { updateArbitrationFailureStats(); VELOX_MEM_POOL_ABORTED("The requestor pool has been aborted"); } } bool SharedArbitrator::maybeGrowFromSelf(ArbitrationOperation* op) { - if (op->requestRoot->freeBytes() >= op->requestBytes) { - if (growPool(op->requestRoot, 0, op->requestBytes)) { + if (op->requestPool->freeBytes() >= op->requestBytes) { + if (growPool(op->requestPool, 0, op->requestBytes)) { return true; } } @@ -598,13 +598,13 @@ bool SharedArbitrator::maybeGrowFromSelf(ArbitrationOperation* op) { } bool SharedArbitrator::checkCapacityGrowth(ArbitrationOperation* op) const { - return (maxGrowCapacity(*op->requestRoot) >= op->requestBytes) && - (capacityAfterGrowth(*op->requestRoot, op->requestBytes) <= capacity_); + return (maxGrowCapacity(*op->requestPool) >= op->requestBytes) && + (capacityAfterGrowth(*op->requestPool, op->requestBytes) <= capacity_); } bool SharedArbitrator::ensureCapacity(ArbitrationOperation* op) { if ((op->requestBytes > capacity_) || - (op->requestBytes > op->requestRoot->maxCapacity())) { + (op->requestBytes > op->requestPool->maxCapacity())) { return false; } if (checkCapacityGrowth(op)) { @@ -612,12 +612,12 @@ bool SharedArbitrator::ensureCapacity(ArbitrationOperation* op) { } const uint64_t reclaimedBytes = - reclaim(op->requestRoot, op->requestBytes, true); + reclaim(op->requestPool, op->requestBytes, true); // NOTE: return the reclaimed bytes back to the arbitrator and let the memory // arbitration process to grow the requestor's memory capacity accordingly. incrementFreeCapacity(reclaimedBytes); // Check if the requestor has been aborted in reclaim operation above. - if (op->requestRoot->aborted()) { + if (op->requestPool->aborted()) { updateArbitrationFailureStats(); VELOX_MEM_POOL_ABORTED("The requestor pool has been aborted"); } @@ -626,24 +626,24 @@ bool SharedArbitrator::ensureCapacity(ArbitrationOperation* op) { bool SharedArbitrator::handleOOM(ArbitrationOperation* op) { MemoryPool* victim = findCandidateWithLargestCapacity( - op->requestRoot, op->requestBytes, op->candidates) + op->requestPool, op->requestBytes, op->candidates) .pool.get(); - if (op->requestRoot == victim) { + if (op->requestPool == victim) { VELOX_MEM_LOG(ERROR) - << "Requestor memory pool " << op->requestRoot->name() + << "Requestor memory pool " << op->requestPool->name() << " is selected as victim memory pool so fail the memory arbitration"; return false; } VELOX_MEM_LOG(WARNING) << "Aborting victim memory pool " << victim->name() << " to free up memory for requestor " - << op->requestRoot->name(); + << op->requestPool->name(); try { - if (victim == op->requestRoot) { + if (victim == op->requestPool) { VELOX_MEM_POOL_CAP_EXCEEDED( - memoryPoolAbortMessage(victim, op->requestRoot, op->requestBytes)); + memoryPoolAbortMessage(victim, op->requestPool, op->requestBytes)); } else { VELOX_MEM_POOL_ABORTED( - memoryPoolAbortMessage(victim, op->requestRoot, op->requestBytes)); + memoryPoolAbortMessage(victim, op->requestPool, op->requestBytes)); } } catch (VeloxRuntimeError&) { abort(victim, std::current_exception()); @@ -668,7 +668,7 @@ void SharedArbitrator::checkedGrow( } bool SharedArbitrator::arbitrateMemory(ArbitrationOperation* op) { - VELOX_CHECK(!op->requestRoot->aborted()); + VELOX_CHECK(!op->requestPool->aborted()); uint64_t maxGrowTarget{0}; uint64_t minGrowTarget{0}; getGrowTargets(op, maxGrowTarget, minGrowTarget); @@ -681,7 +681,7 @@ bool SharedArbitrator::arbitrateMemory(ArbitrationOperation* op) { } }); if (freedBytes >= op->requestBytes) { - checkedGrow(op->requestRoot, freedBytes, op->requestBytes); + checkedGrow(op->requestPool, freedBytes, op->requestBytes); freedBytes = 0; return true; } @@ -694,7 +694,7 @@ bool SharedArbitrator::arbitrateMemory(ArbitrationOperation* op) { reclaimFreeMemoryFromCandidates(op, maxGrowTarget - freedBytes, false); if (freedBytes >= op->requestBytes) { const uint64_t bytesToGrow = std::min(maxGrowTarget, freedBytes); - checkedGrow(op->requestRoot, bytesToGrow, op->requestBytes); + checkedGrow(op->requestPool, bytesToGrow, op->requestBytes); freedBytes -= bytesToGrow; return true; } @@ -707,7 +707,7 @@ bool SharedArbitrator::arbitrateMemory(ArbitrationOperation* op) { if (freedBytes < op->requestBytes) { VELOX_MEM_LOG(WARNING) << "Failed to arbitrate sufficient memory for memory pool " - << op->requestRoot->name() << ", request " + << op->requestPool->name() << ", request " << succinctBytes(op->requestBytes) << ", only " << succinctBytes(freedBytes) << " has been freed, Arbitrator state: " << toString(); @@ -715,7 +715,7 @@ bool SharedArbitrator::arbitrateMemory(ArbitrationOperation* op) { } const uint64_t bytesToGrow = std::min(freedBytes, maxGrowTarget); - checkedGrow(op->requestRoot, bytesToGrow, op->requestBytes); + checkedGrow(op->requestPool, bytesToGrow, op->requestBytes); freedBytes -= bytesToGrow; return true; } @@ -734,7 +734,7 @@ uint64_t SharedArbitrator::reclaimFreeMemoryFromCandidates( if (candidate.freeBytes == 0) { break; } - if (isLocalArbitration && (candidate.pool.get() != op->requestRoot) && + if (isLocalArbitration && (candidate.pool.get() != op->requestPool) && isUnderArbitrationLocked(candidate.pool.get())) { // If the reclamation is for local arbitration and the candidate pool is // also under arbitration processing, then we can't reclaim from the @@ -745,7 +745,7 @@ uint64_t SharedArbitrator::reclaimFreeMemoryFromCandidates( const int64_t bytesToReclaim = std::min( reclaimTargetBytes - reclaimedBytes, reclaimableFreeCapacity( - *candidate.pool, candidate.pool.get() == op->requestRoot)); + *candidate.pool, candidate.pool.get() == op->requestPool)); if (bytesToReclaim <= 0) { continue; } @@ -772,7 +772,7 @@ void SharedArbitrator::reclaimUsedMemoryFromCandidatesBySpill( freedBytes += reclaim(candidate.pool.get(), op->requestBytes - freedBytes, false); if ((freedBytes >= op->requestBytes) || - (op->requestRoot != nullptr && op->requestRoot->aborted())) { + (op->requestPool != nullptr && op->requestPool->aborted())) { break; } } @@ -945,7 +945,6 @@ SharedArbitrator::ScopedArbitration::ScopedArbitration( startTime_(std::chrono::steady_clock::now()) { VELOX_CHECK_NOT_NULL(arbitrator_); VELOX_CHECK_NOT_NULL(operation_); - operation_->enterArbitration(); if (arbitrator_->arbitrationStateCheckCb_ != nullptr && operation_->requestPool != nullptr) { arbitrator_->arbitrationStateCheckCb_(*operation_->requestPool); @@ -954,7 +953,6 @@ SharedArbitrator::ScopedArbitration::ScopedArbitration( } SharedArbitrator::ScopedArbitration::~ScopedArbitration() { - operation_->leaveArbitration(); arbitrator_->finishArbitration(operation_); // Report arbitration operation stats. @@ -998,18 +996,6 @@ SharedArbitrator::ScopedArbitration::~ScopedArbitration() { } } -void SharedArbitrator::ArbitrationOperation::enterArbitration() { - if (requestPool != nullptr) { - requestPool->enterArbitration(); - } -} - -void SharedArbitrator::ArbitrationOperation::leaveArbitration() { - if (requestPool != nullptr) { - requestPool->leaveArbitration(); - } -} - void SharedArbitrator::startArbitration(ArbitrationOperation* op) { updateArbitrationRequestStats(); ContinueFuture waitPromise{ContinueFuture::makeEmpty()}; @@ -1017,16 +1003,14 @@ void SharedArbitrator::startArbitration(ArbitrationOperation* op) { std::lock_guard l(stateLock_); ++numPending_; if (op->requestPool != nullptr) { - auto it = arbitrationQueues_.find(op->requestRoot); + auto it = arbitrationQueues_.find(op->requestPool); if (it != arbitrationQueues_.end()) { - it->second->waitPromises.emplace_back(fmt::format( - "Wait for arbitration {}/{}", - op->requestPool->name(), - op->requestRoot->name())); + it->second->waitPromises.emplace_back( + fmt::format("Wait for arbitration {}", op->requestPool->name())); waitPromise = it->second->waitPromises.back().getSemiFuture(); } else { arbitrationQueues_.emplace( - op->requestRoot, std::make_unique(op)); + op->requestPool, std::make_unique(op)); } } } @@ -1051,12 +1035,11 @@ void SharedArbitrator::finishArbitration(ArbitrationOperation* op) { VELOX_CHECK_GT(numPending_, 0); --numPending_; if (op->requestPool != nullptr) { - auto it = arbitrationQueues_.find(op->requestRoot); + auto it = arbitrationQueues_.find(op->requestPool); VELOX_CHECK( it != arbitrationQueues_.end(), - "{}/{} not found", - op->requestPool->name(), - op->requestRoot->name()); + "{} not found", + op->requestPool->name()); auto* runningArbitration = it->second.get(); if (runningArbitration->waitPromises.empty()) { arbitrationQueues_.erase(it); diff --git a/velox/common/memory/SharedArbitrator.h b/velox/common/memory/SharedArbitrator.h index f513a12825b00..d1bc4d27da790 100644 --- a/velox/common/memory/SharedArbitrator.h +++ b/velox/common/memory/SharedArbitrator.h @@ -165,7 +165,6 @@ class SharedArbitrator : public memory::MemoryArbitrator { // Contains the execution state of an arbitration operation. struct ArbitrationOperation { MemoryPool* const requestPool; - MemoryPool* const requestRoot; const uint64_t requestBytes; // The start time of this arbitration operation. const std::chrono::steady_clock::time_point startTime; @@ -187,17 +186,15 @@ class SharedArbitrator : public memory::MemoryArbitrator { ArbitrationOperation(MemoryPool* _requestor, uint64_t _requestBytes) : requestPool(_requestor), - requestRoot(_requestor == nullptr ? nullptr : _requestor->root()), requestBytes(_requestBytes), - startTime(std::chrono::steady_clock::now()) {} + startTime(std::chrono::steady_clock::now()) { + VELOX_CHECK(requestPool == nullptr || requestPool->isRoot()); + } uint64_t waitTimeUs() const { return localArbitrationQueueTimeUs + localArbitrationLockWaitTimeUs + globalArbitrationLockWaitTimeUs; } - - void enterArbitration(); - void leaveArbitration(); }; // Used to start and finish an arbitration operation initiated from a memory diff --git a/velox/common/memory/tests/MemoryArbitratorTest.cpp b/velox/common/memory/tests/MemoryArbitratorTest.cpp index bb7f0bd781381..990dc9ff4eb80 100644 --- a/velox/common/memory/tests/MemoryArbitratorTest.cpp +++ b/velox/common/memory/tests/MemoryArbitratorTest.cpp @@ -526,13 +526,28 @@ class MockLeafMemoryReclaimer : public MemoryReclaimer { public: explicit MockLeafMemoryReclaimer( std::atomic& totalUsedBytes, - bool reclaimable = true) - : reclaimable_(reclaimable), totalUsedBytes_(totalUsedBytes) {} + bool reclaimable = true, + bool* underArbitration = nullptr) + : reclaimable_(reclaimable), + underArbitration_(underArbitration), + totalUsedBytes_(totalUsedBytes) {} ~MockLeafMemoryReclaimer() override { VELOX_CHECK(allocations_.empty()); } + virtual void enterArbitration() override { + if (underArbitration_ != nullptr) { + *underArbitration_ = true; + } + } + + virtual void leaveArbitration() noexcept override { + if (underArbitration_ != nullptr) { + *underArbitration_ = false; + } + } + bool reclaimableBytes(const MemoryPool& pool, uint64_t& bytes) const override { VELOX_CHECK_EQ(pool.name(), pool_->name()); @@ -612,6 +627,7 @@ class MockLeafMemoryReclaimer : public MemoryReclaimer { } const bool reclaimable_{true}; + bool* const underArbitration_{nullptr}; std::atomic_uint64_t& totalUsedBytes_; std::atomic_int reclaimCount_{0}; mutable std::mutex mu_; @@ -1027,6 +1043,43 @@ TEST_F(MemoryReclaimerTest, arbitrationContext) { ASSERT_TRUE(memoryArbitrationContext() == nullptr); } +TEST_F(MemoryReclaimerTest, scopedMemoryPoolArbitration) { + auto root = memory::memoryManager()->addRootPool( + "scopedArbitration", kMaxMemory, MemoryReclaimer::create()); + std::atomic totalUsedBytes{0}; + bool underArbitration{false}; + auto leafChild = root->addLeafChild( + "scopedArbitration", + true, + std::make_unique( + totalUsedBytes, true, &underArbitration)); + ASSERT_FALSE(underArbitration); + { + ScopedMemoryPoolArbitration scopedArbitration(leafChild.get()); + ASSERT_TRUE(memoryArbitrationContext() == nullptr); + ASSERT_TRUE(underArbitration); + } + ASSERT_FALSE(underArbitration); + ASSERT_TRUE(memoryArbitrationContext() == nullptr); + + std::thread abitrationThread([&]() { + ASSERT_TRUE(memoryArbitrationContext() == nullptr); + { + ScopedMemoryPoolArbitration scopedArbitration(leafChild.get()); + ASSERT_TRUE(memoryArbitrationContext() == nullptr); + ASSERT_TRUE(underArbitration); + } + ASSERT_FALSE(underArbitration); + ASSERT_TRUE(memoryArbitrationContext() == nullptr); + }); + abitrationThread.join(); + + ASSERT_FALSE(underArbitration); + + + ASSERT_TRUE(memoryArbitrationContext() == nullptr); +} + TEST_F(MemoryReclaimerTest, concurrentRandomMockReclaims) { auto root = memory::memoryManager()->addRootPool( "concurrentRandomMockReclaims", kMaxMemory, MemoryReclaimer::create()); diff --git a/velox/common/memory/tests/MemoryPoolTest.cpp b/velox/common/memory/tests/MemoryPoolTest.cpp index d64c2225148ff..e96a7d0f98a6c 100644 --- a/velox/common/memory/tests/MemoryPoolTest.cpp +++ b/velox/common/memory/tests/MemoryPoolTest.cpp @@ -161,7 +161,6 @@ TEST_P(MemoryPoolTest, ctor) { MemoryPool::Kind::kAggregate, nullptr, nullptr, - nullptr, nullptr); // We can't construct an aggregate memory pool with non-thread safe. ASSERT_ANY_THROW(std::make_shared( @@ -171,7 +170,6 @@ TEST_P(MemoryPoolTest, ctor) { nullptr, nullptr, nullptr, - nullptr, MemoryPool::Options{.threadSafe = false})); ASSERT_EQ("fake_root", fakeRoot->name()); ASSERT_EQ( diff --git a/velox/common/memory/tests/MockSharedArbitratorTest.cpp b/velox/common/memory/tests/MockSharedArbitratorTest.cpp index 19da53e346fd9..ff037691e763e 100644 --- a/velox/common/memory/tests/MockSharedArbitratorTest.cpp +++ b/velox/common/memory/tests/MockSharedArbitratorTest.cpp @@ -636,8 +636,8 @@ TEST_F(MockSharedArbitrationTest, arbitrationStateCheck) { const int minPoolCapacity = 32 * MB; std::atomic checkCount{0}; MemoryArbitrationStateCheckCB checkCountCb = [&](MemoryPool& pool) { - const std::string re("MockTask.*"); - ASSERT_TRUE(RE2::FullMatch(pool.name(), re)); + const std::string re("RootPool.*"); + ASSERT_TRUE(RE2::FullMatch(pool.name(), re)) << pool.name(); ++checkCount; }; setupMemory(memCapacity, 0, 0, 0, 0, checkCountCb);