From 2c3068398e35010cf27acb8d5d333a3fd51f85c1 Mon Sep 17 00:00:00 2001 From: hhj Date: Sun, 21 Jan 2024 12:50:45 +0800 Subject: [PATCH 1/7] fix: common_subexpr_eliminate rule should not apply to short-circuit expression --- .../optimizer/src/common_subexpr_eliminate.rs | 42 ++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f29c7406acc9..6ffb5790d74b 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -34,7 +34,10 @@ use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::{col, Expr, ExprSchemable}; +use datafusion_expr::{ + col, expr::ScalarFunction, BinaryExpr, BuiltinScalarFunction, Expr, ExprSchemable, + Operator, ScalarFunctionDefinition, +}; /// A map from expression's identifier to tuple including /// - the expression itself (cloned) @@ -616,8 +619,8 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { fn pre_visit(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 - // If the expr contain volatile expression or is a case expression, skip it. - if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? { + // If the expr contain volatile expression or is a short-circuit expression, skip it. + if is_short_circuit_expression(expr) || is_volatile_expression(expr)? { return Ok(VisitRecursion::Skip); } self.visit_stack @@ -655,6 +658,20 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { } } +/// Check if the expression is short-circuit expression +fn is_short_circuit_expression(expr: &Expr) -> bool { + match expr { + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { + matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce) + } + Expr::BinaryExpr(BinaryExpr { op, .. }) => { + matches!(op, Operator::And | Operator::Or) + } + Expr::Case { .. } => true, + _ => false, + } +} + /// Go through an expression tree and generate identifier for every node in this tree. fn expr_to_identifier( expr: &Expr, @@ -696,7 +713,13 @@ struct CommonSubexprRewriter<'a> { impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type N = Expr; - fn pre_visit(&mut self, _: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { + // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate + // the `id_array`, which records the expr's identifier used to rewrite expr. So if we + // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. + if is_short_circuit_expression(expr) || is_volatile_expression(expr)? { + return Ok(RewriteRecursion::Stop); + } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { @@ -782,7 +805,7 @@ mod test { use datafusion_common::DFSchema; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, + avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, not, }; use datafusion_expr::{ grouping_set, AccumulatorFactoryFunction, AggregateUDF, Signature, @@ -1249,13 +1272,12 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))? + .project(vec![not(lit(1).gt(col("a"))), lit(1).gt(col("a"))])? .build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ - \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ - \n TableScan: test"; + let expected = "Projection: NOT Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a, Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ + \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ + \n TableScan: test"; assert_optimized_plan_eq(expected, &plan); From 584f5c848c119c6f2684e81fe1eeefa39d82f345 Mon Sep 17 00:00:00 2001 From: hhj Date: Sun, 21 Jan 2024 14:26:08 +0800 Subject: [PATCH 2/7] add more tests --- .../optimizer/src/common_subexpr_eliminate.rs | 10 +++---- datafusion/sqllogictest/test_files/select.slt | 30 +++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6ffb5790d74b..95bde8f98b2a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -805,7 +805,7 @@ mod test { use datafusion_common::DFSchema; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, not, + avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, }; use datafusion_expr::{ grouping_set, AccumulatorFactoryFunction, AggregateUDF, Signature, @@ -1272,12 +1272,12 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![not(lit(1).gt(col("a"))), lit(1).gt(col("a"))])? + .filter((lit(1)+col("a")-lit(10)).gt(lit(1)+col("a")))? .build()?; - let expected = "Projection: NOT Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a, Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ - \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ - \n TableScan: test"; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ + \n TableScan: test"; assert_optimized_plan_eq(expected, &plan); diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index ca48c07b0914..9c82980ebdee 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1129,5 +1129,35 @@ FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; 0 0 0 0 +query TT +explain select coalesce(1, y/x), coalesce(2, y/x) from t; +---- +logical_plan +Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), CAST(t.y / t.x AS Int64)) +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(2),t.y / t.x)] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; +---- +logical_plan +Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT y == 0 or 1 / y < 1, x == 0 or y == 0 or 1 / y < 1 / x from t; +---- +logical_plan +Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] +--MemoryExec: partitions=1, partition_sizes=[1] + statement ok DROP TABLE t; From 7329b91165f69d6657c1813b2e76a4f0070c2ec5 Mon Sep 17 00:00:00 2001 From: hhj Date: Sun, 21 Jan 2024 14:42:14 +0800 Subject: [PATCH 3/7] format --- datafusion/optimizer/src/common_subexpr_eliminate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 95bde8f98b2a..e722392dd76b 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -1272,7 +1272,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter((lit(1)+col("a")-lit(10)).gt(lit(1)+col("a")))? + .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))? .build()?; let expected = "Projection: test.a, test.b, test.c\ From f89a4d12ee9a4a20961e131226ed837a95692340 Mon Sep 17 00:00:00 2001 From: hhj Date: Sun, 21 Jan 2024 15:19:07 +0800 Subject: [PATCH 4/7] minor --- datafusion/sqllogictest/test_files/select.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 9c82980ebdee..5af310bc22e9 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1150,7 +1150,7 @@ ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) A --MemoryExec: partitions=1, partition_sizes=[1] query TT -EXPLAIN SELECT y == 0 or 1 / y < 1, x == 0 or y == 0 or 1 / y < 1 / x from t; +EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; ---- logical_plan Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x From 560cb6c0b54490170376c6e209fd35f4b29024c2 Mon Sep 17 00:00:00 2001 From: hhj Date: Mon, 22 Jan 2024 16:02:26 +0800 Subject: [PATCH 5/7] apply reviews --- datafusion/expr/src/expr.rs | 45 +++++++++++++++++++ .../optimizer/src/common_subexpr_eliminate.rs | 23 ++-------- datafusion/sqllogictest/test_files/select.slt | 14 ++++++ 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 40d40692e593..ea8a601f0238 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1266,6 +1266,51 @@ impl Expr { Ok(Transformed::Yes(expr)) }) } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered + pub fn short_circuits(&self) -> bool { + match self { + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { + matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce) + } + Expr::BinaryExpr(BinaryExpr { op, .. }) => { + matches!(op, Operator::And | Operator::Or) + } + Expr::Case { .. } => true, + Expr::AggregateFunction(..) + | Expr::Alias(..) + | Expr::Between(..) + | Expr::Cast(..) + | Expr::Column(..) + | Expr::Exists(..) + | Expr::GetIndexedField(..) + | Expr::GroupingSet(..) + | Expr::InList(..) + | Expr::InSubquery(..) + | Expr::IsFalse(..) + | Expr::IsNotFalse(..) + | Expr::IsNotNull(..) + | Expr::IsNotTrue(..) + | Expr::IsNotUnknown(..) + | Expr::IsNull(..) + | Expr::IsTrue(..) + | Expr::IsUnknown(..) + | Expr::Like(..) + | Expr::ScalarSubquery(..) + | Expr::ScalarVariable(_, _) + | Expr::SimilarTo(..) + | Expr::Not(..) + | Expr::Negative(..) + | Expr::OuterReferenceColumn(_, _) + | Expr::TryCast(..) + | Expr::Wildcard { .. } + | Expr::WindowFunction(..) + | Expr::Literal(..) + | Expr::Sort(..) + | Expr::Placeholder(..) => false, + } + } } // modifies expr if it is a placeholder with datatype of right diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e722392dd76b..fe71171ce545 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -34,10 +34,7 @@ use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::{ - col, expr::ScalarFunction, BinaryExpr, BuiltinScalarFunction, Expr, ExprSchemable, - Operator, ScalarFunctionDefinition, -}; +use datafusion_expr::{col, Expr, ExprSchemable}; /// A map from expression's identifier to tuple including /// - the expression itself (cloned) @@ -620,7 +617,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { fn pre_visit(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. - if is_short_circuit_expression(expr) || is_volatile_expression(expr)? { + if expr.short_circuits() || is_volatile_expression(expr)? { return Ok(VisitRecursion::Skip); } self.visit_stack @@ -658,20 +655,6 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { } } -/// Check if the expression is short-circuit expression -fn is_short_circuit_expression(expr: &Expr) -> bool { - match expr { - Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { - matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce) - } - Expr::BinaryExpr(BinaryExpr { op, .. }) => { - matches!(op, Operator::And | Operator::Or) - } - Expr::Case { .. } => true, - _ => false, - } -} - /// Go through an expression tree and generate identifier for every node in this tree. fn expr_to_identifier( expr: &Expr, @@ -717,7 +700,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if is_short_circuit_expression(expr) || is_volatile_expression(expr)? { + if expr.short_circuits() || is_volatile_expression(expr)? { return Ok(RewriteRecursion::Stop); } if self.curr_index >= self.id_array.len() diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 5af310bc22e9..9ffddc6e2d46 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1129,6 +1129,9 @@ FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; 0 0 0 0 +# Expressions that short circuit should not be refactored out as that may cause side effects (divide by zero) +# at plan time that would not actually happen during execution, so the follow three query should not be extract +# the common sub-expression query TT explain select coalesce(1, y/x), coalesce(2, y/x) from t; ---- @@ -1159,5 +1162,16 @@ physical_plan ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] --MemoryExec: partitions=1, partition_sizes=[1] +# due to the reason describe in https://github.com/apache/arrow-datafusion/issues/8927, +# the following queries will fail +query error +select coalesce(1, y/x), coalesce(2, y/x) from t; + +query error +SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; + +query error +SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; + statement ok DROP TABLE t; From fcb37cce9d6d35d5bc2dda2b6cef95a36109d483 Mon Sep 17 00:00:00 2001 From: hhj Date: Mon, 22 Jan 2024 16:06:31 +0800 Subject: [PATCH 6/7] add some commont --- datafusion/expr/src/expr.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ea8a601f0238..6514f5f24c75 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1278,6 +1278,9 @@ impl Expr { matches!(op, Operator::And | Operator::Or) } Expr::Case { .. } => true, + // Use explicit pattern match instead of a default + // implementation, so that in the future if someone adds + // new Expr types, they will check here as well Expr::AggregateFunction(..) | Expr::Alias(..) | Expr::Between(..) From b3f50c0f042dab2787578188ac676af375dd9ff3 Mon Sep 17 00:00:00 2001 From: hhj Date: Mon, 22 Jan 2024 16:09:10 +0800 Subject: [PATCH 7/7] fmt --- datafusion/expr/src/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 6514f5f24c75..1de458a9838f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1280,7 +1280,7 @@ impl Expr { Expr::Case { .. } => true, // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds - // new Expr types, they will check here as well + // new Expr types, they will check here as well Expr::AggregateFunction(..) | Expr::Alias(..) | Expr::Between(..)