Skip to content

Commit

Permalink
add with_pretty builder to unparser
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedAbdeen21 committed Jul 9, 2024
1 parent 38a04de commit 98893f0
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 54 deletions.
13 changes: 10 additions & 3 deletions datafusion-examples/examples/parse_sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ async fn query_parquet_demo() -> Result<()> {

/// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`].
async fn round_trip_parse_sql_expr_demo() -> Result<()> {
// unparser can also remove extra parentheses,
// so `((int_col < 5) OR (double_col = 8))` will also produce the same SQL
let sql = "int_col < 5 OR double_col = 8";
let sql = "((int_col < 5) OR (double_col = 8))";

let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();
Expand All @@ -155,5 +153,14 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> {

assert_eq!(sql, round_trip_sql);

// enable pretty-unparsing. This make the output more human-readable
// but can be problematic when passed to other SQL engines due to
// difference in precedence rules between DataFusion and target engines.
let unparser = Unparser::default().with_pretty(true);

let pretty = "int_col < 5 OR double_col = 8";
let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string();
assert_eq!(pretty, pretty_round_trip_sql);

Ok(())
}
22 changes: 17 additions & 5 deletions datafusion-examples/examples/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser};
/// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with
/// fluent API and convert to sql suitable for passing to another database
///
/// 2. [`simple_expr_to_sql_demo_no_escape`] Create a simple expression
/// [`Exprs`] with fluent API and convert to sql without escaping column names
/// more suitable for displaying to humans.
/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression
/// [`Exprs`] with fluent API and convert to sql without extra parentheses,
/// suitable for displaying to humans
///
/// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple
/// expression [`Exprs`] with fluent API and convert to sql escaping column
Expand All @@ -49,6 +49,7 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser};
async fn main() -> Result<()> {
// See how to evaluate expressions
simple_expr_to_sql_demo()?;
simple_expr_to_pretty_sql_demo()?;
simple_expr_to_sql_demo_escape_mysql_style()?;
simple_plan_to_sql_demo().await?;
round_trip_plan_to_sql_demo().await?;
Expand All @@ -60,6 +61,17 @@ async fn main() -> Result<()> {
fn simple_expr_to_sql_demo() -> Result<()> {
let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8)));
let sql = expr_to_sql(&expr)?.to_string();
assert_eq!(sql, r#"((a < 5) OR (a = 8))"#);
Ok(())
}

