diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index 17bda5780e8e..e62b7ef4aea1 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -135,6 +135,25 @@ class PlanNode { /// The name of the plan node, used in toString. virtual std::string_view name() const = 0; + /// Recursively checks the node tree for a first node that satisfy a given + /// condition. Returns pointer to the node if found, nullptr if not. + static const PlanNode* visit( + const PlanNode* node, + const std::function& predicate) { + if (predicate(node)) { + return node; + } + + // Recursively go further through the sources. + for (const auto& source : node->sources()) { + const auto* ret = PlanNode::visit(source.get(), predicate); + if (ret != nullptr) { + return ret; + } + } + return nullptr; + } + private: /// The details of the plan node in textual format. virtual void addDetails(std::stringstream& stream) const = 0; @@ -516,6 +535,14 @@ class AggregationNode : public PlanNode { return "Aggregation"; } + bool isFinal() const { + return step_ == Step::kFinal; + } + + bool isSingle() const { + return step_ == Step::kSingle; + } + private: void addDetails(std::stringstream& stream) const override; diff --git a/velox/core/tests/CMakeLists.txt b/velox/core/tests/CMakeLists.txt index 036907a35409..7fdccd46cea7 100644 --- a/velox/core/tests/CMakeLists.txt +++ b/velox/core/tests/CMakeLists.txt @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_core_test TestQueryConfig.cpp TestString.cpp - TestTypeAnalysis.cpp) +add_executable(velox_core_test TestPlanNode.cpp TestQueryConfig.cpp + TestString.cpp TestTypeAnalysis.cpp) add_test(velox_core_test velox_core_test) diff --git a/velox/core/tests/TestPlanNode.cpp b/velox/core/tests/TestPlanNode.cpp new file mode 100644 index 000000000000..29ec82fea144 --- /dev/null +++ b/velox/core/tests/TestPlanNode.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "velox/core/PlanNode.h" + +using namespace ::facebook::velox; +using namespace ::facebook::velox::core; + +TEST(TestPlanNode, visit) { + std::shared_ptr bigIntType = + std::make_shared>(); + auto rowType = std::make_shared( + std::vector{"name1"}, + std::vector>{bigIntType}); + + std::shared_ptr tableHandle; + std::unordered_map> + assignments; + + std::shared_ptr tableScan3 = + std::make_shared("3", rowType, tableHandle, assignments); + std::shared_ptr tableScan2 = + std::make_shared("2", rowType, tableHandle, assignments); + + std::vector sortingKeys; + std::vector sortingOrders; + std::shared_ptr localMerge1 = std::make_shared( + "1", + sortingKeys, + sortingOrders, + std::vector{tableScan2, tableScan3}); + + std::vector names; + std::vector projections; + std::shared_ptr project0 = + std::make_shared("0", names, projections, localMerge1); + + EXPECT_EQ( + tableScan3.get(), + PlanNode::visit(project0.get(), [](const PlanNode* node) { + return node->id() == "3"; + })); + + EXPECT_EQ( + project0.get(), PlanNode::visit(project0.get(), [](const PlanNode* node) { + return node->name() == "Project"; + })); + + EXPECT_EQ(nullptr, PlanNode::visit(project0.get(), [](const PlanNode* node) { + return node->name() == "Unknown"; + })); +}