Skip to content

Commit

Permalink
[SPARK-23736][SQL] Extending concat function to be able to work with …
Browse files Browse the repository at this point in the history
…ArrayType
  • Loading branch information
mn-mikke authored and mn-mikke committed Apr 17, 2018
1 parent c9a8977 commit 6c4f8d0
Show file tree
Hide file tree
Showing 12 changed files with 512 additions and 114 deletions.
34 changes: 19 additions & 15 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,21 +1369,6 @@ def hash(*cols):
del _name, _doc


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))


@since(1.5)
@ignore_unicode_prefix
def concat_ws(sep, *cols):
Expand Down Expand Up @@ -1789,6 +1774,25 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
The function works with strings, binary and compatible array columns.
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]
>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@

public final class UnsafeArrayData extends ArrayData {

public static int calculateHeaderPortionInBytes(int numElements) {
return (int)calculateHeaderPortionInBytes((long)numElements);
public static int calculateHeaderPortionInBytes(int numFields) {
return (int)calculateHeaderPortionInBytes((long)numFields);
}

public static long calculateHeaderPortionInBytes(long numFields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ object FunctionRegistry {
expression[BitLength]("bit_length"),
expression[Length]("char_length"),
expression[Length]("character_length"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
expression[Elt]("elt"),
Expand Down Expand Up @@ -407,6 +406,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
expression[Reverse]("reverse"),
CreateStruct.registryEntry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,14 @@ object TypeCoercion {
case None => a
}

case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case None => c
}

case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import java.util.Comparator

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -378,6 +378,218 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}

/**
* Concatenates multiple input columns together into a single column.
* The function works with strings, binary and compatible array columns.
*/
@ExpressionDescription(
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
examples = """
Examples:
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
| [1,2,3,4,5,6]
""")
case class Concat(children: Seq[Expression]) extends Expression {

private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

val allowedTypes = Seq(StringType, BinaryType, ArrayType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
val childTypes = children.map(_.dataType)
if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have been StringType, BinaryType or ArrayType," +
s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
}
}

override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)

override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = dataType match {
case BinaryType =>
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
case StringType =>
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
case ArrayType(elementType, _) =>
val inputs = children.toStream.map(_.eval(input))
if (inputs.contains(null)) {
null
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numberOfElements > MAX_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
for(ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, finalData, position, arr.length)
position += arr.length
}
new GenericArrayData(finalData)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")

val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
}

val (concatenator, initCode) = dataType match {
case BinaryType =>
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, _) =>
val arrayConcatClass = if (ctx.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
ev.copy(s"""
$initCode
$codes
${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}

private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
val numElements = ctx.freshName("numElements")
val code = s"""
|long $numElements = 0L;
|for (int z = 0; z < ${children.length}; z++) {
| $numElements += args[z].numElements();
|}
|if ($numElements > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with $numElements" +
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|}
""".stripMargin

(code, numElements)
}

private def nullArgumentProtection() : String = {
if (nullable) {
s"""
|for (int z = 0; z < ${children.length}; z++) {
| if (args[z] == null) return null;
|}
""".stripMargin
} else {
""
}
}

private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with $arraySizeName bytes" +
| " of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)

s"""
|new Object() {
| public ArrayData concat(${ctx.javaType(dataType)}[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| $unsafeArraySizeInBytes
| byte[] $arrayName = new byte[(int)$arraySizeName];
| UnsafeArrayData $arrayData = new UnsafeArrayData();
| Platform.putLong($arrayName, $baseOffset, $numElemName);
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| if (args[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
| } else {
| $arrayData.set$primitiveValueTypeName(
| $counter,
| ${ctx.getValue(s"args[y]", elementType, "z")}
| );
| }
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
}

private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayData = ctx.freshName("arrayObjects")
val counter = ctx.freshName("counter")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

s"""
|new Object() {
| public ArrayData concat(${ctx.javaType(dataType)}[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${ctx.getValue(s"args[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
}

override def toString: String = s"concat(${children.mkString(", ")})"

override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}

/**
* Transforms an array of arrays into a single array.
*/
Expand Down
Loading

0 comments on commit 6c4f8d0

Please sign in to comment.