Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22771][SQL] Concatenate binary inputs into a binary output #19977

Closed
wants to merge 16 commits into from
Closed
3 changes: 2 additions & 1 deletion R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2088,7 +2088,8 @@ setMethod("countDistinct",
})

#' @details
#' \code{concat}: Concatenates multiple input string columns together into a single string column.
#' \code{concat}: Concatenates multiple input columns together into a single column.
#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
#'
#' @rdname column_string_functions
#' @aliases concat concat,Column-method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,29 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) {
}
return Arrays.copyOfRange(bytes, start, end);
}

public static byte[] concat(byte[]... inputs) {
// Compute the total length of the result
int totalLength = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
totalLength += inputs[i].length;
} else {
return null;
}
}

// Allocate a new byte array, and copy the inputs one by one into it
final byte[] result = new byte[totalLength];
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].length;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

null check here too?

Copy link
Member Author

@maropu maropu Dec 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha, I see. UTF8String seems to need the same null check?

Copy link
Member Author

@maropu maropu Dec 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Platform.copyMemory(
inputs[i], Platform.BYTE_ARRAY_OFFSET,
result, Platform.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
}
return result;
}
}
2 changes: 2 additions & 0 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,8 @@ options.

- Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489).

- Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`.

## Upgrading From Spark SQL 2.1 to 2.2

- Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,8 @@ def hash(*cols):
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input string columns together into a single string column.
Concatenates multiple input columns together into a single column.
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class Analyzer(
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveTimeZone(conf) ::
TypeCoercion.typeCoercionRules ++
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Batch("View", Once,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


Expand All @@ -45,13 +46,14 @@ import org.apache.spark.sql.types._
*/
object TypeCoercion {

val typeCoercionRules =
def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
InConversion ::
WidenSetOperationTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanEquality ::
FunctionArgumentConversion ::
ConcatCoercion(conf) ::
CaseWhenCoercion ::
IfCoercion ::
StackCoercion ::
Expand Down Expand Up @@ -658,6 +660,29 @@ object TypeCoercion {
}
}

/**
* Coerces the types of [[Concat]] children to expected ones.
*
* If `spark.sql.function.concatBinaryAsString` is false and all children types are binary,
* the expected types are binary. Otherwise, the expected ones are strings.
*/
case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule {

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or empty children
case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, we cant cuz we hit unresolved calls.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empty line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

case c @ Concat(children) if conf.concatBinaryAsString ||
!children.map(_.dataType).forall(_ == BinaryType) =>
val newChildren = c.children.map { e =>
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
}
c.copy(children = newChildren)
}
}
}

/**
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
* to TimeAdd/TimeSub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ import java.util.regex.Pattern

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

Expand All @@ -38,7 +37,8 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}


/**
* An expression that concatenates multiple input strings into a single string.
* An expression that concatenates multiple inputs into a single output.
* If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
* If any input is null, concat returns null.
*/
@ExpressionDescription(
Expand All @@ -48,17 +48,37 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
""")
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
case class Concat(children: Seq[Expression]) extends Expression {

override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType
private lazy val isBinaryMode: Boolean = dataType == BinaryType

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
val childTypes = children.map(_.dataType)
if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have StringType or BinaryType, but it's " +
childTypes.map(_.simpleString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
}
}

override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)

override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = {
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
if (isBinaryMode) {
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
} else {
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -73,17 +93,27 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}

val (concatenator, initCode) = if (isBinaryMode) {
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
} else {
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = ("UTF8String[]", args) :: Nil)
extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
ev.copy(s"""
UTF8String[] $args = new UTF8String[${evals.length}];
$initCode
$codes
UTF8String ${ev.value} = UTF8String.concat($args);
${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}

override def toString: String = s"concat(${children.mkString(", ")})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to override sql


override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -645,6 +646,12 @@ object CombineConcats extends Rule[LogicalPlan] {
stack.pop() match {
case Concat(children) =>
stack.pushAll(children.reverse)
// If `spark.sql.function.concatBinaryAsString` is false, nested `Concat` exprs possibly
// have `Concat`s with binary output. Since `TypeCoercion` casts them into strings,
// we need to handle the case to combine all nested `Concat`s.
case c @ Cast(Concat(children), StringType, _) =>
val newChildren = children.map { e => c.copy(child = e) }
stack.pushAll(newChildren.reverse)
case child =>
flattened += child
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,12 @@ object SQLConf {
"When this conf is not set, the value from `spark.redaction.string.regex` is used.")
.fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN)

val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString")
.doc("When this option is set to false and all inputs are binary, `functions.concat` returns " +
"an output as binary. Otherwise, it returns as a string. ")
.booleanConf
.createWithDefault(false)

val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
Expand Down Expand Up @@ -1378,6 +1384,8 @@ class SQLConf extends Serializable with Logging {
def continuousStreamingExecutorPollIntervalMs: Long =
getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS)

def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,60 @@ class TypeCoercionSuite extends AnalysisTest {
Literal.create(null, IntegerType), Literal.create(null, StringType))))
}

test("type coercion for Concat") {
val rule = TypeCoercion.ConcatCoercion(conf)

ruleTest(rule,
Concat(Seq(Literal("ab"), Literal("cde"))),
Concat(Seq(Literal("ab"), Literal("cde"))))
ruleTest(rule,
Concat(Seq(Literal(null), Literal("abc"))),
Concat(Seq(Cast(Literal(null), StringType), Literal("abc"))))
ruleTest(rule,
Concat(Seq(Literal(1), Literal("234"))),
Concat(Seq(Cast(Literal(1), StringType), Literal("234"))))
ruleTest(rule,
Concat(Seq(Literal("1"), Literal("234".getBytes()))),
Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType))))
ruleTest(rule,
Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))),
Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType),
Cast(Literal(0.1), StringType))))
ruleTest(rule,
Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))),
Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType),
Cast(Literal(3.toShort), StringType))))
ruleTest(rule,
Concat(Seq(Literal(1L), Literal(0.1))),
Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType))))
ruleTest(rule,
Concat(Seq(Literal(Decimal(10)))),
Concat(Seq(Cast(Literal(Decimal(10)), StringType))))
ruleTest(rule,
Concat(Seq(Literal(BigDecimal.valueOf(10)))),
Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType))))
ruleTest(rule,
Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))),
Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
ruleTest(rule,
Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType),
Cast(Literal(new Timestamp(0)), StringType))))

withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") {
ruleTest(rule,
Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
Concat(Seq(Cast(Literal("123".getBytes), StringType),
Cast(Literal("456".getBytes), StringType))))
}

withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") {
ruleTest(rule,
Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))))
}
}

test("BooleanEquality type cast") {
val be = TypeCoercion.BooleanEquality
// Use something more than a literal to avoid triggering the simplification rules.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.StringType


class CombineConcatsSuite extends PlanTest {
Expand All @@ -37,8 +36,10 @@ class CombineConcatsSuite extends PlanTest {
comparePlans(actual, correctAnswer)
}

def str(s: String): Literal = Literal(s)
def binary(s: String): Literal = Literal(s.getBytes)

test("combine nested Concat exprs") {
def str(s: String): Literal = Literal(s, StringType)
assertEquivalent(
Concat(
Concat(str("a") :: str("b") :: Nil) ::
Expand Down Expand Up @@ -72,4 +73,13 @@ class CombineConcatsSuite extends PlanTest {
Nil),
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
}

test("combine string and binary exprs") {
assertEquivalent(
Concat(
Concat(str("a") :: str("b") :: Nil) ::
Concat(binary("c") :: binary("d") :: Nil) ::
Nil),
Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil))
}
}
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,8 @@ object functions {
def base64(e: Column): Column = withExpr { Base64(e.expr) }

/**
* Concatenates multiple input string columns together into a single string column.
* Concatenates multiple input columns together into a single column.
* If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we update document for python and R?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

*
* @group string_funcs
* @since 1.5.0
Expand Down
14 changes: 14 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,17 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null);
select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a');
select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null);
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a');

-- turn on concatBinaryAsString
set spark.sql.function.concatBinaryAsString=false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turn on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since most of other dbms-like systems concat binary inputs as binary, IMO turning off by default is okay to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant you said turn on in the comment (L28).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh....


-- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false
EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
FROM (
SELECT
string(id) col1,
string(id + 1) col2,
encode(string(id + 2), 'utf-8') col3,
encode(string(id + 3), 'utf-8') col4
FROM range(10)
);
Loading