/// DataFusioon can remove parentheses when converting an expression to SQL.
/// Note that output is intended for humans, not for other SQL engines,
/// as difference in precedence rules can cause expressions to be parsed differently.
fn simple_expr_to_pretty_sql_demo() -> Result<()> {
let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8)));
let unparser = Unparser::default().with_pretty(true);
let sql = unparser.expr_to_sql(&expr)?.to_string();
assert_eq!(sql, r#"a < 5 OR a = 8"#);
Ok(())
}
Expand All @@ -71,7 +83,7 @@ fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> {
let dialect = CustomDialect::new(Some('`'));
let unparser = Unparser::new(&dialect);
let sql = unparser.expr_to_sql(&expr)?.to_string();
assert_eq!(sql, r#"`a` < 5 OR `a` = 8"#);
assert_eq!(sql, r#"((`a` < 5) OR (`a` = 8))"#);
Ok(())
}

Expand Down Expand Up @@ -133,7 +145,7 @@ async fn round_trip_plan_to_sql_demo() -> Result<()> {
let sql = plan_to_sql(df.logical_plan())?.to_string();
assert_eq!(
sql,
r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE alltypes_plain.id > 1 AND alltypes_plain.tinyint_col < alltypes_plain.double_col"#
r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE ((alltypes_plain.id > 1) AND (alltypes_plain.tinyint_col < alltypes_plain.double_col))"#
);

Ok(())
Expand Down
26 changes: 12 additions & 14 deletions datafusion/core/tests/expr_api/parse_sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ fn schema() -> DFSchemaRef {
#[tokio::test]
async fn round_trip_parse_sql_expr() -> Result<()> {
let tests = vec![
"a = 10",
"(a = 10) AND (b <> 20)",
"(a = 10) OR (b <> 20)",
"((a = 10) AND (b <> 20)) OR (c = a)",
"(a = 10) AND b IN (20, 30)",
"(a = 10) AND b NOT IN (20, 30)",
"(a = 10)",
"((a = 10) AND (b <> 20))",
"((a = 10) OR (b <> 20))",
"(((a = 10) AND (b <> 20)) OR (c = a))",
"((a = 10) AND b IN (20, 30))",
"((a = 10) AND b NOT IN (20, 30))",
"sum(a)",
"sum(a) + 1",
"MIN(a) + MAX(b)",
"MIN(a) + (MAX(b) * sum(c))",
"MIN(a) + ((MAX(b) * sum(c)) / 10)",
"(sum(a) + 1)",
"(MIN(a) + MAX(b))",
"(MIN(a) + (MAX(b) * sum(c)))",
"(MIN(a) + ((MAX(b) * sum(c)) / 10))",
];

for test in tests {
Expand All @@ -65,8 +65,7 @@ fn round_trip_session_context(sql: &str) -> Result<()> {
let df_schema = schema();
let expr = ctx.parse_sql_expr(sql, &df_schema)?;
let sql2 = unparse_sql_expr(&expr)?;
let expr2 = ctx.parse_sql_expr(&sql2, &df_schema)?;
assert_eq!(expr.to_string(), expr2.to_string());
assert_eq!(sql, sql2);

Ok(())
}
Expand All @@ -81,8 +80,7 @@ async fn round_trip_dataframe(sql: &str) -> Result<()> {
.await?;
let expr = df.parse_sql_expr(sql)?;
let sql2 = unparse_sql_expr(&expr)?;
let roundtrip = df.parse_sql_expr(&sql2)?;
assert_eq!(expr, roundtrip);
assert_eq!(sql, sql2);

Ok(())
}
Expand Down
60 changes: 33 additions & 27 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl Display for Unparsed {
/// let expr = col("a").gt(lit(4));
/// let sql = expr_to_sql(&expr).unwrap();
///
/// assert_eq!(format!("{}", sql), "a > 4")
/// assert_eq!(format!("{}", sql), "(a > 4)")
/// ```
pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
let unparser = Unparser::default();
Expand All @@ -108,8 +108,11 @@ const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd;

impl Unparser<'_> {
pub fn expr_to_sql(&self, expr: &Expr) -> Result<ast::Expr> {
let root_expr = self.expr_to_sql_inner(expr)?;
Ok(self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST))
let mut root_expr = self.expr_to_sql_inner(expr)?;
if self.pretty {
root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST);
}
Ok(root_expr)
}

fn expr_to_sql_inner(&self, expr: &Expr) -> Result<ast::Expr> {
Expand Down Expand Up @@ -1260,14 +1263,14 @@ mod tests {
.build()?;

let tests: Vec<(Expr, &str)> = vec![
((col("a") + col("b")).gt(lit(4)), r#"a + b > 4"#),
((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#),
(
Expr::Column(Column {
relation: Some(TableReference::partial("a", "b")),
name: "c".to_string(),
})
.gt(lit(4)),
r#"a.b.c > 4"#,
r#"(a.b.c > 4)"#,
),
(
case(col("a"))
Expand Down Expand Up @@ -1483,41 +1486,41 @@ mod tests {
(col("a").is_null(), r#"a IS NULL"#),
(
(col("a") + col("b")).gt(lit(4)).is_true(),
r#"(a + b > 4) IS TRUE"#,
r#"((a + b) > 4) IS TRUE"#,
),
(
(col("a") + col("b")).gt(lit(4)).is_not_true(),
r#"(a + b > 4) IS NOT TRUE"#,
r#"((a + b) > 4) IS NOT TRUE"#,
),
(
(col("a") + col("b")).gt(lit(4)).is_false(),
r#"(a + b > 4) IS FALSE"#,
r#"((a + b) > 4) IS FALSE"#,
),
(
(col("a") + col("b")).gt(lit(4)).is_not_false(),
r#"(a + b > 4) IS NOT FALSE"#,
r#"((a + b) > 4) IS NOT FALSE"#,
),
(
(col("a") + col("b")).gt(lit(4)).is_unknown(),
r#"(a + b > 4) IS UNKNOWN"#,
r#"((a + b) > 4) IS UNKNOWN"#,
),
(
(col("a") + col("b")).gt(lit(4)).is_not_unknown(),
r#"(a + b > 4) IS NOT UNKNOWN"#,
r#"((a + b) > 4) IS NOT UNKNOWN"#,
),
(not(col("a")), r#"NOT a"#),
(
Expr::between(col("a"), lit(1), lit(7)),
r#"a BETWEEN 1 AND 7"#,
r#"(a BETWEEN 1 AND 7)"#,
),
(Expr::Negative(Box::new(col("a"))), r#"-a"#),
(
exists(Arc::new(dummy_logical_plan.clone())),
r#"EXISTS (SELECT t.a FROM t WHERE t.a = 1)"#,
r#"EXISTS (SELECT t.a FROM t WHERE (t.a = 1))"#,
),
(
not_exists(Arc::new(dummy_logical_plan.clone())),
r#"NOT EXISTS (SELECT t.a FROM t WHERE t.a = 1)"#,
r#"NOT EXISTS (SELECT t.a FROM t WHERE (t.a = 1))"#,
),
(
try_cast(col("a"), DataType::Date64),
Expand All @@ -1538,21 +1541,24 @@ mod tests {
),
r#"@root.foo"#,
),
(col("x").eq(placeholder("$1")), r#"x = $1"#),
(out_ref_col(DataType::Int32, "t.a").gt(lit(1)), r#"t.a > 1"#),
(col("x").eq(placeholder("$1")), r#"(x = $1)"#),
(
out_ref_col(DataType::Int32, "t.a").gt(lit(1)),
r#"(t.a > 1)"#,
),
(
grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]),
r#"GROUPING SETS ((a, b), (a))"#,
),
(cube(vec![col("a"), col("b")]), r#"CUBE (a, b)"#),
(rollup(vec![col("a"), col("b")]), r#"ROLLUP (a, b)"#),
(col("table").eq(lit(1)), r#""table" = 1"#),
(col("table").eq(lit(1)), r#"("table" = 1)"#),
(
col("123_need_quoted").eq(lit(1)),
r#""123_need_quoted" = 1"#,
r#"("123_need_quoted" = 1)"#,
),
(col("need-quoted").eq(lit(1)), r#""need-quoted" = 1"#),
(col("need quoted").eq(lit(1)), r#""need quoted" = 1"#),
(col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#),
(col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#),
(
interval_month_day_nano_lit(
"1 YEAR 1 MONTH 1 DAY 3 HOUR 10 MINUTE 20 SECOND",
Expand All @@ -1570,12 +1576,12 @@ mod tests {
(
interval_month_day_nano_lit("1 MONTH")
.add(interval_month_day_nano_lit("1 DAY")),
r#"INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#,
r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' + INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#,
),
(
interval_month_day_nano_lit("1 MONTH")
.sub(interval_month_day_nano_lit("1 DAY")),
r#"INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS'"#,
r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#,
),
(
interval_datetime_lit("10 DAY 1 HOUR 10 MINUTE 20 SECOND"),
Expand All @@ -1599,15 +1605,15 @@ mod tests {
28,
3,
))),
r#"a + b > 100.123"#,
r#"((a + b) > 100.123)"#,
),
(
(col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256(
Some(100123.into()),
28,
3,
))),
r#"a + b > 100.123"#,
r#"((a + b) > 100.123)"#,
),
(
Expr::Cast(Cast {
Expand All @@ -1632,7 +1638,7 @@ mod tests {
#[test]
fn expr_to_unparsed_ok() -> Result<()> {
let tests: Vec<(Expr, &str)> = vec![
((col("a") + col("b")).gt(lit(4)), r#"a + b > 4"#),
((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#),
(col("a").sort(true, true), r#"a ASC NULLS FIRST"#),
];

Expand All @@ -1657,7 +1663,7 @@ mod tests {

let actual = format!("{}", ast);

let expected = r#"'a' > 4"#;
let expected = r#"('a' > 4)"#;
assert_eq!(actual, expected);

Ok(())
Expand All @@ -1673,7 +1679,7 @@ mod tests {

let actual = format!("{}", ast);

let expected = r#"a > 4"#;
let expected = r#"(a > 4)"#;
assert_eq!(actual, expected);

Ok(())
Expand Down
15 changes: 14 additions & 1 deletion datafusion/sql/src/unparser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,31 @@ pub mod dialect;

pub struct Unparser<'a> {
dialect: &'a dyn Dialect,
pretty: bool,
}

impl<'a> Unparser<'a> {
pub fn new(dialect: &'a dyn Dialect) -> Self {
Self { dialect }
Self {
dialect,
pretty: false,
}
}

/// Allow unparser to remove parenthesis according to the precedence rules of DataFusion.
/// This might make it invalid SQL for other SQL query engines with different precedence
/// rules, even if its valid for DataFusion.
pub fn with_pretty(mut self, pretty: bool) -> Self {
self.pretty = pretty;
self
}
}

impl<'a> Default for Unparser<'a> {
fn default() -> Self {
Self {
dialect: &DefaultDialect {},
pretty: false,
}
}
}
12 changes: 8 additions & 4 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ use crate::common::MockContextProvider;
#[test]
fn roundtrip_expr() {
let tests: Vec<(TableReference, &str, &str)> = vec![
(TableReference::bare("person"), "age > 35", r#"age > 35"#),
(TableReference::bare("person"), "id = '10'", r#"id = '10'"#),
(TableReference::bare("person"), "age > 35", r#"(age > 35)"#),
(
TableReference::bare("person"),
"id = '10'",
r#"(id = '10')"#,
),
(
TableReference::bare("person"),
"CAST(id AS VARCHAR)",
Expand All @@ -46,7 +50,7 @@ fn roundtrip_expr() {
(
TableReference::bare("person"),
"sum((age * 2))",
r#"sum(age * 2)"#,
r#"sum((age * 2))"#,
),
];

Expand Down Expand Up @@ -323,7 +327,7 @@ fn test_pretty_roundtrip() -> Result<()> {
let context = MockContextProvider::default();
let sql_to_rel = SqlToRel::new(&context);

let unparser = Unparser::default();
let unparser = Unparser::default().with_pretty(true);

let sql_to_pretty_unparse = vec![
("((id < 5) OR (age = 8))", "id < 5 OR age = 8"),
Expand Down

0 comments on commit 98893f0

Please sign in to comment.