From 729c9d2f220e58fb588c5cd2ae5430f02994591b Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 13 Jan 2024 10:01:26 +0100 Subject: [PATCH] refactor `TreeNode::rewrite()` --- datafusion-examples/examples/rewrite_expr.rs | 6 +- datafusion/common/src/tree_node.rs | 192 ++++++++++-------- .../core/src/datasource/listing/helpers.rs | 18 +- .../physical_plan/parquet/row_filter.rs | 17 +- datafusion/core/src/execution/context/mod.rs | 6 +- .../combine_partial_final_agg.rs | 2 +- .../physical_optimizer/projection_pushdown.rs | 4 +- .../core/src/physical_optimizer/pruning.rs | 2 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 38 ++-- datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- datafusion/expr/src/logical_plan/display.rs | 18 +- datafusion/expr/src/logical_plan/plan.rs | 79 ++++--- datafusion/expr/src/tree_node/expr.rs | 14 +- datafusion/expr/src/tree_node/plan.rs | 14 +- datafusion/expr/src/utils.rs | 10 +- .../src/analyzer/count_wildcard_rule.rs | 4 +- .../src/analyzer/inline_table_scan.rs | 2 +- datafusion/optimizer/src/analyzer/mod.rs | 4 +- .../optimizer/src/analyzer/rewrite_expr.rs | 4 +- datafusion/optimizer/src/analyzer/subquery.rs | 16 +- .../optimizer/src/analyzer/type_coercion.rs | 10 +- .../optimizer/src/common_subexpr_eliminate.rs | 95 +++++---- datafusion/optimizer/src/decorrelate.rs | 22 +- datafusion/optimizer/src/plan_signature.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 14 +- .../optimizer/src/scalar_subquery_to_join.rs | 26 +-- .../simplify_expressions/expr_simplifier.rs | 21 +- .../src/simplify_expressions/guarantees.rs | 4 +- .../simplify_expressions/inlist_simplifier.rs | 4 +- .../or_in_list_simplifier.rs | 4 +- .../src/unwrap_cast_in_comparison.rs | 10 +- datafusion/optimizer/src/utils.rs | 6 +- .../physical-expr/src/equivalence/class.rs | 2 +- .../physical-expr/src/expressions/case.rs | 2 +- datafusion/physical-expr/src/utils/mod.rs | 26 +-- .../library-user-guide/working-with-exprs.md | 2 +- 37 files changed, 355 insertions(+), 351 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 5e95562033e60..9dfc238ab9e83 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(&|plan| { + plan.transform_up(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -106,7 +106,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c5c4ee824d61f..138519c4e99dd 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -31,12 +31,12 @@ use crate::Result; macro_rules! handle_tree_recursion { ($EXPR:expr) => { match $EXPR { - VisitRecursion::Continue => {} + TreeNodeRecursion::Continue => {} // If the recursion should skip, do not apply to its children, let // the recursion continue: - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue), // If the recursion should stop, do not apply to its children: - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } }; } @@ -58,10 +58,10 @@ pub trait TreeNode: Sized { /// /// The `op` closure can be used to collect some info from the /// tree node or do some checking for the tree node. - fn apply Result>( + fn apply Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { handle_tree_recursion!(op(self)?); self.apply_children(&mut |node| node.apply(op)) } @@ -88,7 +88,7 @@ pub trait TreeNode: Sized { /// /// If an Err result is returned, recursion is stopped immediately /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no + /// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no /// children of that node will be visited, nor is post_visit /// called on that node. Details see [`TreeNodeVisitor`] /// @@ -97,20 +97,53 @@ pub trait TreeNode: Sized { fn visit>( &self, visitor: &mut V, - ) -> Result { + ) -> Result { handle_tree_recursion!(visitor.pre_visit(self)?); handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); visitor.post_visit(self) } - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result + /// Transforms the tree using `f_down` while traversing the tree top-down + /// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order). + /// + /// E.g. for an tree such as: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order: + /// ```text + /// f_down(ParentNode) + /// f_down(ChildNode1) + /// f_up(ChildNode1) + /// f_down(ChildNode2) + /// f_up(ChildNode2) + /// f_up(ParentNode) + /// ``` + /// + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// + /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. + fn transform(self, f_down: &mut FD, f_up: &mut FU) -> Result where - F: Fn(Self) -> Result>, + FD: FnMut(Self) -> Result<(Transformed, TreeNodeRecursion)>, + FU: FnMut(Self) -> Result, { - self.transform_up(op) + let (new_node, tnr) = f_down(self).map(|(t, tnr)| (t.into(), tnr))?; + match tnr { + TreeNodeRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + TreeNodeRecursion::Skip => return Ok(new_node), + // If the recursion should stop, do not apply to its children + TreeNodeRecursion::Stop => { + panic!("Stop can't be used in TreeNode::transform()") + } + } + let node_with_new_children = + new_node.map_children(|node| node.transform(f_down, f_up))?; + f_up(node_with_new_children) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -159,56 +192,50 @@ pub trait TreeNode: Sized { Ok(new_node) } - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for + /// recursively transforming [`TreeNode`]s. /// - /// For an node tree such as + /// E.g. for an tree such as: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order + /// The nodes are visited using the following order: /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) + /// TreeNodeRewriter::f_down(ParentNode) + /// TreeNodeRewriter::f_down(ChildNode1) + /// TreeNodeRewriter::f_up(ChildNode1) + /// TreeNodeRewriter::f_down(ChildNode2) + /// TreeNodeRewriter::f_up(ChildNode2) + /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is mutate - /// called on that node + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// - /// If using the default [`TreeNodeRewriter::pre_visit`] which - /// returns `true`, [`Self::transform`] should be preferred. - fn rewrite>(self, rewriter: &mut R) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) + /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], + /// recursion is stopped immediately. + fn rewrite>(self, rewriter: &mut R) -> Result { + let (new_node, tnr) = rewriter.f_down(self)?; + match tnr { + TreeNodeRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + TreeNodeRecursion::Skip => return Ok(new_node), + // If the recursion should stop, do not apply to its children + TreeNodeRecursion::Stop => { + panic!("Stop can't be used in TreeNode::rewrite()") + } } + let node_with_new_children = + new_node.map_children(|node| node.rewrite(rewriter))?; + rewriter.f_up(node_with_new_children) } /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result; /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result @@ -231,69 +258,58 @@ pub trait TreeNode: Sized { /// If an [`Err`] result is returned, recursion is stopped /// immediately. /// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no +/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no /// children of that tree node are visited, nor is post_visit /// called on that tree node /// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no +/// If [`TreeNodeRecursion::Stop`] is returned on a call to post_visit, no /// siblings of that tree node are visited, nor is post_visit /// called on its parent tree node /// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no +/// If [`TreeNodeRecursion::Skip`] is returned on a call to pre_visit, no /// children of that tree node are visited. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. type N: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + fn pre_visit(&mut self, node: &Self::N) -> Result; /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(TreeNodeRecursion::Continue) } } -/// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is -/// invoked recursively on all nodes of a tree. +/// Trait for potentially recursively transform a [`TreeNode`] node tree. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. - type N: TreeNode; + type Node: TreeNode; - /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(Recursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(RewriteRecursion::Continue) + /// Invoked while traversing down the tree before any children are rewritten / + /// visited. + /// Default implementation returns the node unmodified and continues recursion. + fn f_down(&mut self, node: Self::Node) -> Result<(Self::Node, TreeNodeRecursion)> { + Ok((node, TreeNodeRecursion::Continue)) } - /// Invoked after (Postorder) all children of `node` have been mutated and - /// returns a potentially modified node. - fn mutate(&mut self, node: Self::N) -> Result; -} - -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`]. -#[derive(Debug)] -pub enum RewriteRecursion { - /// Continue rewrite this node tree. - Continue, - /// Call 'op' immediately and return. - Mutate, - /// Do not rewrite the children of this node. - Stop, - /// Keep recursive but skip apply op on this node - Skip, + /// Invoked while traversing up the tree after all children have been rewritten / + /// visited. + /// Default implementation returns the node unmodified. + fn f_up(&mut self, node: Self::Node) -> Result { + Ok(node) + } } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. +/// Controls how [`TreeNode`] recursions should proceed. #[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. +pub enum TreeNodeRecursion { + /// Continue recursion with the next node. Continue, - /// Keep recursive but skip applying op on the children + /// Skip the current subtree. Skip, - /// Stop the visit to this node tree. + /// Stop recursion. Stop, } @@ -340,14 +356,14 @@ pub trait DynTreeNode { /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { for child in self.arc_children() { handle_tree_recursion!(op(&child)?) } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children(self, transform: F) -> Result @@ -382,14 +398,14 @@ pub trait ConcreteTreeNode: Sized { impl TreeNode for T { /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { for child in self.children() { handle_tree_recursion!(op(child)?) } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index a03bcec7abece..96864672573b1 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue}; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; @@ -57,9 +57,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(VisitRecursion::Skip) + Ok(TreeNodeRecursion::Skip) } else { - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } Expr::Literal(_) @@ -88,27 +88,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(VisitRecursion::Continue), + | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } ScalarFunctionDefinition::UDF(fun) => { match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } @@ -128,7 +128,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } }) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 3c40509a86d27..ddfeb146b876d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; @@ -209,29 +209,32 @@ impl<'a> FilterCandidateBuilder<'a> { } impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { - type N = Arc; + type Node = Arc; - fn pre_visit(&mut self, node: &Arc) -> Result { + fn f_down( + &mut self, + node: Arc, + ) -> Result<(Arc, TreeNodeRecursion)> { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok((node, TreeNodeRecursion::Skip)); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok((node, TreeNodeRecursion::Skip)); } } - Ok(RewriteRecursion::Continue) + Ok((node, TreeNodeRecursion::Continue)) } - fn mutate(&mut self, expr: Arc) -> Result> { + fn f_up(&mut self, expr: Arc) -> Result> { if let Some(column) = expr.as_any().downcast_ref::() { if self.file_schema.field_with_name(column.name()).is_err() { // the column expr must be in the table schema diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index b5ad6174821b9..4f57d873cbdfa 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -39,7 +39,7 @@ use crate::{ use datafusion_common::{ alias::AliasGenerator, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -2108,7 +2108,7 @@ impl<'a> BadPlanVisitor<'a> { impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { type N = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn pre_visit(&mut self, node: &Self::N) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2122,7 +2122,7 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } - _ => Ok(VisitRecursion::Continue), + _ => Ok(TreeNodeRecursion::Continue), } } } diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 61eb2381c63b6..b26d9763e53a5 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -178,7 +178,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform(&|expr| { + .transform_up(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 2d20c487e473e..64ef92faa865c 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,7 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::JoinSide; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -270,7 +270,7 @@ fn try_unifying_projections( if let Some(column) = expr.as_any().downcast_ref::() { *column_ref_map.entry(column.clone()).or_default() += 1; } - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index aa0c26723767e..aa72771b1eb3f 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -837,7 +837,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(&|expr| { + e.transform_up(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::Yes(Arc::new(column_new.clone()))); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c5d158d876385..e0eebf5c8c18a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1247,7 +1247,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform(&|mut expr| { + self.transform_up(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1f04c80833f09..76bd51619954a 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -33,7 +33,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -57,7 +57,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -75,7 +75,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -102,7 +102,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -122,7 +122,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -164,7 +164,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::Yes(Expr::Column(col)) @@ -250,7 +250,7 @@ pub fn unalias(expr: Expr) -> Expr { /// schema of plan nodes don't change after optimization pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result where - R: TreeNodeRewriter, + R: TreeNodeRewriter, { let original_name = expr.name_for_alias()?; let expr = expr.rewrite(rewriter)?; @@ -263,7 +263,7 @@ mod test { use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -273,14 +273,14 @@ mod test { } impl TreeNodeRewriter for RecordingRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { self.v.push(format!("Previsited {expr}")); - Ok(RewriteRecursion::Continue) + Ok((expr, TreeNodeRecursion::Continue)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { self.v.push(format!("Mutated {expr}")); Ok(expr) } @@ -305,11 +305,17 @@ mod test { }; // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("foo")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("baz")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -444,9 +450,9 @@ mod test { } impl TreeNodeRewriter for TestRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, _: Expr) -> Result { + fn f_up(&mut self, _: Expr) -> Result { Ok(self.rewrite_to.clone()) } } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646b..1e7efcafd04df 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform(&|expr| { + expr.transform_up(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba18..ebef7791f8d8d 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -19,7 +19,7 @@ use crate::LogicalPlan; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; use std::fmt; @@ -54,7 +54,7 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -69,15 +69,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -176,7 +176,7 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -204,18 +204,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); res.ok_or(DataFusionError::Internal("Fail to format".to_string())) - .map(|_| VisitRecursion::Continue) + .map(|_| TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index aee3a59dd2da6..80ce38fe93897 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,8 +45,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, - VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -475,7 +474,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(using_columns) @@ -648,31 +647,29 @@ impl LogicalPlan { // Decimal128(Some(69999999999999),30,15) // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - struct RemoveAliases {} - - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; - - fn pre_visit(&mut self, expr: &Expr) -> Result { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) - } - Expr::Alias(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + fn unalias_down( + expr: Expr, + ) -> Result<(Transformed, TreeNodeRecursion)> { + match expr { + Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok((Transformed::No(expr), TreeNodeRecursion::Skip)) } + Expr::Alias(_) => Ok(( + Transformed::Yes(expr.unalias()), + TreeNodeRecursion::Skip, + )), + _ => Ok((Transformed::No(expr), TreeNodeRecursion::Continue)), } + } - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) - } + fn dummy_up(expr: Expr) -> Result { + Ok(expr) } - let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + let predicate = predicate.transform(&mut unalias_down, &mut dummy_up)?; Filter::try_new(predicate, Arc::new(inputs[0].clone())) .map(LogicalPlan::Filter) @@ -1124,9 +1121,9 @@ impl LogicalPlan { impl LogicalPlan { /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> + pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> Result, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -1150,7 +1147,7 @@ impl LogicalPlan { } /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> + pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> where V: TreeNodeVisitor, { @@ -1225,11 +1222,11 @@ impl LogicalPlan { _ => {} } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok::<(), DataFusionError>(()) })?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(param_types) @@ -1241,7 +1238,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, .. }) => { let value = param_values.get_placeholders_with_values(id)?; @@ -2840,7 +2837,7 @@ digraph { impl TreeNodeVisitor for OkVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2851,10 +2848,10 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2865,7 +2862,7 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -2923,18 +2920,18 @@ digraph { impl TreeNodeVisitor for StoppingVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.pre_visit(plan)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.post_visit(plan) @@ -2992,7 +2989,7 @@ digraph { impl TreeNodeVisitor for ErrorVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } @@ -3000,7 +2997,7 @@ digraph { self.inner.pre_visit(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } @@ -3306,7 +3303,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform(&|plan| match plan { + .transform_up(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 05464c96d05ef..d937c11633f46 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,14 +24,14 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children Result>( + fn apply_children Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { let children = match self { Expr::Alias(Alias{expr, .. }) | Expr::Not(expr) @@ -130,13 +130,13 @@ impl TreeNode for Expr { for child in children { match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue), + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children Result>( diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 589bb917a953f..8be24638c1cc8 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -19,14 +19,14 @@ use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::{handle_tree_recursion, Result}; impl TreeNode for LogicalPlan { - fn apply Result>( + fn apply Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::apply_subqueries`] before visiting its children handle_tree_recursion!(op(self)?); @@ -57,7 +57,7 @@ impl TreeNode for LogicalPlan { fn visit>( &self, visitor: &mut V, - ) -> Result { + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::visit_subqueries`] before visiting its children handle_tree_recursion!(visitor.pre_visit(self)?); @@ -66,14 +66,14 @@ impl TreeNode for LogicalPlan { visitor.post_visit(self) } - fn apply_children Result>( + fn apply_children Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { for child in self.inputs() { handle_tree_recursion!(op(child)?) } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 02479c0765bd3..88b6d34c48dc5 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -31,7 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, @@ -662,10 +662,10 @@ where exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Skip); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -682,10 +682,10 @@ where if let Err(e) = f(expr) { // save the error for later (it may not be a DataFusionError err = Err(e); - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } else { // keep going - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } }) // The closure always returns OK, so this will always too diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 35a8597832399..90046ca2aac0e 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -115,9 +115,9 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { struct CountWildcardRewriter {} impl TreeNodeRewriter for CountWildcardRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, old_expr: Expr) -> Result { + fn f_up(&mut self, old_expr: Expr) -> Result { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec82935..a418fbf5537be 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::Yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?; + let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; Transformed::Yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 9d47299a56167..b416e1eb18639 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -28,7 +28,7 @@ use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; @@ -136,7 +136,7 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { })?; } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 8f1c844ed0623..829197b4d9481 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -94,9 +94,9 @@ pub(crate) struct OperatorToFunctionRewriter { } impl TreeNodeRewriter for OperatorToFunctionRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match expr { Expr::BinaryExpr(BinaryExpr { ref left, diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 7c5b70b19af0a..7ad9832dea54e 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -17,7 +17,7 @@ use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; @@ -146,7 +146,7 @@ fn check_inner_plan( LogicalPlan::Aggregate(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -171,7 +171,7 @@ fn check_inner_plan( check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -188,7 +188,7 @@ fn check_inner_plan( | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -206,7 +206,7 @@ fn check_inner_plan( is_aggregate, can_contain_outer_ref, )?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -221,7 +221,7 @@ fn check_inner_plan( JoinType::Full => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -290,9 +290,9 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { correlated .into_iter() .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c0dad2ef40063..0f20ede0f2391 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -126,13 +126,9 @@ pub(crate) struct TypeCoercionRewriter { } impl TreeNodeRewriter for TypeCoercionRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } - - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match expr { Expr::ScalarSubquery(Subquery { subquery, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index fe71171ce5455..564addd53f29c 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -25,7 +25,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, + TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -614,21 +614,21 @@ impl ExprIdentifierVisitor<'_> { impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type N = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Skip); } self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn post_visit(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -637,7 +637,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -651,7 +651,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -694,74 +694,71 @@ struct CommonSubexprRewriter<'a> { } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(RewriteRecursion::Stop); + if expr.short_circuits() || is_volatile_expression(&expr)? { + return Ok((expr, TreeNodeRecursion::Skip)); } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok(RewriteRecursion::Stop); + return Ok((expr, TreeNodeRecursion::Skip)); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok(RewriteRecursion::Skip); + return Ok((expr, TreeNodeRecursion::Continue)); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { if *counter > 1 { self.affected_id.insert(curr_id.clone()); - Ok(RewriteRecursion::Mutate) + + // This expr tree is finished. + if self.curr_index >= self.id_array.len() { + return Ok((expr, TreeNodeRecursion::Skip)); + } + + let (series_number, id) = &self.id_array[self.curr_index]; + self.curr_index += 1; + // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. + let expr_set_item = self.expr_set.get(id).ok_or_else(|| { + DataFusionError::Internal("expr_set invalid state".to_string()) + })?; + if *series_number < self.max_series_number + || id.is_empty() + || expr_set_item.1 <= 1 + { + return Ok((expr, TreeNodeRecursion::Skip)); + } + + self.max_series_number = *series_number; + // step index to skip all sub-node (which has smaller series number). + while self.curr_index < self.id_array.len() + && *series_number > self.id_array[self.curr_index].0 + { + self.curr_index += 1; + } + + let expr_name = expr.display_name()?; + // Alias this `Column` expr to it original "expr name", + // `projection_push_down` optimizer use "expr name" to eliminate useless + // projections. + Ok((col(id).alias(expr_name), TreeNodeRecursion::Skip)) } else { self.curr_index += 1; - Ok(RewriteRecursion::Skip) + Ok((expr, TreeNodeRecursion::Continue)) } } _ => internal_err!("expr_set invalid state"), } } - - fn mutate(&mut self, expr: Expr) -> Result { - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(expr); - } - - let (series_number, id) = &self.id_array[self.curr_index]; - self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - DataFusionError::Internal("expr_set invalid state".to_string()) - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(expr); - } - - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - - let expr_name = expr.display_name()?; - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - Ok(col(id).alias(expr_name)) - } } fn replace_common_expr( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b1000f042c987..49d3c322ca2b0 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -18,7 +18,7 @@ use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; @@ -56,19 +56,19 @@ pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: LogicalPlan) -> Result<(LogicalPlan, TreeNodeRecursion)> { match plan { - LogicalPlan::Filter(_) => Ok(RewriteRecursion::Continue), + LogicalPlan::Filter(_) => Ok((plan, TreeNodeRecursion::Continue)), LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok((plan, TreeNodeRecursion::Skip)) } else { - Ok(RewriteRecursion::Continue) + Ok((plan, TreeNodeRecursion::Continue)) } } LogicalPlan::Limit(_) => { @@ -77,21 +77,21 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { (false, true) => { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok((plan, TreeNodeRecursion::Skip)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok((plan, TreeNodeRecursion::Continue)), } } _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok((plan, TreeNodeRecursion::Skip)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok((plan, TreeNodeRecursion::Continue)), } } - fn mutate(&mut self, plan: LogicalPlan) -> Result { + fn f_up(&mut self, plan: LogicalPlan) -> Result { let subquery_schema = plan.schema().clone(); match &plan { LogicalPlan::Filter(plan_filter) => { diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 07f495a7262df..8b8814192d383 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, @@ -75,7 +75,7 @@ fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; plan.apply(&mut |_plan| { node_number += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // Closure always return Ok .unwrap(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7086c5cda56f8..0ae0bc696a352 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -22,7 +22,7 @@ use crate::optimizer::ApplyOrder; use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, DataFusionError, JoinConstraint, Result, @@ -222,7 +222,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Skip), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) @@ -232,7 +232,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { .. }) => { is_evaluate = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } Expr::Alias(_) | Expr::BinaryExpr(_) @@ -254,7 +254,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(VisitRecursion::Continue), + | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) @@ -1039,12 +1039,12 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } - None => VisitRecursion::Continue, + None => TreeNodeRecursion::Continue, } } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 34ed4a9475cba..e1c35e468f68a 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -21,7 +21,7 @@ use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; @@ -201,16 +201,9 @@ struct ExtractScalarSubQuery { } impl TreeNodeRewriter for ExtractScalarSubQuery { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { - match expr { - Expr::ScalarSubquery(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), - } - } - - fn mutate(&mut self, expr: Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); @@ -220,12 +213,15 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; - Ok(Expr::Column(create_col_from_scalar_expr( - &scalar_expr, - subqry_alias, - )?)) + Ok(( + Expr::Column(create_col_from_scalar_expr( + &scalar_expr, + subqry_alias, + )?), + TreeNodeRecursion::Skip, + )) } - _ => Ok(expr), + _ => Ok((expr, TreeNodeRecursion::Continue)), } } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 1c12289491711..fd77071ea7286 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -33,9 +33,10 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, - tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, + tree_node::{TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -246,9 +247,9 @@ impl Canonicalizer { } impl TreeNodeRewriter for Canonicalizer { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { return Ok(expr); }; @@ -310,9 +311,9 @@ enum ConstSimplifyResult { } impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -320,7 +321,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // stack as not ok (as all parents have at least one child or // descendant that can not be evaluated - if !Self::can_evaluate(expr) { + if !Self::can_evaluate(&expr) { // walk back up stack, marking first parent that is not mutable let parent_iter = self.can_evaluate.iter_mut().rev(); for p in parent_iter { @@ -336,10 +337,10 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok(RewriteRecursion::Continue) + Ok((expr, TreeNodeRecursion::Continue)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting // and may not evalute all their sub expressions. Thus if @@ -504,10 +505,10 @@ impl<'a, S> Simplifier<'a, S> { } impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { - type N = Expr; + type Node = Expr; /// rewrite the expression simplifying any constant expressions - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index aa7bb4f78a93f..e7c619c046de8 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -57,9 +57,9 @@ impl<'a> GuaranteeRewriter<'a> { } impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { if self.guarantees.is_empty() { return Ok(expr); } diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index fa95f1688e6f4..867e96d213d99 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -49,9 +49,9 @@ impl InListSimplifier { } impl TreeNodeRewriter for InListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) = (left.as_ref(), op, right.as_ref()) diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs index fd5c9ecaf82c5..ea02c1f3af8a2 100644 --- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs @@ -37,9 +37,9 @@ impl OrInListSimplifier { } impl TreeNodeRewriter for OrInListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if *op == Operator::Or { let left = as_inlist(left); diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 91603e82a54fc..0232a28c722a6 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; @@ -127,13 +127,9 @@ struct UnwrapCastExprRewriter { } impl TreeNodeRewriter for UnwrapCastExprRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } - - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match &expr { // For case: // try_cast/cast(expr as data_type) op literal diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 5671dc6ae94d3..13b67794c7ddc 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,7 +18,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::is_volatile; @@ -100,9 +100,9 @@ pub(crate) fn is_volatile_expression(e: &Expr) -> Result { e.apply(&mut |expr| { Ok(if is_volatile(expr)? { is_volatile_expr = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) })?; Ok(is_volatile_expr) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index f0bd1740d5d2d..29a6825ddcf70 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -260,7 +260,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform(&|expr| { + .transform_up(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 6a168e2f1e5fa..b04c66b237289 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -960,7 +960,7 @@ mod tests { let expr2 = expr .clone() - .transform(&|e| { + .transform_up(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index e14ff26921463..8d4f4cad4afaa 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -29,9 +29,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, -}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::Result; use datafusion_expr::Operator; @@ -130,11 +128,10 @@ pub fn get_indices_of_exprs_strict>>( pub type ExprTreeNode = ExprContext>; -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a -/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting -/// identical expressions in one node. Caller specifies the node type in the -/// DAEG via the `constructor` argument, which constructs nodes in the DAEG -/// from the [ExprTreeNode] ancillary object. +/// This struct is used to convert a [PhysicalExpr] tree into a DAEG (i.e. an expression +/// DAG) by collecting identical expressions in one node. Caller specifies the node type +/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from +/// the [ExprTreeNode] ancillary object. struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, @@ -144,16 +141,15 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter - for PhysicalExprDAEGBuilder<'a, T, F> +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> + PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. fn mutate( &mut self, mut node: ExprTreeNode, - ) -> Result> { + ) -> Result>> { // Get the expression associated with the input expression node. let expr = &node.expr; @@ -176,7 +172,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(node) + Ok(Transformed::Yes(node)) } } @@ -197,7 +193,7 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let root = init.transform_up_mut(&mut |node| builder.mutate(node))?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -211,7 +207,7 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns.insert(column.clone()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index 96be8ef7f1aeb..b128d661f31a9 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -96,7 +96,7 @@ To implement the inlining, we'll need to write a function that takes an `Expr` a ```rust fn rewrite_add_one(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok(match expr { Expr::ScalarUDF(scalar_fun) if scalar_fun.fun.name == "add_one" => { let input_arg = scalar_fun.args[0].clone();