-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance short circuit handling in CommonSubexprEliminate
#11197
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,9 +56,13 @@ struct Identifier<'n> { | |
} | ||
|
||
impl<'n> Identifier<'n> { | ||
fn new(expr: &'n Expr, random_state: &RandomState) -> Self { | ||
fn new(expr: &'n Expr, is_tree: bool, random_state: &RandomState) -> Self { | ||
let mut hasher = random_state.build_hasher(); | ||
expr.hash_node(&mut hasher); | ||
if is_tree { | ||
expr.hash(&mut hasher); | ||
} else { | ||
expr.hash_node(&mut hasher); | ||
} | ||
let hash = hasher.finish(); | ||
Self { hash, expr } | ||
} | ||
|
@@ -908,31 +912,28 @@ struct ExprIdentifierVisitor<'a, 'n> { | |
found_common: bool, | ||
} | ||
|
||
/// Record item that used when traversing a expression tree. | ||
/// Record item that used when traversing an expression tree. | ||
enum VisitRecord<'n> { | ||
/// `usize` postorder index assigned in `f-down`(). Starts from 0. | ||
EnterMark(usize), | ||
/// the node's children were skipped => jump to f_up on same node | ||
JumpMark, | ||
EnterMark(usize, bool), | ||
/// Accumulated identifier of sub expression. | ||
ExprItem(Identifier<'n>), | ||
} | ||
|
||
impl<'n> ExprIdentifierVisitor<'_, 'n> { | ||
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` | ||
/// before it. | ||
fn pop_enter_mark(&mut self) -> Option<(usize, Option<Identifier<'n>>)> { | ||
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>) { | ||
let mut expr_id = None; | ||
|
||
while let Some(item) = self.visit_stack.pop() { | ||
match item { | ||
VisitRecord::EnterMark(idx) => { | ||
return Some((idx, expr_id)); | ||
VisitRecord::EnterMark(down_index, tree) => { | ||
return (down_index, tree, expr_id); | ||
} | ||
VisitRecord::ExprItem(id) => { | ||
expr_id = Some(id.combine(expr_id)); | ||
} | ||
VisitRecord::JumpMark => return None, | ||
} | ||
} | ||
unreachable!("Enter mark should paired with node number"); | ||
|
@@ -944,30 +945,30 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { | |
|
||
fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> { | ||
// TODO: consider non-volatile sub-expressions for CSE | ||
// TODO: consider surely executed children of "short circuited"s for CSE | ||
|
||
// If an expression can short circuit its children then don't consider it for CSE | ||
// (https://github.com/apache/arrow-datafusion/issues/8814). | ||
if expr.short_circuits() { | ||
self.visit_stack.push(VisitRecord::JumpMark); | ||
|
||
return Ok(TreeNodeRecursion::Jump); | ||
} | ||
// If an expression can short circuit its children then don't consider its | ||
// children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). | ||
// TODO: consider surely executed children of "short circuited"s for CSE | ||
let is_tree = expr.short_circuits(); | ||
let tnr = if is_tree { | ||
TreeNodeRecursion::Jump | ||
} else { | ||
TreeNodeRecursion::Continue | ||
}; | ||
|
||
self.id_array.push((0, None)); | ||
self.visit_stack | ||
.push(VisitRecord::EnterMark(self.down_index)); | ||
.push(VisitRecord::EnterMark(self.down_index, is_tree)); | ||
self.down_index += 1; | ||
|
||
Ok(TreeNodeRecursion::Continue) | ||
Ok(tnr) | ||
} | ||
|
||
fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> { | ||
let Some((down_index, sub_expr_id)) = self.pop_enter_mark() else { | ||
return Ok(TreeNodeRecursion::Continue); | ||
}; | ||
let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark(); | ||
|
||
let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); | ||
let expr_id = | ||
Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id); | ||
|
||
self.id_array[down_index].0 = self.up_index; | ||
if !self.expr_mask.ignores(expr) { | ||
|
@@ -1012,19 +1013,22 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { | |
self.alias_counter += 1; | ||
} | ||
|
||
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate | ||
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we | ||
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the | ||
// `id_array`, which records the expr's identifier used to rewrite expr. So if we | ||
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. | ||
if expr.short_circuits() { | ||
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); | ||
} | ||
let is_tree = expr.short_circuits(); | ||
let tnr = if is_tree { | ||
TreeNodeRecursion::Jump | ||
} else { | ||
TreeNodeRecursion::Continue | ||
}; | ||
|
||
let (up_index, expr_id) = self.id_array[self.down_index]; | ||
self.down_index += 1; | ||
|
||
// skip `Expr`s without identifier (empty identifier). | ||
let Some(expr_id) = expr_id else { | ||
return Ok(Transformed::no(expr)); | ||
return Ok(Transformed::new(expr, false, tnr)); | ||
}; | ||
|
||
let count = self.expr_stats.get(&expr_id).unwrap(); | ||
|
@@ -1052,7 +1056,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { | |
|
||
Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)) | ||
} else { | ||
Ok(Transformed::no(expr)) | ||
Ok(Transformed::new(expr, false, tnr)) | ||
} | ||
} | ||
|
||
|
@@ -1799,4 +1803,34 @@ mod test { | |
assert!(result.len() == 1); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_short_circuits() -> Result<()> { | ||
let table_scan = test_table_scan()?; | ||
|
||
let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0))); | ||
let not_extracted_short_circuit_leg = (col("a") + col("b")).eq(lit(0)); | ||
let plan = LogicalPlanBuilder::from(table_scan.clone()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this test covers the negative case too. |
||
.project(vec![ | ||
extracted_short_circuit.clone().alias("c1"), | ||
extracted_short_circuit.alias("c2"), | ||
col("c") | ||
.gt(lit(0)) | ||
.or(not_extracted_short_circuit_leg.clone()) | ||
.alias("c3"), | ||
col("c") | ||
.gt(lit(1)) | ||
.or(not_extracted_short_circuit_leg) | ||
.alias("c4"), | ||
])? | ||
.build()?; | ||
|
||
let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.c > Int32(0) OR test.a + test.b = Int32(0) AS c3, test.c > Int32(1) OR test.a + test.b = Int32(0) AS c4\ | ||
\n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ | ||
\n TableScan: test"; | ||
|
||
assert_optimized_plan_eq(expected, plan, None); | ||
|
||
Ok(()) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a test case like the one below to check if (a or b) can be extracted as a common subexpr?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR doesn't extract surely evaluated children of short circuiting expressions so I kept that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to add the test (as a negative test perhaps) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've adjusted the test case to test both legs of an OR expression: 8b75d82. In this PR none of them are extraceted, but the after that follow-up PR the srurely evaluated first leg (called |
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the meaning of
is_tree
here.Maybe we could add a comment explaining that
Jump
will skip children but continue with siblingsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can rename
is_tree
tois_short_circuits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I forgot to add a comment here. Basically I wanted to express that we handle the expression as a subtree (not just a node) in this case.
I added a comment in c02bae9.