Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: common_subexpr_eliminate rule should not apply to short-circuit expression #8928

Merged
merged 7 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,54 @@ impl Expr {
Ok(Transformed::Yes(expr))
})
}

/// Returns true if some of this `exprs` subexpressions may not be evaluated
/// and thus any side effects (like divide by zero) may not be encountered
pub fn short_circuits(&self) -> bool {
match self {
Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce)
}
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
matches!(op, Operator::And | Operator::Or)
}
Expr::Case { .. } => true,
// Use explicit pattern match instead of a default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// implementation, so that in the future if someone adds
// new Expr types, they will check here as well
Expr::AggregateFunction(..)
| Expr::Alias(..)
| Expr::Between(..)
| Expr::Cast(..)
| Expr::Column(..)
| Expr::Exists(..)
| Expr::GetIndexedField(..)
| Expr::GroupingSet(..)
| Expr::InList(..)
| Expr::InSubquery(..)
| Expr::IsFalse(..)
| Expr::IsNotFalse(..)
| Expr::IsNotNull(..)
| Expr::IsNotTrue(..)
| Expr::IsNotUnknown(..)
| Expr::IsNull(..)
| Expr::IsTrue(..)
| Expr::IsUnknown(..)
| Expr::Like(..)
| Expr::ScalarSubquery(..)
| Expr::ScalarVariable(_, _)
| Expr::SimilarTo(..)
| Expr::Not(..)
| Expr::Negative(..)
| Expr::OuterReferenceColumn(_, _)
| Expr::TryCast(..)
| Expr::Wildcard { .. }
| Expr::WindowFunction(..)
| Expr::Literal(..)
| Expr::Sort(..)
| Expr::Placeholder(..) => false,
}
}
}

// modifies expr if it is a placeholder with datatype of right
Expand Down
17 changes: 11 additions & 6 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,8 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {

fn pre_visit(&mut self, expr: &Expr) -> Result<VisitRecursion> {
// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a case expression, skip it.
if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? {
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if expr.short_circuits() || is_volatile_expression(expr)? {
return Ok(VisitRecursion::Skip);
}
self.visit_stack
Expand Down Expand Up @@ -696,7 +696,13 @@ struct CommonSubexprRewriter<'a> {
impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
type N = Expr;

fn pre_visit(&mut self, _: &Expr) -> Result<RewriteRecursion> {
fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if expr.short_circuits() || is_volatile_expression(expr)? {
return Ok(RewriteRecursion::Stop);
}
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
{
Expand Down Expand Up @@ -1249,12 +1255,11 @@ mod test {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))?
.filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
.build()?;

let expected = "Projection: test.a, test.b, test.c\
\n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\
\n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\
\n Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down
44 changes: 44 additions & 0 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1129,5 +1129,49 @@ FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B;
0 0
0 0

# Expressions that short circuit should not be refactored out as that may cause side effects (divide by zero)
# at plan time that would not actually happen during execution, so the follow three query should not be extract
# the common sub-expression
query TT
explain select coalesce(1, y/x), coalesce(2, y/x) from t;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of the reason describe in #8927, use plan to test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add some comments here explaining the rationale and what the expected outputs are (so future readers know if changes are expected)

Something like this perhaps:

# Expressions that short circuit should not be refactored out as that may cause side effects (divide by zero)
# at plan time that would not actually happen during execution

Also, can you please add the tests that now pass (e.g. select coalesce(1, y/x), coalesce(2, y/x) from t;) so if someone breaks this code by accident, those queries would start failing, which might be easier to quickly tell is incorrect

----
logical_plan
Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), CAST(t.y / t.x AS Int64))
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(2),t.y / t.x)]
--MemoryExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;
----
logical_plan
Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x]
--MemoryExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t;
----
logical_plan
Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x]
--MemoryExec: partitions=1, partition_sizes=[1]

# due to the reason describe in https://github.com/apache/arrow-datafusion/issues/8927,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure if those tests are appropriate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's related with #8910

# the following queries will fail
query error
select coalesce(1, y/x), coalesce(2, y/x) from t;

query error
SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;

query error
SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t;

statement ok
DROP TABLE t;
Loading