From 179c6fdf261d3392d4d3477a68f7fde60d190435 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Dec 2017 11:36:42 +0900 Subject: [PATCH] Fix --- .../sql/catalyst/analysis/TypeCoercion.scala | 5 ++--- .../sql/catalyst/optimizer/expressions.scala | 13 +++++++------ .../optimizer/CombineConcatsSuite.scala | 14 ++++++++++++-- .../sql-tests/inputs/string-functions.sql | 2 +- .../results/string-functions.sql.out | 19 +------------------ 5 files changed, 23 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 8e82dac0eb631..dab3b05a65f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -680,9 +680,8 @@ object TypeCoercion { // Skip nodes if unresolved or empty children case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if !children.map(_.dataType).forall(_ == BinaryType) => - typeCastToString(c) - case c @ Concat(children) if conf.concatBinaryAsString => + case c @ Concat(children) if conf.concatBinaryAsString || + !children.map(_.dataType).forall(_ == BinaryType) => typeCastToString(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 293fa16b84dda..64fa3cf3726d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -646,16 +646,17 @@ object CombineConcats extends Rule[LogicalPlan] { stack.pop() match { case Concat(children) => stack.pushAll(children.reverse) - case Cast(Concat(children), StringType, _) => - stack.pushAll(children.reverse) + // If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly + // have `Concat`s with binary output. Since `TypeCoercion` casts them into strings, + // we need to handle the case to combine all nested `Concat`s. + case c @ Cast(Concat(children), StringType, _) => + val newChildren = children.map { e => c.copy(child = e) } + stack.pushAll(newChildren.reverse) case child => flattened += child } } - val newChildren = flattened.map { e => - ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) - } - Concat(newChildren) + Concat(flattened) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala index 412e199dfaae3..441c15340a778 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.StringType class CombineConcatsSuite extends PlanTest { @@ -37,8 +36,10 @@ class CombineConcatsSuite extends PlanTest { comparePlans(actual, correctAnswer) } + def str(s: String): Literal = Literal(s) + def binary(s: String): Literal = Literal(s.getBytes) + test("combine nested Concat exprs") { - def str(s: String): Literal = Literal(s, StringType) assertEquivalent( Concat( Concat(str("a") :: str("b") :: Nil) :: @@ -72,4 +73,13 @@ class CombineConcatsSuite extends PlanTest { Nil), Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil)) } + + test("combine string and binary exprs") { + assertEquivalent( + Concat( + Concat(str("a") :: str("b") :: Nil) :: + Concat(binary("c") :: binary("d") :: Nil) :: + Nil), + Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil)) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 0439b2a142dc7..3ed3db8c85134 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -29,7 +29,7 @@ select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a'); set spark.sql.function.concatBinaryAsString=false; -- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false -EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col +EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col FROM ( SELECT string(id) col1, diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 8d84b4ab4d253..3f182c5c50c39 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -129,7 +129,7 @@ spark.sql.function.concatBinaryAsString false -- !query 13 -EXPLAIN EXTENDED SELECT ((col1 || col2) || (col3 || col4)) col +EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col FROM ( SELECT string(id) col1, @@ -141,23 +141,6 @@ FROM ( -- !query 13 schema struct -- !query 13 output -== Parsed Logical Plan == -'Project [concat(concat('col1, 'col2), concat('col3, 'col4)) AS col#x] -+- 'SubqueryAlias __auto_generated_subquery_name - +- 'Project ['string('id) AS col1#x, 'string(('id + 1)) AS col2#x, 'encode('string(('id + 2)), utf-8) AS col3#x, 'encode('string(('id + 3)), utf-8) AS col4#x] - +- 'UnresolvedTableValuedFunction range, [10] - -== Analyzed Logical Plan == -col: string -Project [concat(concat(col1#x, col2#x), cast(concat(col3#x, col4#x) as string)) AS col#x] -+- SubqueryAlias __auto_generated_subquery_name - +- Project [cast(id#xL as string) AS col1#x, cast((id#xL + cast(1 as bigint)) as string) AS col2#x, encode(cast((id#xL + cast(2 as bigint)) as string), utf-8) AS col3#x, encode(cast((id#xL + cast(3 as bigint)) as string), utf-8) AS col4#x] - +- Range (0, 10, step=1, splits=None) - -== Optimized Logical Plan == -Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] -+- Range (0, 10, step=1, splits=None) - == Physical Plan == *Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] +- *Range (0, 10, step=1, splits=2)