From d34c99cf721dca1f3d8e7b8167b49d198ce192de Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 18 Apr 2018 10:00:27 +0200 Subject: [PATCH] [SPARK-24042][SQL] Collection function: zip_with_index --- python/pyspark/sql/functions.py | 18 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 137 ++++++++++ .../CollectionExpressionsSuite.scala | 54 ++++ .../expressions/ExpressionEvalHelper.scala | 5 + .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 249 ++++++++++++------ 7 files changed, 401 insertions(+), 74 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d36787750fcda..a26389ead6b4b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2064,6 +2064,24 @@ def reverse(col): return Column(sc._jvm.functions.reverse(_to_java_column(col))) +@since(2.4) +def zip_with_index(col, indexFirst=False): + """ + Collection function: transforms the input array by encapsulating elements into pairs + with indexes indicating the order. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 5, 3],), ([],)], ['data']) + >>> df.select(zip_with_index(df.data).alias('r')).collect() + [Row(r=[[value=2, index=0], [value=5, index=1], [value=3, index=2]]), Row(r=[])] + >>> df.select(zip_with_index(df.data, indexFirst=True).alias('r')).collect() + [Row(r=[[index=0, value=2], [index=1, value=5], [index=2, value=3]]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.zip_with_index(_to_java_column(col), indexFirst)) + + @since(2.4) def flatten(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1bcae752b5bd7..3cdb176036782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[Concat]("concat"), expression[Flatten]("flatten"), expression[Reverse]("reverse"), + expression[ZipWithIndex]("zip_with_index"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5db36fbe43de9..c60b4086fe8da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -765,3 +766,139 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } + +/** + * Returns the maximum value in the array. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array[, indexFirst]) - Transforms the input array by encapsulating elements into pairs with indexes indicating the order.", + examples = """ + Examples: + > SELECT _FUNC_(array("d", "a", null, "b")); + [("d",0),("a",1),(null,2),("b",3)] + > SELECT _FUNC_(array("d", "a", null, "b"), true); + [(0,"d"),(1,"a"),(2,null),(3,"b")] + """, + since = "2.4.0") +case class ZipWithIndex(child: Expression, indexFirst: Expression) + extends UnaryExpression with ExpectsInputTypes { + + def this(e: Expression) = this(e, Literal.FalseLiteral) + + val indexFirstValue: Boolean = indexFirst match { + case Literal(v: Boolean, BooleanType) => v + case _ => throw new AnalysisException("The second argument has to be a boolean constant.") + } + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + lazy val childArrayType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def dataType: DataType = { + val elementField = StructField("value", childArrayType.elementType, childArrayType.containsNull) + val indexField = StructField("index", IntegerType, false) + + val fields = if (indexFirstValue) Seq(indexField, elementField) else Seq(elementField, indexField) + + ArrayType(StructType(fields), false) + } + + override protected def nullSafeEval(input: Any): Any = { + val array = input.asInstanceOf[ArrayData].toObjectArray(childArrayType.elementType) + + val makeStruct = (v: Any, i: Int) => if (indexFirstValue) InternalRow(i, v) else InternalRow(v, i) + val resultData = array.zipWithIndex.map{case (v, i) => makeStruct(v, i)} + + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + if (ctx.isPrimitiveType(childArrayType.elementType)) { + genCodeForPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForNonPrimitiveElements(ctx, c, ev.value) + } + }) + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayData: String): String = { + val numElements = ctx.freshName("numElements") + val byteArraySize = ctx.freshName("byteArraySize") + val data = ctx.freshName("byteArray") + val unsafeRow = ctx.freshName("unsafeRow") + val structSize = ctx.freshName("structSize") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val longSize = LongType.defaultSize + val primitiveValueTypeName = ctx.primitiveTypeName(childArrayType.elementType) + val valuePosition = if (indexFirstValue) "1" else "0" + val indexPosition = if (indexFirstValue) "0" else "1" + s""" + |final int $numElements = $childVariableName.numElements(); + |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; + |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; + |if ($byteArraySize > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to zip array with index due to exceeding" + + | " the limit $MAX_ARRAY_LENGTH bytes for UnsafeArrayData. " + $byteArraySize + + | " bytes of data are required for performing the operation with the given array."); + |} + |final byte[] $data = new byte[(int)$byteArraySize]; + |UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); + |Platform.putLong($data, $baseOffset, $numElements); + |$unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSize; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); + | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); + | if ($childVariableName.isNullAt(z)) { + | $unsafeRow.setNullAt($valuePosition); + | } else { + | $unsafeRow.set$primitiveValueTypeName( + | $valuePosition, + | ${ctx.getValue(childVariableName, childArrayType.elementType, "z")} + | ); + | } + | $unsafeRow.setInt($indexPosition, z); + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + } + + private def genCodeForNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayData: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val numberOfElements = ctx.freshName("numElements") + val data = ctx.freshName("internalRowArray") + + val elementValue = ctx.getValue(childVariableName, childArrayType.elementType, "z") + val arguments = if (indexFirstValue) s"z, $elementValue" else s"$elementValue, z" + + s""" + |final int $numberOfElements = $childVariableName.numElements(); + |final Object[] $data = new Object[$numberOfElements]; + |for (int z = 0; z < $numberOfElements; z++) { + | $data[z] = new $rowClass(new Object[]{$arguments}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "zip_with_index" +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 49eaf60ae0221..fceb2268e07d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -286,4 +287,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Reverse(as7), null) checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) } + + test("Zip With Index") { + def r(values: Any*): InternalRow = create_row(values: _*) + val t = Literal.TrueLiteral + val f = Literal.FalseLiteral + + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 8, 4, 7), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(null, 4, null, 2), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai5 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(ZipWithIndex(ai0, f), Seq(r(2, 0), r(8, 1), r(4, 2), r(7, 3))) + checkEvaluation(ZipWithIndex(ai1, f), Seq(r(null, 0), r(4, 1), r(null, 2), r(2, 3))) + checkEvaluation(ZipWithIndex(ai2, f), Seq(r(null, 0), r(null, 1), r(null, 2))) + checkEvaluation(ZipWithIndex(ai3, f), Seq(r(1, 0))) + checkEvaluation(ZipWithIndex(ai4, f), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, f), null) + + checkEvaluation(ZipWithIndex(ai0, t), Seq(r(0, 2), r(1, 8), r(2, 4), r(3, 7))) + checkEvaluation(ZipWithIndex(ai1, t), Seq(r(0, null), r(1, 4), r(2, null), r(3, 2))) + checkEvaluation(ZipWithIndex(ai2, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(ai3, t), Seq(r(0, 1))) + checkEvaluation(ZipWithIndex(ai4, t), Seq.empty) + checkEvaluation(ZipWithIndex(ai5, t), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "y", "z"), ArrayType(StringType)) + val as1 = Literal.create(Seq(null, "x", null, "y"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as3 = Literal.create(Seq("a"), ArrayType(StringType)) + val as4 = Literal.create(Seq.empty, ArrayType(StringType)) + val as5 = Literal.create(null, ArrayType(StringType)) + val aas = Literal.create(Seq(Seq("e"), Seq("c", "d")), ArrayType(ArrayType(StringType))) + + checkEvaluation(ZipWithIndex(as0, f), Seq(r("b", 0), r("a", 1), r("y", 2), r("z", 3))) + checkEvaluation(ZipWithIndex(as1, f), Seq(r(null, 0), r("x", 1), r(null, 2), r("y", 3))) + checkEvaluation(ZipWithIndex(as2, f), Seq(r(null, 0), r(null, 1), r(null, 2))) + checkEvaluation(ZipWithIndex(as3, f), Seq(r("a", 0))) + checkEvaluation(ZipWithIndex(as4, f), Seq.empty) + checkEvaluation(ZipWithIndex(as5, f), null) + checkEvaluation(ZipWithIndex(aas, f), Seq(r(Seq("e"), 0), r(Seq("c", "d"), 1))) + + checkEvaluation(ZipWithIndex(as0, t), Seq(r(0, "b"), r(1, "a"), r(2, "y"), r(3, "z"))) + checkEvaluation(ZipWithIndex(as1, t), Seq(r(0, null), r(1, "x"), r(2, null), r(3, "y"))) + checkEvaluation(ZipWithIndex(as2, t), Seq(r(0, null), r(1, null), r(2, null))) + checkEvaluation(ZipWithIndex(as3, t), Seq(r(0, "a"))) + checkEvaluation(ZipWithIndex(as4, t), Seq.empty) + checkEvaluation(ZipWithIndex(as5, t), null) + checkEvaluation(ZipWithIndex(aas, t), Seq(r(0, Seq("e")), r(1, Seq("c", "d")))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4c8eab19c5cc..93f0e9466dfa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,6 +24,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -88,6 +89,10 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: InternalRow, expected: InternalRow) => + val structType = dataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4f2d6fc0c6af9..2598f9d430cf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3249,6 +3249,17 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Transforms the input array by encapsulating elements into pairs + * with indexes indicating the order. + * + * @group collection_funcs + * @since 2.4.0 + */ + def zip_with_index(e: Column, indexFirst: Boolean = false): Column = withExpr { + ZipWithIndex(e.expr, Literal(indexFirst)) + } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fd847c8a9cd57..eea6d2411b172 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,80 +413,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("concat function - arrays") { - val nseqi : Seq[Int] = null - val nseqs : Seq[String] = null - val df = Seq( - - (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), - (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) - ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") - - val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on - - // Simple test cases - checkAnswer( - df.selectExpr("array(1, 2, 3L)"), - Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) - ) - - checkAnswer ( - df.select(concat($"i1", $"s1")), - Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) - ) - checkAnswer( - df.select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), - Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) - ) - checkAnswer( - df.selectExpr("concat(array(1, null), i2, i3)"), - Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) - ) - checkAnswer( - df.select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.selectExpr("concat(s1, s2, s3)"), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - checkAnswer( - df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), - Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) - ) - - // Null test cases - checkAnswer( - df.select(concat($"i1", $"in")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"in", $"i1")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"s1", $"sn")), - Seq(Row(null), Row(null)) - ) - checkAnswer( - df.select(concat($"sn", $"s1")), - Seq(Row(null), Row(null)) - ) - - // Type error test cases - intercept[AnalysisException] { - df.selectExpr("concat(i1, i2, null)") - } - - intercept[AnalysisException] { - df.selectExpr("concat(i1, array(i1, i2))") - } - } - test("flatten function") { val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") @@ -660,6 +586,181 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("concat function - arrays") { + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null + val df = Seq( + + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on + + // Simple test cases + checkAnswer( + df.selectExpr("array(1, 2, 3L)"), + Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) + ) + + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + + // Null test cases + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + + // Type error test cases + intercept[AnalysisException] { + df.selectExpr("concat(i1, i2, null)") + } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } + } + + test("zip_with_index function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + val oneRowDF = Seq(("Spark", 3215, true)).toDF("s", "i", "b") + + // Test cases with primitive-type elements + val idf = Seq( + Seq(1, 9, 8, 7), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(zip_with_index('i)), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(zip_with_index('i)), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.select(zip_with_index('i, true)), + Seq(Row(Seq(Row(0, 1), Row(1, 9), Row(2, 8), Row(3, 7))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("zip_with_index(i)"), + Seq(Row(Seq(Row(1, 0), Row(9, 1), Row(8, 2), Row(7, 3))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("zip_with_index(i, true)"), + Seq(Row(Seq(Row(0, 1), Row(1, 9), Row(2, 8), Row(3, 7))), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), false)"), + Seq(Row(Seq(Row(null, 0), Row(2, 1), Row(null, 2)))) + ) + checkAnswer( + oneRowDF.selectExpr("zip_with_index(array(null, 2, null), true)"), + Seq(Row(Seq(Row(0, null), Row(1, 2), Row(2, null)))) + ) + + // Test cases with non-primitive-type elements + val sdf = Seq( + Seq("c", "a", "d", "b"), + Seq(null, "x", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(zip_with_index('s)), + Seq( + Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), + Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(zip_with_index('s)), + Seq( + Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), + Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.select(zip_with_index('s, true)), + Seq( + Row(Seq(Row(0, "c"), Row(1, "a"), Row(2, "d"), Row(3, "b"))), + Row(Seq(Row(0, null), Row(1, "x"), Row(2, null))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s)"), + Seq( + Row(Seq(Row("c", 0), Row("a", 1), Row("d", 2), Row("b", 3))), + Row(Seq(Row(null, 0), Row("x", 1), Row(null, 2))), + Row(Seq.empty), + Row(null)) + ) + checkAnswer( + sdf.selectExpr("zip_with_index(s, true)"), + Seq( + Row(Seq(Row(0, "c"), Row(1, "a"), Row(2, "d"), Row(3, "b"))), + Row(Seq(Row(0, null), Row(1, "x"), Row(2, null))), + Row(Seq.empty), + Row(null)) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(zip_with_index('s)) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), b)") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("zip_with_index(array(1, 2, 3), 1)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {