From c65d72a61540d04d451976c2bb92a34ceeeb052e Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Sun, 30 Jun 2024 14:30:37 +0300 Subject: [PATCH 01/17] initial prettier unparse --- .../examples/parse_sql_expr.rs | 48 ++++++++ datafusion/sql/src/unparser/expr.rs | 107 +++++++++++++++++- datafusion/sql/tests/cases/plan_to_sql.rs | 26 +++-- 3 files changed, 167 insertions(+), 14 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 6444eb68b6b2..af3c2da0d54e 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -48,6 +48,7 @@ async fn main() -> Result<()> { simple_dataframe_parse_sql_expr_demo().await?; query_parquet_demo().await?; round_trip_parse_sql_expr_demo().await?; + round_trip_parse_sql_expr_pretty_demo().await?; Ok(()) } @@ -155,3 +156,50 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { Ok(()) } + +// TODO: Move these to sql/tests/cases/plan_to_sql.rs +/// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. +async fn round_trip_parse_sql_expr_pretty_demo() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let unparser = Unparser::default(); + + let sql_pairs = vec![ + ( + "((int_col < 5) OR (double_col = 8))", + "int_col < 5 OR double_col = 8", + ), + ( + "((int_col + 5) * (double_col * 8))", + "(int_col + 5) * double_col * 8", + ), + ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"), + ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"), + ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), + ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), + ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), + ( + "((int_col > 10) AND (double_col BETWEEN 10 AND 20))", + "int_col > 10 AND double_col BETWEEN 10 AND 20", + ), + ( + "((int_col > 10) * (double_col BETWEEN 10 AND 20))", + "(int_col > 10) * (double_col BETWEEN 10 AND 20)", + ), + ]; + + for (sql, pretty) in sql_pairs.iter() { + let parsed_expr = df.parse_sql_expr(sql)?; + let round_trip_sql = unparser.pretty_expr_to_sql(&parsed_expr)?.to_string(); + assert_eq!(pretty.to_string(), round_trip_sql); + } + + Ok(()) +} diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ad898de5987a..7a6c2f8bfc54 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -30,8 +30,8 @@ use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Expr as AstExpr, Function, FunctionArg, Ident, Interval, TimezoneInfo, - UnaryOperator, + self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, + TimezoneInfo, UnaryOperator, }; use datafusion_common::{ @@ -102,6 +102,16 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { } impl Unparser<'_> { + /// Try to unparse the expression into a more human-readable format + /// by removing unnecessary parentheses. + pub fn pretty_expr_to_sql(&self, expr: &Expr) -> Result { + let root_expr = self.expr_to_sql(expr)?; + match root_expr { + ast::Expr::Nested(nested) => Ok(self.pretty(*nested, 100, 100)), + expr => Ok(self.pretty(expr, 100, 100)), + } + } + pub fn expr_to_sql(&self, expr: &Expr) -> Result { match expr { Expr::InList(InList { @@ -603,6 +613,55 @@ impl Unparser<'_> { } } + /// Given an expression of the form `(a + b) * (c * d)`, + /// the parenthesing is redundant if the precedence of the nested expression is already higher + /// than the surrounding operators' precedence. The above expression would become + /// `(a + b) * c * d`. + /// + /// Also note that when fetching the precedence of a nested expression, we ignore other nested + /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. + /// + /// Note that outermost parentheses should be removed before calling this function. + fn pretty( + &self, + expr: ast::Expr, + left_precedence: u8, + right_precedence: u8, + ) -> ast::Expr { + match expr { + ast::Expr::Nested(nested) => { + let surrounding_precedence = left_precedence.min(right_precedence); + let inner_precedence = self.lowest_inner_precedence(&nested); + if inner_precedence >= surrounding_precedence { + self.pretty(*nested, left_precedence, right_precedence) + } else { + ast::Expr::Nested(Box::new(self.pretty(*nested, 100, 100))) + } + } + ast::Expr::BinaryOp { left, op, right } => { + let op_precedence = self.sql_op_precedence(&op); + + ast::Expr::BinaryOp { + left: Box::new(self.pretty(*left, left_precedence, op_precedence)), + right: Box::new(self.pretty(*right, op_precedence, right_precedence)), + op, + } + } + _ => expr, + } + } + + fn lowest_inner_precedence(&self, expr: &ast::Expr) -> u8 { + match expr { + ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, + ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), + ast::Expr::Between { .. } => { + self.sql_op_precedence(&ast::BinaryOperator::And) + } + _ => 0, + } + } + pub(super) fn between_op_to_sql( &self, expr: ast::Expr, @@ -618,6 +677,50 @@ impl Unparser<'_> { } } + // TODO: operator precedence should be defined in sqlparser + // to avoid the need for sql_to_op and sql_op_precedence + fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 { + match self.sql_to_op(op) { + Ok(op) => op.precedence(), + Err(_) => 0, + } + } + + fn sql_to_op(&self, op: &BinaryOperator) -> Result { + match op { + ast::BinaryOperator::Eq => Ok(Operator::Eq), + ast::BinaryOperator::NotEq => Ok(Operator::NotEq), + ast::BinaryOperator::Lt => Ok(Operator::Lt), + ast::BinaryOperator::LtEq => Ok(Operator::LtEq), + ast::BinaryOperator::Gt => Ok(Operator::Gt), + ast::BinaryOperator::GtEq => Ok(Operator::GtEq), + ast::BinaryOperator::Plus => Ok(Operator::Plus), + ast::BinaryOperator::Minus => Ok(Operator::Minus), + ast::BinaryOperator::Multiply => Ok(Operator::Multiply), + ast::BinaryOperator::Divide => Ok(Operator::Divide), + ast::BinaryOperator::Modulo => Ok(Operator::Modulo), + ast::BinaryOperator::And => Ok(Operator::And), + ast::BinaryOperator::Or => Ok(Operator::Or), + ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), + ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), + ast::BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), + ast::BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), + ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), + ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), + ast::BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), + ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), + ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), + ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + ast::BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), + ast::BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), + ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat), + ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow), + ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + _ => not_impl_err!("unsupported operation: {op:?}"), + } + } + fn op_to_sql(&self, op: &Operator) -> Result { match op { Operator::Eq => Ok(ast::BinaryOperator::Eq), diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 374403d853f9..c92d52264f50 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -104,26 +104,26 @@ fn roundtrip_statement() -> Result<()> { "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", "select id, count(*), first_name from person group by first_name, id", "select id, sum(age), first_name from person group by first_name, id", - "select id, count(*), first_name - from person + "select id, count(*), first_name + from person where id!=3 and first_name=='test' - group by first_name, id + group by first_name, id having count(*)>5 and count(*)<10 order by count(*)", - r#"select id, count("First Name") as count_first_name, "Last Name" + r#"select id, count("First Name") as count_first_name, "Last Name" from person_quoted_cols where id!=3 and "First Name"=='test' - group by "Last Name", id + group by "Last Name", id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"select p.id, count("First Name") as count_first_name, - "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) from (select id, "First Name", "Last Name" from person_quoted_cols) qp inner join (select * from person) p on p.id = qp.id - where p.id!=3 and "First Name"=='test' and qp.id in + where p.id!=3 and "First Name"=='test' and qp.id in (select id from (select id, count(*) from person group by id having count(*) > 0)) - group by "Last Name", p.id + group by "Last Name", p.id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"SELECT j1_string as string FROM j1 @@ -134,12 +134,12 @@ fn roundtrip_statement() -> Result<()> { SELECT j2_string as string FROM j2 ORDER BY string DESC LIMIT 10"#, - "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), - last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", - r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, - "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", + "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", ]; // For each test sql string, we transform as follows: @@ -314,3 +314,5 @@ fn test_table_references_in_plan_to_sql() { "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"", ); } + +// TODO: Pretty unparse tests here From 1e255676345b9335be8c92ab6a7b3ea063ae8045 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Sun, 30 Jun 2024 17:50:28 +0300 Subject: [PATCH 02/17] bug fix --- datafusion/sql/src/unparser/expr.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 7a6c2f8bfc54..27bd242e2900 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -106,10 +106,7 @@ impl Unparser<'_> { /// by removing unnecessary parentheses. pub fn pretty_expr_to_sql(&self, expr: &Expr) -> Result { let root_expr = self.expr_to_sql(expr)?; - match root_expr { - ast::Expr::Nested(nested) => Ok(self.pretty(*nested, 100, 100)), - expr => Ok(self.pretty(expr, 100, 100)), - } + Ok(self.pretty(root_expr, 0, 0)) } pub fn expr_to_sql(&self, expr: &Expr) -> Result { @@ -613,15 +610,13 @@ impl Unparser<'_> { } } - /// Given an expression of the form `(a + b) * (c * d)`, + /// Given an expression of the form `((a + b) * (c * d))`, /// the parenthesing is redundant if the precedence of the nested expression is already higher /// than the surrounding operators' precedence. The above expression would become /// `(a + b) * c * d`. /// /// Also note that when fetching the precedence of a nested expression, we ignore other nested /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. - /// - /// Note that outermost parentheses should be removed before calling this function. fn pretty( &self, expr: ast::Expr, @@ -630,12 +625,12 @@ impl Unparser<'_> { ) -> ast::Expr { match expr { ast::Expr::Nested(nested) => { - let surrounding_precedence = left_precedence.min(right_precedence); + let surrounding_precedence = left_precedence.max(right_precedence); let inner_precedence = self.lowest_inner_precedence(&nested); if inner_precedence >= surrounding_precedence { self.pretty(*nested, left_precedence, right_precedence) } else { - ast::Expr::Nested(Box::new(self.pretty(*nested, 100, 100))) + ast::Expr::Nested(Box::new(self.pretty(*nested, 0, 0))) } } ast::Expr::BinaryOp { left, op, right } => { From 79532b958fcf33d3771e130764e6121075007553 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Sun, 30 Jun 2024 18:36:45 +0300 Subject: [PATCH 03/17] handling minus and divide --- .../examples/parse_sql_expr.rs | 5 +++ datafusion/sql/src/unparser/expr.rs | 39 +++++++++++-------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index af3c2da0d54e..4ef71aa649ff 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -193,6 +193,11 @@ async fn round_trip_parse_sql_expr_pretty_demo() -> Result<()> { "((int_col > 10) * (double_col BETWEEN 10 AND 20))", "(int_col > 10) * (double_col BETWEEN 10 AND 20)", ), + ("int_col - (double_col - 8)", "int_col - (double_col - 8)"), + ("((int_col - double_col) - 8)", "int_col - double_col - 8"), + ("(int_col OR (double_col - 8))", "int_col OR double_col - 8"), + ("(int_col / (double_col - 8))", "int_col / (double_col - 8)"), + ("((int_col / double_col) * 8)", "int_col / double_col * 8"), ]; for (sql, pretty) in sql_pairs.iter() { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 27bd242e2900..9fae65d7e562 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -101,12 +101,14 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { unparser.expr_to_unparsed(expr) } +const LOWEST: BinaryOperator = BinaryOperator::BitwiseOr; + impl Unparser<'_> { /// Try to unparse the expression into a more human-readable format /// by removing unnecessary parentheses. pub fn pretty_expr_to_sql(&self, expr: &Expr) -> Result { let root_expr = self.expr_to_sql(expr)?; - Ok(self.pretty(root_expr, 0, 0)) + Ok(self.pretty(root_expr, &LOWEST, &LOWEST)) } pub fn expr_to_sql(&self, expr: &Expr) -> Result { @@ -620,28 +622,33 @@ impl Unparser<'_> { fn pretty( &self, expr: ast::Expr, - left_precedence: u8, - right_precedence: u8, + left_op: &BinaryOperator, + right_op: &BinaryOperator, ) -> ast::Expr { match expr { ast::Expr::Nested(nested) => { - let surrounding_precedence = left_precedence.max(right_precedence); + let surrounding_precedence = self + .sql_op_precedence(left_op) + .max(self.sql_op_precedence(right_op)); + let inner_precedence = self.lowest_inner_precedence(&nested); - if inner_precedence >= surrounding_precedence { - self.pretty(*nested, left_precedence, right_precedence) - } else { - ast::Expr::Nested(Box::new(self.pretty(*nested, 0, 0))) - } - } - ast::Expr::BinaryOp { left, op, right } => { - let op_precedence = self.sql_op_precedence(&op); - ast::Expr::BinaryOp { - left: Box::new(self.pretty(*left, left_precedence, op_precedence)), - right: Box::new(self.pretty(*right, op_precedence, right_precedence)), - op, + let not_associative = + matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide); + + if inner_precedence == surrounding_precedence && not_associative { + ast::Expr::Nested(Box::new(self.pretty(*nested, &LOWEST, &LOWEST))) + } else if inner_precedence >= surrounding_precedence { + self.pretty(*nested, left_op, right_op) + } else { + ast::Expr::Nested(Box::new(self.pretty(*nested, &LOWEST, &LOWEST))) } } + ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp { + left: Box::new(self.pretty(*left, left_op, &op)), + right: Box::new(self.pretty(*right, &op, right_op)), + op, + }, _ => expr, } } From 5c6aecaaf871e17e6fd088902d846deee61d7ab2 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Sun, 30 Jun 2024 18:51:10 +0300 Subject: [PATCH 04/17] cleaning references and comments --- datafusion-examples/examples/parse_sql_expr.rs | 1 - datafusion/sql/src/unparser/expr.rs | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 4ef71aa649ff..68ee9732a220 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -158,7 +158,6 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { } // TODO: Move these to sql/tests/cases/plan_to_sql.rs -/// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. async fn round_trip_parse_sql_expr_pretty_demo() -> Result<()> { let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9fae65d7e562..d635a559c940 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -101,14 +101,14 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { unparser.expr_to_unparsed(expr) } -const LOWEST: BinaryOperator = BinaryOperator::BitwiseOr; +const LOWEST: &BinaryOperator = &BinaryOperator::BitwiseOr; impl Unparser<'_> { /// Try to unparse the expression into a more human-readable format /// by removing unnecessary parentheses. pub fn pretty_expr_to_sql(&self, expr: &Expr) -> Result { let root_expr = self.expr_to_sql(expr)?; - Ok(self.pretty(root_expr, &LOWEST, &LOWEST)) + Ok(self.pretty(root_expr, LOWEST, LOWEST)) } pub fn expr_to_sql(&self, expr: &Expr) -> Result { @@ -631,17 +631,17 @@ impl Unparser<'_> { .sql_op_precedence(left_op) .max(self.sql_op_precedence(right_op)); - let inner_precedence = self.lowest_inner_precedence(&nested); + let inner_precedence = self.inner_precedence(&nested); let not_associative = matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide); if inner_precedence == surrounding_precedence && not_associative { - ast::Expr::Nested(Box::new(self.pretty(*nested, &LOWEST, &LOWEST))) + ast::Expr::Nested(Box::new(self.pretty(*nested, LOWEST, LOWEST))) } else if inner_precedence >= surrounding_precedence { self.pretty(*nested, left_op, right_op) } else { - ast::Expr::Nested(Box::new(self.pretty(*nested, &LOWEST, &LOWEST))) + ast::Expr::Nested(Box::new(self.pretty(*nested, LOWEST, LOWEST))) } } ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp { @@ -653,7 +653,7 @@ impl Unparser<'_> { } } - fn lowest_inner_precedence(&self, expr: &ast::Expr) -> u8 { + fn inner_precedence(&self, expr: &ast::Expr) -> u8 { match expr { ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), From dcc666459b6b950cd0033083fdb535bd1787358f Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Sun, 30 Jun 2024 21:31:58 +0300 Subject: [PATCH 05/17] moved tests --- .../examples/parse_sql_expr.rs | 52 ------------------- datafusion/sql/tests/cases/plan_to_sql.rs | 50 +++++++++++++++++- 2 files changed, 49 insertions(+), 53 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 68ee9732a220..6444eb68b6b2 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -48,7 +48,6 @@ async fn main() -> Result<()> { simple_dataframe_parse_sql_expr_demo().await?; query_parquet_demo().await?; round_trip_parse_sql_expr_demo().await?; - round_trip_parse_sql_expr_pretty_demo().await?; Ok(()) } @@ -156,54 +155,3 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { Ok(()) } - -// TODO: Move these to sql/tests/cases/plan_to_sql.rs -async fn round_trip_parse_sql_expr_pretty_demo() -> Result<()> { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); - let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; - - let unparser = Unparser::default(); - - let sql_pairs = vec![ - ( - "((int_col < 5) OR (double_col = 8))", - "int_col < 5 OR double_col = 8", - ), - ( - "((int_col + 5) * (double_col * 8))", - "(int_col + 5) * double_col * 8", - ), - ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"), - ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"), - ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), - ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), - ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), - ( - "((int_col > 10) AND (double_col BETWEEN 10 AND 20))", - "int_col > 10 AND double_col BETWEEN 10 AND 20", - ), - ( - "((int_col > 10) * (double_col BETWEEN 10 AND 20))", - "(int_col > 10) * (double_col BETWEEN 10 AND 20)", - ), - ("int_col - (double_col - 8)", "int_col - (double_col - 8)"), - ("((int_col - double_col) - 8)", "int_col - double_col - 8"), - ("(int_col OR (double_col - 8))", "int_col OR double_col - 8"), - ("(int_col / (double_col - 8))", "int_col / (double_col - 8)"), - ("((int_col / double_col) * 8)", "int_col / double_col * 8"), - ]; - - for (sql, pretty) in sql_pairs.iter() { - let parsed_expr = df.parse_sql_expr(sql)?; - let round_trip_sql = unparser.pretty_expr_to_sql(&parsed_expr)?.to_string(); - assert_eq!(pretty.to_string(), round_trip_sql); - } - - Ok(()) -} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index c92d52264f50..75b8088749b3 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -315,4 +315,52 @@ fn test_table_references_in_plan_to_sql() { ); } -// TODO: Pretty unparse tests here +#[test] +fn test_pretty_roundtrip() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let df_schema = DFSchema::try_from(schema)?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + + let unparser = Unparser::default(); + + let sql_to_pretty_unparse = vec![ + ("((id < 5) OR (age = 8))", "id < 5 OR age = 8"), + ("((id + 5) * (age * 8))", "(id + 5) * age * 8"), + ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"), + ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"), + ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), + ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), + ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"), + ( + "((id > 10) AND (age BETWEEN 10 AND 20))", + "id > 10 AND age BETWEEN 10 AND 20", + ), + ( + "((id > 10) * (age BETWEEN 10 AND 20))", + "(id > 10) * (age BETWEEN 10 AND 20)", + ), + ("id - (age - 8)", "id - (age - 8)"), + ("((id - age) - 8)", "id - age - 8"), + ("(id OR (age - 8))", "id OR age - 8"), + ("(id / (age - 8))", "id / (age - 8)"), + ("((id / age) * 8)", "id / age * 8"), + ]; + + for (sql, pretty) in sql_to_pretty_unparse.iter() { + let sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(sql)? + .parse_expr()?; + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + let round_trip_sql = unparser.pretty_expr_to_sql(&expr)?.to_string(); + assert_eq!(pretty.to_string(), round_trip_sql); + } + + Ok(()) +} From 29b5aa5449294d49271f62a35559bd5c4258f4e9 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 2 Jul 2024 19:28:44 +0300 Subject: [PATCH 06/17] Update precedence of BETWEEN --- datafusion/sql/src/unparser/expr.rs | 4 +++- datafusion/sql/tests/cases/plan_to_sql.rs | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d635a559c940..fd125bbbc39a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -657,8 +657,10 @@ impl Unparser<'_> { match expr { ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), + // closest precedence we currently have to Between is PGLikeMatch + // (https://www.postgresql.org/docs/7.2/sql-precedence.html) ast::Expr::Between { .. } => { - self.sql_op_precedence(&ast::BinaryOperator::And) + self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch) } _ => 0, } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 75b8088749b3..9b116d5e0511 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -338,8 +338,8 @@ fn test_pretty_roundtrip() -> Result<()> { ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"), ( - "((id > 10) AND (age BETWEEN 10 AND 20))", - "id > 10 AND age BETWEEN 10 AND 20", + "((id > 10) || (age BETWEEN 10 AND 20))", + "id > 10 || age BETWEEN 10 AND 20", ), ( "((id > 10) * (age BETWEEN 10 AND 20))", From 384dde1d863bfbf4b387b3657c34e2b8d98ff6af Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 2 Jul 2024 19:41:28 +0300 Subject: [PATCH 07/17] rerun CI From 4d6967ccc42dc2d16adb54cb787de308b588ecc8 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 2 Jul 2024 20:34:36 +0300 Subject: [PATCH 08/17] Change precedence to match PGSQLs --- datafusion/expr/src/operator.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index 742511822a0f..4ae3b0272948 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -223,24 +223,18 @@ impl Operator { match self { Operator::Or => 5, Operator::And => 10, - Operator::NotEq - | Operator::Eq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq => 20, - Operator::Plus | Operator::Minus => 30, - Operator::Multiply | Operator::Divide | Operator::Modulo => 40, + Operator::Eq => 15, + Operator::NotEq => 20, + Operator::LikeMatch + | Operator::NotLikeMatch + | Operator::ILikeMatch + | Operator::NotILikeMatch => 25, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom | Operator::RegexMatch | Operator::RegexNotMatch | Operator::RegexIMatch | Operator::RegexNotIMatch - | Operator::LikeMatch - | Operator::ILikeMatch - | Operator::NotLikeMatch - | Operator::NotILikeMatch | Operator::BitwiseAnd | Operator::BitwiseOr | Operator::BitwiseShiftLeft @@ -248,7 +242,10 @@ impl Operator { | Operator::BitwiseXor | Operator::StringConcat | Operator::AtArrow - | Operator::ArrowAt => 0, + | Operator::ArrowAt => 30, + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => 35, + Operator::Plus | Operator::Minus => 40, + Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } } } From 2c8f5c4cf63f123fa7fb98113654ee78afb88ad0 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 2 Jul 2024 20:35:18 +0300 Subject: [PATCH 09/17] more pretty unparser tests --- datafusion/sql/src/unparser/expr.rs | 2 +- datafusion/sql/tests/cases/plan_to_sql.rs | 22 ++++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index fd125bbbc39a..8cff8d9e7a1e 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -101,7 +101,7 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { unparser.expr_to_unparsed(expr) } -const LOWEST: &BinaryOperator = &BinaryOperator::BitwiseOr; +const LOWEST: &BinaryOperator = &BinaryOperator::Or; impl Unparser<'_> { /// Try to unparse the expression into a more human-readable format diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 9b116d5e0511..654f9e29ca3d 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -337,9 +337,14 @@ fn test_pretty_roundtrip() -> Result<()> { ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"), + ("3 + 5 + 6 + 3", "3 + 5 + 6 + 3"), + ("3 + (5 + (6 + 3))", "3 + 5 + 6 + 3"), + ("3 + ((5 + 6) + 3)", "3 + 5 + 6 + 3"), + ("(3 + 5) + (6 + 3)", "3 + 5 + 6 + 3"), + ("((3 + 5) + (6 + 3))", "3 + 5 + 6 + 3"), ( - "((id > 10) || (age BETWEEN 10 AND 20))", - "id > 10 || age BETWEEN 10 AND 20", + "((id > 10) OR (age BETWEEN 10 AND 20))", + "id > 10 OR age BETWEEN 10 AND 20", ), ( "((id > 10) * (age BETWEEN 10 AND 20))", @@ -360,6 +365,19 @@ fn test_pretty_roundtrip() -> Result<()> { sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; let round_trip_sql = unparser.pretty_expr_to_sql(&expr)?.to_string(); assert_eq!(pretty.to_string(), round_trip_sql); + + // verify that the pretty string parses to the same underlying Expr + let pretty_sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(pretty)? + .parse_expr()?; + + let pretty_expr = sql_to_rel.sql_to_expr( + pretty_sql_expr, + &df_schema, + &mut PlannerContext::new(), + )?; + + assert_eq!(expr.to_string(), pretty_expr.to_string()); } Ok(()) From f753f05be70c4dd72e75f530e941345a6b0a7f96 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 2 Jul 2024 21:18:31 +0300 Subject: [PATCH 10/17] Update operator precedence to match latest PGSQL --- datafusion/expr/src/operator.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index 4ae3b0272948..0cbf9f00821a 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -218,13 +218,13 @@ impl Operator { } /// Get the operator precedence - /// use as a reference + /// use as a reference pub fn precedence(&self) -> u8 { match self { Operator::Or => 5, Operator::And => 10, - Operator::Eq => 15, - Operator::NotEq => 20, + Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq => 15, + Operator::Lt | Operator::Gt => 20, Operator::LikeMatch | Operator::NotLikeMatch | Operator::ILikeMatch @@ -243,7 +243,6 @@ impl Operator { | Operator::StringConcat | Operator::AtArrow | Operator::ArrowAt => 30, - Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => 35, Operator::Plus | Operator::Minus => 40, Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } From 912ce3aa3ed60f7cb5bff58d677168da5ab483de Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 3 Jul 2024 07:41:43 +0300 Subject: [PATCH 11/17] directly prettify expr_to_sql --- datafusion/sql/src/unparser/expr.rs | 157 +++++++++++----------- datafusion/sql/tests/cases/plan_to_sql.rs | 12 +- 2 files changed, 81 insertions(+), 88 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8cff8d9e7a1e..c67f84bf6296 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -85,7 +85,7 @@ impl Display for Unparsed { /// let expr = col("a").gt(lit(4)); /// let sql = expr_to_sql(&expr).unwrap(); /// -/// assert_eq!(format!("{}", sql), "(a > 4)") +/// assert_eq!(format!("{}", sql), "a > 4") /// ``` pub fn expr_to_sql(expr: &Expr) -> Result { let unparser = Unparser::default(); @@ -104,14 +104,12 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { const LOWEST: &BinaryOperator = &BinaryOperator::Or; impl Unparser<'_> { - /// Try to unparse the expression into a more human-readable format - /// by removing unnecessary parentheses. - pub fn pretty_expr_to_sql(&self, expr: &Expr) -> Result { - let root_expr = self.expr_to_sql(expr)?; - Ok(self.pretty(root_expr, LOWEST, LOWEST)) + pub fn expr_to_sql(&self, expr: &Expr) -> Result { + let root_expr = self.expr_to_sql_inner(expr)?; + Ok(self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST)) } - pub fn expr_to_sql(&self, expr: &Expr) -> Result { + fn expr_to_sql_inner(&self, expr: &Expr) -> Result { match expr { Expr::InList(InList { expr, @@ -120,10 +118,10 @@ impl Unparser<'_> { }) => { let list_expr = list .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?; Ok(ast::Expr::InList { - expr: Box::new(self.expr_to_sql(expr)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), list: list_expr, negated: *negated, }) @@ -137,7 +135,7 @@ impl Unparser<'_> { if matches!(e, Expr::Wildcard { qualifier: None }) { Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) } else { - self.expr_to_sql(e).map(|e| { + self.expr_to_sql_inner(e).map(|e| { FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) }) } @@ -166,9 +164,9 @@ impl Unparser<'_> { low, high, }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; - let sql_low = self.expr_to_sql(low)?; - let sql_high = self.expr_to_sql(high)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; + let sql_low = self.expr_to_sql_inner(low)?; + let sql_high = self.expr_to_sql_inner(high)?; Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql( sql_parser_expr, *negated, @@ -178,8 +176,8 @@ impl Unparser<'_> { } Expr::Column(col) => self.col_to_sql(col), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = self.expr_to_sql(left.as_ref())?; - let r = self.expr_to_sql(right.as_ref())?; + let l = self.expr_to_sql_inner(left.as_ref())?; + let r = self.expr_to_sql_inner(right.as_ref())?; let op = self.op_to_sql(op)?; Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) @@ -191,21 +189,21 @@ impl Unparser<'_> { }) => { let conditions = when_then_expr .iter() - .map(|(w, _)| self.expr_to_sql(w)) + .map(|(w, _)| self.expr_to_sql_inner(w)) .collect::>>()?; let results = when_then_expr .iter() - .map(|(_, t)| self.expr_to_sql(t)) + .map(|(_, t)| self.expr_to_sql_inner(t)) .collect::>>()?; let operand = match expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, None => None, }; let else_result = match else_expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, @@ -220,7 +218,7 @@ impl Unparser<'_> { }) } Expr::Cast(Cast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), @@ -229,7 +227,7 @@ impl Unparser<'_> { }) } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), - Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr), + Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(WindowFunction { fun, args, @@ -264,7 +262,7 @@ impl Unparser<'_> { window_name: None, partition_by: partition_by .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?, order_by, window_frame: Some(ast::WindowFrame { @@ -305,8 +303,8 @@ impl Unparser<'_> { case_insensitive: _, }) => Ok(ast::Expr::Like { negated: *negated, - expr: Box::new(self.expr_to_sql(expr)?), - pattern: Box::new(self.expr_to_sql(pattern)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), escape_char: escape_char.map(|c| c.to_string()), }), Expr::AggregateFunction(agg) => { @@ -314,7 +312,7 @@ impl Unparser<'_> { let args = self.function_args_to_sql(&agg.args)?; let filter = match &agg.filter { - Some(filter) => Some(Box::new(self.expr_to_sql(filter)?)), + Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; Ok(ast::Expr::Function(Function { @@ -348,7 +346,7 @@ impl Unparser<'_> { Ok(ast::Expr::Subquery(sub_query)) } Expr::InSubquery(insubq) => { - let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?); + let inexpr = Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?); let sub_statement = self.plan_to_sql(insubq.subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement @@ -386,38 +384,38 @@ impl Unparser<'_> { nulls_first: _, }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), Expr::IsNull(expr) => { - Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotNull(expr) => { - Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsTrue(expr) => { - Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotTrue(expr) => { - Ok(ast::Expr::IsNotTrue(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsFalse(expr) => { - Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotFalse(expr) => { - Ok(ast::Expr::IsNotFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsUnknown(expr) => { - Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotUnknown(expr) => { - Ok(ast::Expr::IsNotUnknown(Box::new(self.expr_to_sql(expr)?))) - } + Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?))) + } + Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::Not(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Not, expr: Box::new(sql_parser_expr), }) } Expr::Negative(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Minus, expr: Box::new(sql_parser_expr), @@ -441,7 +439,7 @@ impl Unparser<'_> { }) } Expr::TryCast(TryCast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), @@ -458,7 +456,7 @@ impl Unparser<'_> { .iter() .map(|set| { set.iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>() }) .collect::>>()?; @@ -469,7 +467,7 @@ impl Unparser<'_> { let expr_ast_sets = cube .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -479,7 +477,7 @@ impl Unparser<'_> { let expr_ast_sets: Vec> = rollup .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -619,7 +617,7 @@ impl Unparser<'_> { /// /// Also note that when fetching the precedence of a nested expression, we ignore other nested /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. - fn pretty( + fn remove_unnecessary_nesting( &self, expr: ast::Expr, left_op: &BinaryOperator, @@ -637,16 +635,20 @@ impl Unparser<'_> { matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide); if inner_precedence == surrounding_precedence && not_associative { - ast::Expr::Nested(Box::new(self.pretty(*nested, LOWEST, LOWEST))) + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) } else if inner_precedence >= surrounding_precedence { - self.pretty(*nested, left_op, right_op) + self.remove_unnecessary_nesting(*nested, left_op, right_op) } else { - ast::Expr::Nested(Box::new(self.pretty(*nested, LOWEST, LOWEST))) + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) } } ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp { - left: Box::new(self.pretty(*left, left_op, &op)), - right: Box::new(self.pretty(*right, &op, right_op)), + left: Box::new(self.remove_unnecessary_nesting(*left, left_op, &op)), + right: Box::new(self.remove_unnecessary_nesting(*right, &op, right_op)), op, }, _ => expr, @@ -681,8 +683,6 @@ impl Unparser<'_> { } } - // TODO: operator precedence should be defined in sqlparser - // to avoid the need for sql_to_op and sql_op_precedence fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 { match self.sql_to_op(op) { Ok(op) => op.precedence(), @@ -1233,14 +1233,14 @@ mod tests { .build()?; let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), + ((col("a") + col("b")).gt(lit(4)), r#"a + b > 4"#), ( Expr::Column(Column { relation: Some(TableReference::partial("a", "b")), name: "c".to_string(), }) .gt(lit(4)), - r#"(a.b.c > 4)"#, + r#"a.b.c > 4"#, ), ( case(col("a")) @@ -1481,16 +1481,16 @@ mod tests { (not(col("a")), r#"NOT a"#), ( Expr::between(col("a"), lit(1), lit(7)), - r#"(a BETWEEN 1 AND 7)"#, + r#"a BETWEEN 1 AND 7"#, ), (Expr::Negative(Box::new(col("a"))), r#"-a"#), ( exists(Arc::new(dummy_logical_plan.clone())), - r#"EXISTS (SELECT t.a FROM t WHERE (t.a = 1))"#, + r#"EXISTS (SELECT t.a FROM t WHERE t.a = 1)"#, ), ( not_exists(Arc::new(dummy_logical_plan.clone())), - r#"NOT EXISTS (SELECT t.a FROM t WHERE (t.a = 1))"#, + r#"NOT EXISTS (SELECT t.a FROM t WHERE t.a = 1)"#, ), ( try_cast(col("a"), DataType::Date64), @@ -1511,24 +1511,21 @@ mod tests { ), r#"@root.foo"#, ), - (col("x").eq(placeholder("$1")), r#"(x = $1)"#), - ( - out_ref_col(DataType::Int32, "t.a").gt(lit(1)), - r#"(t.a > 1)"#, - ), + (col("x").eq(placeholder("$1")), r#"x = $1"#), + (out_ref_col(DataType::Int32, "t.a").gt(lit(1)), r#"t.a > 1"#), ( grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]), r#"GROUPING SETS ((a, b), (a))"#, ), (cube(vec![col("a"), col("b")]), r#"CUBE (a, b)"#), (rollup(vec![col("a"), col("b")]), r#"ROLLUP (a, b)"#), - (col("table").eq(lit(1)), r#"("table" = 1)"#), + (col("table").eq(lit(1)), r#""table" = 1"#), ( col("123_need_quoted").eq(lit(1)), - r#"("123_need_quoted" = 1)"#, + r#""123_need_quoted" = 1"#, ), - (col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#), - (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), + (col("need-quoted").eq(lit(1)), r#""need-quoted" = 1"#), + (col("need quoted").eq(lit(1)), r#""need quoted" = 1"#), ( interval_month_day_nano_lit( "1 YEAR 1 MONTH 1 DAY 3 HOUR 10 MINUTE 20 SECOND", @@ -1546,12 +1543,12 @@ mod tests { ( interval_month_day_nano_lit("1 MONTH") .add(interval_month_day_nano_lit("1 DAY")), - r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, + r#"INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, ), ( interval_month_day_nano_lit("1 MONTH") .sub(interval_month_day_nano_lit("1 DAY")), - r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, + r#"INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, ), ( interval_datetime_lit("10 DAY 1 HOUR 10 MINUTE 20 SECOND"), @@ -1575,7 +1572,7 @@ mod tests { 28, 3, ))), - r#"((a + b) > 100.123)"#, + r#"a + b > 100.123"#, ), ( (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( @@ -1583,7 +1580,7 @@ mod tests { 28, 3, ))), - r#"((a + b) > 100.123)"#, + r#"a + b > 100.123"#, ), ( Expr::Cast(Cast { @@ -1608,7 +1605,7 @@ mod tests { #[test] fn expr_to_unparsed_ok() -> Result<()> { let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), + ((col("a") + col("b")).gt(lit(4)), r#"a + b > 4"#), (col("a").sort(true, true), r#"a ASC NULLS FIRST"#), ]; @@ -1632,7 +1629,7 @@ mod tests { let actual = format!("{}", ast); - let expected = r#"('a' > 4)"#; + let expected = r#"'a' > 4"#; assert_eq!(actual, expected); Ok(()) @@ -1648,7 +1645,7 @@ mod tests { let actual = format!("{}", ast); - let expected = r#"(a > 4)"#; + let expected = r#"a > 4"#; assert_eq!(actual, expected); Ok(()) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 654f9e29ca3d..205815c517e1 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -36,12 +36,8 @@ use crate::common::MockContextProvider; #[test] fn roundtrip_expr() { let tests: Vec<(TableReference, &str, &str)> = vec![ - (TableReference::bare("person"), "age > 35", r#"(age > 35)"#), - ( - TableReference::bare("person"), - "id = '10'", - r#"(id = '10')"#, - ), + (TableReference::bare("person"), "age > 35", r#"age > 35"#), + (TableReference::bare("person"), "id = '10'", r#"id = '10'"#), ( TableReference::bare("person"), "CAST(id AS VARCHAR)", @@ -50,7 +46,7 @@ fn roundtrip_expr() { ( TableReference::bare("person"), "sum((age * 2))", - r#"sum((age * 2))"#, + r#"sum(age * 2)"#, ), ]; @@ -363,7 +359,7 @@ fn test_pretty_roundtrip() -> Result<()> { .parse_expr()?; let expr = sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; - let round_trip_sql = unparser.pretty_expr_to_sql(&expr)?.to_string(); + let round_trip_sql = unparser.expr_to_sql(&expr)?.to_string(); assert_eq!(pretty.to_string(), round_trip_sql); // verify that the pretty string parses to the same underlying Expr From eadc0770b4b21dfda1c70f0ff0870ef18605a044 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 3 Jul 2024 07:58:10 +0300 Subject: [PATCH 12/17] handle IS operator --- datafusion/sql/src/unparser/expr.rs | 37 ++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index c67f84bf6296..fbcefd80925c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -651,6 +651,30 @@ impl Unparser<'_> { right: Box::new(self.remove_unnecessary_nesting(*right, &op, right_op)), op, }, + ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), + ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), + ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), + ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), + ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), + ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), + ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), + ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + )), _ => expr, } } @@ -1456,27 +1480,27 @@ mod tests { (col("a").is_null(), r#"a IS NULL"#), ( (col("a") + col("b")).gt(lit(4)).is_true(), - r#"((a + b) > 4) IS TRUE"#, + r#"a + b > 4 IS TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_true(), - r#"((a + b) > 4) IS NOT TRUE"#, + r#"a + b > 4 IS NOT TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_false(), - r#"((a + b) > 4) IS FALSE"#, + r#"a + b > 4 IS FALSE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_false(), - r#"((a + b) > 4) IS NOT FALSE"#, + r#"a + b > 4 IS NOT FALSE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_unknown(), - r#"((a + b) > 4) IS UNKNOWN"#, + r#"a + b > 4 IS UNKNOWN"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_unknown(), - r#"((a + b) > 4) IS NOT UNKNOWN"#, + r#"a + b > 4 IS NOT UNKNOWN"#, ), (not(col("a")), r#"NOT a"#), ( @@ -1619,6 +1643,7 @@ mod tests { Ok(()) } + #[test] fn custom_dialect() -> Result<()> { let dialect = CustomDialect::new(Some('\'')); From 91d8b431c5f7dcf9fededb09192cb28736b382f4 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 3 Jul 2024 08:16:23 +0300 Subject: [PATCH 13/17] correct IS precedence --- datafusion/sql/src/unparser/expr.rs | 31 +++++++++++++---------- datafusion/sql/tests/cases/plan_to_sql.rs | 6 +++++ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index fbcefd80925c..a6968f870798 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -102,6 +102,9 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { } const LOWEST: &BinaryOperator = &BinaryOperator::Or; +// closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs +// (https://www.postgresql.org/docs/7.2/sql-precedence.html) +const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; impl Unparser<'_> { pub fn expr_to_sql(&self, expr: &Expr) -> Result { @@ -652,28 +655,28 @@ impl Unparser<'_> { op, }, ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, LOWEST), + self.remove_unnecessary_nesting(*expr, LOWEST, IS), )), _ => expr, } @@ -1480,27 +1483,27 @@ mod tests { (col("a").is_null(), r#"a IS NULL"#), ( (col("a") + col("b")).gt(lit(4)).is_true(), - r#"a + b > 4 IS TRUE"#, + r#"(a + b > 4) IS TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_true(), - r#"a + b > 4 IS NOT TRUE"#, + r#"(a + b > 4) IS NOT TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_false(), - r#"a + b > 4 IS FALSE"#, + r#"(a + b > 4) IS FALSE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_false(), - r#"a + b > 4 IS NOT FALSE"#, + r#"(a + b > 4) IS NOT FALSE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_unknown(), - r#"a + b > 4 IS UNKNOWN"#, + r#"(a + b > 4) IS UNKNOWN"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_unknown(), - r#"a + b > 4 IS NOT UNKNOWN"#, + r#"(a + b > 4) IS NOT UNKNOWN"#, ), (not(col("a")), r#"NOT a"#), ( diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 205815c517e1..d93659ac260b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -351,6 +351,12 @@ fn test_pretty_roundtrip() -> Result<()> { ("(id OR (age - 8))", "id OR age - 8"), ("(id / (age - 8))", "id / (age - 8)"), ("((id / age) * 8)", "id / age * 8"), + ("((age + 10) < 20) IS TRUE", "(age + 10 < 20) IS TRUE"), + ( + "(20 > (age + 5)) IS NOT FALSE", + "(20 > age + 5) IS NOT FALSE", + ), + ("(TRUE AND FALSE) IS FALSE", "(TRUE AND FALSE) IS FALSE"), ]; for (sql, pretty) in sql_to_pretty_unparse.iter() { From c3dcb02da2f0db3fed202ae5b33b603db2985c35 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 3 Jul 2024 19:21:19 +0300 Subject: [PATCH 14/17] update unparser tests --- .../core/tests/expr_api/parse_sql_expr.rs | 26 ++++++++++--------- datafusion/sql/tests/cases/plan_to_sql.rs | 3 ++- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index 991579b5a350..20db1fd68fcd 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -39,17 +39,17 @@ fn schema() -> DFSchemaRef { #[tokio::test] async fn round_trip_parse_sql_expr() -> Result<()> { let tests = vec![ - "(a = 10)", - "((a = 10) AND (b <> 20))", - "((a = 10) OR (b <> 20))", - "(((a = 10) AND (b <> 20)) OR (c = a))", - "((a = 10) AND b IN (20, 30))", - "((a = 10) AND b NOT IN (20, 30))", + "a = 10", + "(a = 10) AND (b <> 20)", + "(a = 10) OR (b <> 20)", + "((a = 10) AND (b <> 20)) OR (c = a)", + "(a = 10) AND b IN (20, 30)", + "(a = 10) AND b NOT IN (20, 30)", "sum(a)", - "(sum(a) + 1)", - "(MIN(a) + MAX(b))", - "(MIN(a) + (MAX(b) * sum(c)))", - "(MIN(a) + ((MAX(b) * sum(c)) / 10))", + "sum(a) + 1", + "MIN(a) + MAX(b)", + "MIN(a) + (MAX(b) * sum(c))", + "MIN(a) + ((MAX(b) * sum(c)) / 10)", ]; for test in tests { @@ -65,7 +65,8 @@ fn round_trip_session_context(sql: &str) -> Result<()> { let df_schema = schema(); let expr = ctx.parse_sql_expr(sql, &df_schema)?; let sql2 = unparse_sql_expr(&expr)?; - assert_eq!(sql, sql2); + let expr2 = ctx.parse_sql_expr(&sql2, &df_schema)?; + assert_eq!(expr.to_string(), expr2.to_string()); Ok(()) } @@ -80,7 +81,8 @@ async fn round_trip_dataframe(sql: &str) -> Result<()> { .await?; let expr = df.parse_sql_expr(sql)?; let sql2 = unparse_sql_expr(&expr)?; - assert_eq!(sql, sql2); + let roundtrip = df.parse_sql_expr(&sql2)?; + assert_eq!(expr, roundtrip); Ok(()) } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index d93659ac260b..22db64e92bcb 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -356,7 +356,8 @@ fn test_pretty_roundtrip() -> Result<()> { "(20 > (age + 5)) IS NOT FALSE", "(20 > age + 5) IS NOT FALSE", ), - ("(TRUE AND FALSE) IS FALSE", "(TRUE AND FALSE) IS FALSE"), + ("(true AND false) IS FALSE", "(true AND false) IS FALSE"), + ("true AND (false IS FALSE)", "true AND false IS FALSE"), ]; for (sql, pretty) in sql_to_pretty_unparse.iter() { From 61fddb9620a67d81dbc8b8f319e45da40a470d91 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 3 Jul 2024 20:05:14 +0300 Subject: [PATCH 15/17] update unparser example --- datafusion-examples/examples/parse_sql_expr.rs | 4 +++- datafusion/sql/src/unparser/expr.rs | 16 ++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 6444eb68b6b2..3b710783cf23 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -135,7 +135,9 @@ async fn query_parquet_demo() -> Result<()> { /// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. async fn round_trip_parse_sql_expr_demo() -> Result<()> { - let sql = "((int_col < 5) OR (double_col = 8))"; + // unparser can also remove extra parentheses, + // so `((int_col < 5) OR (double_col = 8))` will also produce the same SQL + let sql = "int_col < 5 OR double_col = 8"; let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a6968f870798..482b48d7b756 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -655,28 +655,28 @@ impl Unparser<'_> { op, }, ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new( - self.remove_unnecessary_nesting(*expr, LOWEST, IS), + self.remove_unnecessary_nesting(*expr, left_op, IS), )), _ => expr, } From 38a04de89cfdd41f0300a219cbc9336ae0bdf1a4 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Wed, 3 Jul 2024 21:17:13 +0300 Subject: [PATCH 16/17] update more unparser examples --- datafusion-examples/examples/plan_to_sql.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index bd708fe52bc1..d4297a95c898 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -60,7 +60,7 @@ async fn main() -> Result<()> { fn simple_expr_to_sql_demo() -> Result<()> { let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); let sql = expr_to_sql(&expr)?.to_string(); - assert_eq!(sql, r#"((a < 5) OR (a = 8))"#); + assert_eq!(sql, r#"a < 5 OR a = 8"#); Ok(()) } @@ -71,7 +71,7 @@ fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { let dialect = CustomDialect::new(Some('`')); let unparser = Unparser::new(&dialect); let sql = unparser.expr_to_sql(&expr)?.to_string(); - assert_eq!(sql, r#"((`a` < 5) OR (`a` = 8))"#); + assert_eq!(sql, r#"`a` < 5 OR `a` = 8"#); Ok(()) } @@ -133,7 +133,7 @@ async fn round_trip_plan_to_sql_demo() -> Result<()> { let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE ((alltypes_plain.id > 1) AND (alltypes_plain.tinyint_col < alltypes_plain.double_col))"# + r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE alltypes_plain.id > 1 AND alltypes_plain.tinyint_col < alltypes_plain.double_col"# ); Ok(()) From 98893f0a32df7db660ef11cc4a3c02afa7281d65 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Tue, 9 Jul 2024 11:16:10 +0300 Subject: [PATCH 17/17] add with_pretty builder to unparser --- .../examples/parse_sql_expr.rs | 13 +++- datafusion-examples/examples/plan_to_sql.rs | 22 +++++-- .../core/tests/expr_api/parse_sql_expr.rs | 26 ++++---- datafusion/sql/src/unparser/expr.rs | 60 ++++++++++--------- datafusion/sql/src/unparser/mod.rs | 15 ++++- datafusion/sql/tests/cases/plan_to_sql.rs | 12 ++-- 6 files changed, 94 insertions(+), 54 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 3b710783cf23..e6dfaf6c8a82 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -135,9 +135,7 @@ async fn query_parquet_demo() -> Result<()> { /// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. async fn round_trip_parse_sql_expr_demo() -> Result<()> { - // unparser can also remove extra parentheses, - // so `((int_col < 5) OR (double_col = 8))` will also produce the same SQL - let sql = "int_col < 5 OR double_col = 8"; + let sql = "((int_col < 5) OR (double_col = 8))"; let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); @@ -155,5 +153,14 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { assert_eq!(sql, round_trip_sql); + // enable pretty-unparsing. This make the output more human-readable + // but can be problematic when passed to other SQL engines due to + // difference in precedence rules between DataFusion and target engines. + let unparser = Unparser::default().with_pretty(true); + + let pretty = "int_col < 5 OR double_col = 8"; + let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); + assert_eq!(pretty, pretty_round_trip_sql); + Ok(()) } diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index d4297a95c898..f719a33fb624 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -31,9 +31,9 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with /// fluent API and convert to sql suitable for passing to another database /// -/// 2. [`simple_expr_to_sql_demo_no_escape`] Create a simple expression -/// [`Exprs`] with fluent API and convert to sql without escaping column names -/// more suitable for displaying to humans. +/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans /// /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple /// expression [`Exprs`] with fluent API and convert to sql escaping column @@ -49,6 +49,7 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; async fn main() -> Result<()> { // See how to evaluate expressions simple_expr_to_sql_demo()?; + simple_expr_to_pretty_sql_demo()?; simple_expr_to_sql_demo_escape_mysql_style()?; simple_plan_to_sql_demo().await?; round_trip_plan_to_sql_demo().await?; @@ -60,6 +61,17 @@ async fn main() -> Result<()> { fn simple_expr_to_sql_demo() -> Result<()> { let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); let sql = expr_to_sql(&expr)?.to_string(); + assert_eq!(sql, r#"((a < 5) OR (a = 8))"#); + Ok(()) +} + +/// DataFusioon can remove parentheses when converting an expression to SQL. +/// Note that output is intended for humans, not for other SQL engines, +/// as difference in precedence rules can cause expressions to be parsed differently. +fn simple_expr_to_pretty_sql_demo() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let unparser = Unparser::default().with_pretty(true); + let sql = unparser.expr_to_sql(&expr)?.to_string(); assert_eq!(sql, r#"a < 5 OR a = 8"#); Ok(()) } @@ -71,7 +83,7 @@ fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { let dialect = CustomDialect::new(Some('`')); let unparser = Unparser::new(&dialect); let sql = unparser.expr_to_sql(&expr)?.to_string(); - assert_eq!(sql, r#"`a` < 5 OR `a` = 8"#); + assert_eq!(sql, r#"((`a` < 5) OR (`a` = 8))"#); Ok(()) } @@ -133,7 +145,7 @@ async fn round_trip_plan_to_sql_demo() -> Result<()> { let sql = plan_to_sql(df.logical_plan())?.to_string(); assert_eq!( sql, - r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE alltypes_plain.id > 1 AND alltypes_plain.tinyint_col < alltypes_plain.double_col"# + r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE ((alltypes_plain.id > 1) AND (alltypes_plain.tinyint_col < alltypes_plain.double_col))"# ); Ok(()) diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index 20db1fd68fcd..991579b5a350 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -39,17 +39,17 @@ fn schema() -> DFSchemaRef { #[tokio::test] async fn round_trip_parse_sql_expr() -> Result<()> { let tests = vec![ - "a = 10", - "(a = 10) AND (b <> 20)", - "(a = 10) OR (b <> 20)", - "((a = 10) AND (b <> 20)) OR (c = a)", - "(a = 10) AND b IN (20, 30)", - "(a = 10) AND b NOT IN (20, 30)", + "(a = 10)", + "((a = 10) AND (b <> 20))", + "((a = 10) OR (b <> 20))", + "(((a = 10) AND (b <> 20)) OR (c = a))", + "((a = 10) AND b IN (20, 30))", + "((a = 10) AND b NOT IN (20, 30))", "sum(a)", - "sum(a) + 1", - "MIN(a) + MAX(b)", - "MIN(a) + (MAX(b) * sum(c))", - "MIN(a) + ((MAX(b) * sum(c)) / 10)", + "(sum(a) + 1)", + "(MIN(a) + MAX(b))", + "(MIN(a) + (MAX(b) * sum(c)))", + "(MIN(a) + ((MAX(b) * sum(c)) / 10))", ]; for test in tests { @@ -65,8 +65,7 @@ fn round_trip_session_context(sql: &str) -> Result<()> { let df_schema = schema(); let expr = ctx.parse_sql_expr(sql, &df_schema)?; let sql2 = unparse_sql_expr(&expr)?; - let expr2 = ctx.parse_sql_expr(&sql2, &df_schema)?; - assert_eq!(expr.to_string(), expr2.to_string()); + assert_eq!(sql, sql2); Ok(()) } @@ -81,8 +80,7 @@ async fn round_trip_dataframe(sql: &str) -> Result<()> { .await?; let expr = df.parse_sql_expr(sql)?; let sql2 = unparse_sql_expr(&expr)?; - let roundtrip = df.parse_sql_expr(&sql2)?; - assert_eq!(expr, roundtrip); + assert_eq!(sql, sql2); Ok(()) } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 482b48d7b756..f67cd5928c79 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -85,7 +85,7 @@ impl Display for Unparsed { /// let expr = col("a").gt(lit(4)); /// let sql = expr_to_sql(&expr).unwrap(); /// -/// assert_eq!(format!("{}", sql), "a > 4") +/// assert_eq!(format!("{}", sql), "(a > 4)") /// ``` pub fn expr_to_sql(expr: &Expr) -> Result { let unparser = Unparser::default(); @@ -108,8 +108,11 @@ const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; impl Unparser<'_> { pub fn expr_to_sql(&self, expr: &Expr) -> Result { - let root_expr = self.expr_to_sql_inner(expr)?; - Ok(self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST)) + let mut root_expr = self.expr_to_sql_inner(expr)?; + if self.pretty { + root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST); + } + Ok(root_expr) } fn expr_to_sql_inner(&self, expr: &Expr) -> Result { @@ -1260,14 +1263,14 @@ mod tests { .build()?; let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"a + b > 4"#), + ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), ( Expr::Column(Column { relation: Some(TableReference::partial("a", "b")), name: "c".to_string(), }) .gt(lit(4)), - r#"a.b.c > 4"#, + r#"(a.b.c > 4)"#, ), ( case(col("a")) @@ -1483,41 +1486,41 @@ mod tests { (col("a").is_null(), r#"a IS NULL"#), ( (col("a") + col("b")).gt(lit(4)).is_true(), - r#"(a + b > 4) IS TRUE"#, + r#"((a + b) > 4) IS TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_true(), - r#"(a + b > 4) IS NOT TRUE"#, + r#"((a + b) > 4) IS NOT TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_false(), - r#"(a + b > 4) IS FALSE"#, + r#"((a + b) > 4) IS FALSE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_false(), - r#"(a + b > 4) IS NOT FALSE"#, + r#"((a + b) > 4) IS NOT FALSE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_unknown(), - r#"(a + b > 4) IS UNKNOWN"#, + r#"((a + b) > 4) IS UNKNOWN"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_unknown(), - r#"(a + b > 4) IS NOT UNKNOWN"#, + r#"((a + b) > 4) IS NOT UNKNOWN"#, ), (not(col("a")), r#"NOT a"#), ( Expr::between(col("a"), lit(1), lit(7)), - r#"a BETWEEN 1 AND 7"#, + r#"(a BETWEEN 1 AND 7)"#, ), (Expr::Negative(Box::new(col("a"))), r#"-a"#), ( exists(Arc::new(dummy_logical_plan.clone())), - r#"EXISTS (SELECT t.a FROM t WHERE t.a = 1)"#, + r#"EXISTS (SELECT t.a FROM t WHERE (t.a = 1))"#, ), ( not_exists(Arc::new(dummy_logical_plan.clone())), - r#"NOT EXISTS (SELECT t.a FROM t WHERE t.a = 1)"#, + r#"NOT EXISTS (SELECT t.a FROM t WHERE (t.a = 1))"#, ), ( try_cast(col("a"), DataType::Date64), @@ -1538,21 +1541,24 @@ mod tests { ), r#"@root.foo"#, ), - (col("x").eq(placeholder("$1")), r#"x = $1"#), - (out_ref_col(DataType::Int32, "t.a").gt(lit(1)), r#"t.a > 1"#), + (col("x").eq(placeholder("$1")), r#"(x = $1)"#), + ( + out_ref_col(DataType::Int32, "t.a").gt(lit(1)), + r#"(t.a > 1)"#, + ), ( grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]), r#"GROUPING SETS ((a, b), (a))"#, ), (cube(vec![col("a"), col("b")]), r#"CUBE (a, b)"#), (rollup(vec![col("a"), col("b")]), r#"ROLLUP (a, b)"#), - (col("table").eq(lit(1)), r#""table" = 1"#), + (col("table").eq(lit(1)), r#"("table" = 1)"#), ( col("123_need_quoted").eq(lit(1)), - r#""123_need_quoted" = 1"#, + r#"("123_need_quoted" = 1)"#, ), - (col("need-quoted").eq(lit(1)), r#""need-quoted" = 1"#), - (col("need quoted").eq(lit(1)), r#""need quoted" = 1"#), + (col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#), + (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), ( interval_month_day_nano_lit( "1 YEAR 1 MONTH 1 DAY 3 HOUR 10 MINUTE 20 SECOND", @@ -1570,12 +1576,12 @@ mod tests { ( interval_month_day_nano_lit("1 MONTH") .add(interval_month_day_nano_lit("1 DAY")), - r#"INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, + r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, ), ( interval_month_day_nano_lit("1 MONTH") .sub(interval_month_day_nano_lit("1 DAY")), - r#"INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#, + r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, ), ( interval_datetime_lit("10 DAY 1 HOUR 10 MINUTE 20 SECOND"), @@ -1599,7 +1605,7 @@ mod tests { 28, 3, ))), - r#"a + b > 100.123"#, + r#"((a + b) > 100.123)"#, ), ( (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( @@ -1607,7 +1613,7 @@ mod tests { 28, 3, ))), - r#"a + b > 100.123"#, + r#"((a + b) > 100.123)"#, ), ( Expr::Cast(Cast { @@ -1632,7 +1638,7 @@ mod tests { #[test] fn expr_to_unparsed_ok() -> Result<()> { let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"a + b > 4"#), + ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), (col("a").sort(true, true), r#"a ASC NULLS FIRST"#), ]; @@ -1657,7 +1663,7 @@ mod tests { let actual = format!("{}", ast); - let expected = r#"'a' > 4"#; + let expected = r#"('a' > 4)"#; assert_eq!(actual, expected); Ok(()) @@ -1673,7 +1679,7 @@ mod tests { let actual = format!("{}", ast); - let expected = r#"a > 4"#; + let expected = r#"(a > 4)"#; assert_eq!(actual, expected); Ok(()) diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index fbbed4972b17..e5ffbc8a212a 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,11 +29,23 @@ pub mod dialect; pub struct Unparser<'a> { dialect: &'a dyn Dialect, + pretty: bool, } impl<'a> Unparser<'a> { pub fn new(dialect: &'a dyn Dialect) -> Self { - Self { dialect } + Self { + dialect, + pretty: false, + } + } + + /// Allow unparser to remove parenthesis according to the precedence rules of DataFusion. + /// This might make it invalid SQL for other SQL query engines with different precedence + /// rules, even if its valid for DataFusion. + pub fn with_pretty(mut self, pretty: bool) -> Self { + self.pretty = pretty; + self } } @@ -41,6 +53,7 @@ impl<'a> Default for Unparser<'a> { fn default() -> Self { Self { dialect: &DefaultDialect {}, + pretty: false, } } } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 22db64e92bcb..91295b2e8aae 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -36,8 +36,12 @@ use crate::common::MockContextProvider; #[test] fn roundtrip_expr() { let tests: Vec<(TableReference, &str, &str)> = vec![ - (TableReference::bare("person"), "age > 35", r#"age > 35"#), - (TableReference::bare("person"), "id = '10'", r#"id = '10'"#), + (TableReference::bare("person"), "age > 35", r#"(age > 35)"#), + ( + TableReference::bare("person"), + "id = '10'", + r#"(id = '10')"#, + ), ( TableReference::bare("person"), "CAST(id AS VARCHAR)", @@ -46,7 +50,7 @@ fn roundtrip_expr() { ( TableReference::bare("person"), "sum((age * 2))", - r#"sum(age * 2)"#, + r#"sum((age * 2))"#, ), ]; @@ -323,7 +327,7 @@ fn test_pretty_roundtrip() -> Result<()> { let context = MockContextProvider::default(); let sql_to_rel = SqlToRel::new(&context); - let unparser = Unparser::default(); + let unparser = Unparser::default().with_pretty(true); let sql_to_pretty_unparse = vec![ ("((id < 5) OR (age = 8))", "id < 5 OR age = 8"),