Skip to content

Commit

Permalink
[native]Improve prestissimo tests by inheriting from velox operator t…
Browse files Browse the repository at this point in the history
…est base
  • Loading branch information
xiaoxmeng authored and tanjialiang committed Nov 7, 2024
1 parent 261613b commit 0a90d6e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
#include "presto_cpp/main/ServerOperation.h"
#include <gtest/gtest.h>
#include "presto_cpp/main/PrestoServerOperations.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/common/memory/Memory.h"
#include "velox/connectors/hive/HiveConnector.h"
#include "velox/exec/tests/utils/OperatorTestBase.h"
#include "velox/exec/tests/utils/PlanBuilder.h"

DECLARE_bool(velox_memory_leak_check_enabled);
Expand All @@ -26,10 +26,14 @@ using namespace facebook::velox;

namespace facebook::presto {

class ServerOperationTest : public testing::Test {
class ServerOperationTest : public exec::test::OperatorTestBase {
void SetUp() override {
FLAGS_velox_memory_leak_check_enabled = true;
memory::MemoryManager::testingSetInstance({});
exec::test::OperatorTestBase::SetUp();
}

void TearDown() override {
exec::test::OperatorTestBase::TearDown();
}
};

Expand Down Expand Up @@ -238,4 +242,4 @@ TEST_F(ServerOperationTest, systemConfigEndpoint) {
EXPECT_EQ(getPropertyResponse, "16\n");
}

} // namespace facebook::presto
} // namespace facebook::presto
42 changes: 9 additions & 33 deletions presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "velox/dwio/dwrf/writer/Writer.h"
#include "velox/exec/Exchange.h"
#include "velox/exec/Values.h"
#include "velox/exec/tests/utils/OperatorTestBase.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/QueryAssertions.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"
Expand Down Expand Up @@ -181,9 +182,10 @@ void setAggregationSpillConfig(
queryConfigs.emplace(core::QueryConfig::kAggregationSpillEnabled, "true");
}

class TaskManagerTest : public testing::Test {
class TaskManagerTest : public exec::test::OperatorTestBase {
public:
static void SetUpTestCase() {
OperatorTestBase::SetUpTestCase();
filesystems::registerLocalFileSystem();
if (!connector::hasConnectorFactory(
connector::hive::HiveConnectorFactory::kHiveConnectorName)) {
Expand All @@ -194,32 +196,12 @@ class TaskManagerTest : public testing::Test {
SystemConfig::instance()->setValue(
std::string(SystemConfig::kMemoryArbitratorKind), "SHARED");
ASSERT_EQ(SystemConfig::instance()->memoryArbitratorKind(), "SHARED");
FLAGS_velox_enable_memory_usage_track_in_default_memory_pool = true;
FLAGS_velox_memory_leak_check_enabled = true;
velox::memory::SharedArbitrator::registerFactory();
velox::memory::MemoryManagerOptions options;
options.allocatorCapacity = 8L << 30;
options.arbitratorCapacity = 6L << 30;
options.extraArbitratorConfigs = {
{std::string(velox::memory::SharedArbitrator::ExtraConfig::
kMemoryPoolInitialCapacity),
"512MB"},
{std::string(velox::memory::SharedArbitrator::ExtraConfig::
kMemoryPoolMinReclaimBytes),
"0B"}};
options.arbitratorKind = "SHARED";
options.checkUsageLeak = true;
options.arbitrationStateCheckCb = memoryArbitrationStateCheck;
memory::MemoryManager::testingSetInstance(options);
common::testutil::TestValue::enable();
}

protected:
void SetUp() override {
FLAGS_velox_memory_leak_check_enabled = true;
functions::prestosql::registerAllScalarFunctions();
aggregate::prestosql::registerAllAggregateFunctions();
parse::registerTypeResolver();
OperatorTestBase::SetUp();
dwrf::registerDwrfWriterFactory();
dwrf::registerDwrfReaderFactory();
exec::ExchangeSource::registerFactory(
Expand All @@ -240,9 +222,6 @@ class TaskManagerTest : public testing::Test {
connPool.get(),
nullptr);
});
if (!isRegisteredVectorSerde()) {
serializer::presto::PrestoVectorSerde::registerVectorSerde();
};

registerPrestoToVeloxConnector(std::make_unique<HivePrestoToVeloxConnector>(
connector::hive::HiveConnectorFactory::kHiveConnectorName));
Expand All @@ -255,17 +234,14 @@ class TaskManagerTest : public testing::Test {
std::unordered_map<std::string, std::string>()));
connector::registerConnector(hiveConnector);

rootPool_ = memory::memoryManager()->addRootPool("TaskManagerTest.root");
leafPool_ =
memory::deprecatedAddDefaultLeafMemoryPool("TaskManagerTest.leaf");
rowType_ = ROW({"c0", "c1"}, {INTEGER(), VARCHAR()});

taskManager_ = std::make_unique<TaskManager>(
driverExecutor_.get(), httpSrvCpuExecutor_.get(), nullptr);

auto validator = std::make_shared<facebook::presto::VeloxPlanValidator>();
taskResource_ = std::make_unique<TaskResource>(
leafPool_.get(),
pool_.get(),
httpSrvCpuExecutor_.get(),
validator.get(),
*taskManager_.get());
Expand Down Expand Up @@ -298,14 +274,16 @@ class TaskManagerTest : public testing::Test {
connector::hive::HiveConnectorFactory::kHiveConnectorName);
dwrf::unregisterDwrfWriterFactory();
dwrf::unregisterDwrfReaderFactory();
taskManager_.reset();
OperatorTestBase::TearDown();
}

std::vector<RowVectorPtr> makeVectors(int count, int rowsPerVector) {
std::vector<RowVectorPtr> vectors;
for (int i = 0; i < count; ++i) {
auto vector = std::dynamic_pointer_cast<RowVector>(
facebook::velox::test::BatchMaker::createBatch(
rowType_, rowsPerVector, *leafPool_));
rowType_, rowsPerVector, *pool_));
vectors.emplace_back(vector);
}
return vectors;
Expand Down Expand Up @@ -398,7 +376,7 @@ class TaskManagerTest : public testing::Test {
const protocol::TaskId& taskId,
const RowTypePtr& resultType,
const std::vector<std::string>& allTaskIds) {
Cursor cursor(taskManager_.get(), taskId, resultType, leafPool_.get());
Cursor cursor(taskManager_.get(), taskId, resultType, pool_.get());
std::vector<RowVectorPtr> vectors;
for (;;) {
auto moreVectors = cursor.next();
Expand Down Expand Up @@ -678,8 +656,6 @@ class TaskManagerTest : public testing::Test {
taskId, updateRequest, planFragment, std::move(queryCtx), 0);
}

std::shared_ptr<memory::MemoryPool> rootPool_;
std::shared_ptr<memory::MemoryPool> leafPool_;
RowTypePtr rowType_;
exec::test::DuckDbQueryRunner duckDbQueryRunner_;
std::unique_ptr<TaskManager> taskManager_;
Expand Down

0 comments on commit 0a90d6e

Please sign in to comment.