Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20778][SQL] Implement array_intersect function. #18010

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArrayIntersect]("array_intersect"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[MapKeys]("map_keys"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.catalyst.expressions

import java.util._
import java.util.Comparator

import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -287,3 +288,126 @@ case class ArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}

@ExpressionDescription(
usage = "_FUNC_(array, array, ...) - Returns intersection of multiple arrays.",
extended = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(3, 4), array(0, 1, 3));
array(1)
""")
case class ArrayIntersect(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override def checkInputDataTypes(): TypeCheckResult = {
val types = children.map(_.dataType)
types.foreach { t =>
if (!t.isInstanceOf[NullType] && !t.isInstanceOf[ArrayType]) {
return TypeCheckResult.TypeCheckFailure(
s"input to $prettyName should be an array type, but it's " +
types.map(_.simpleString).mkString("[", ", ", "]"))
}
}

TypeCheckResult.TypeCheckSuccess
}

override def dataType: DataType = {
children.headOption.map(_.dataType).getOrElse(NullType)
}

override def nullable: Boolean = children.exists(_.nullable)

override def eval(input: InternalRow): Any = {
if (nullable) {
null
} else {
val arrays = children.map(_.eval(input).asInstanceOf[ArrayData].array)
var results = arrays.head
arrays.tail.foreach {
array => results = results.filter(elem => array.contains(elem))
}
new GenericArrayData(results)
}
}

private def doGenJavaArray(
ctx: CodegenContext,
arrayCodeType: (ExprCode, DataType)): (String, String) = {
val objArrayName = ctx.freshName("array")
val tmpIndex = ctx.freshName("index")

val (ev, arrayDataType) = arrayCodeType
val elemDataType = arrayDataType.asInstanceOf[ArrayType].elementType
val boxedJavaDataType = ctx.boxedType(elemDataType)
val getValueCode = ctx.getValue(ev.value, elemDataType, tmpIndex)

(objArrayName,
s"""
${ev.code}
${boxedJavaDataType}[] ${objArrayName} = new ${boxedJavaDataType}[${ev.value}.numElements()];
for (int ${tmpIndex}=0; ${tmpIndex}<${ev.value}.numElements(); ${tmpIndex}++) {
${objArrayName}[${tmpIndex}] = ${getValueCode};
}
""".stripMargin
)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arraysCode = children.map(e => (
e.genCode(ctx),
e.dataType))

val arrayDataName = ctx.freshName("arrayData")
val resultsArrayListName = ctx.freshName("resultArrayList")

val genericArrayClass = classOf[GenericArrayData].getName
val arrayListClass = classOf[ArrayList[Any]].getName
val listClass = classOf[List[Any]].getName
val arraysClass = classOf[Arrays].getName

if (nullable) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
""")
} else {
val (resultsName, genResultsCode) = doGenJavaArray(ctx, arraysCode.head)
val setupResultsCode =
s"""
${arrayListClass} ${resultsArrayListName} = new ${arrayListClass}();
${genResultsCode}
${resultsArrayListName}.addAll(${arraysClass}.asList(${resultsName}));
""".stripMargin

val intersectArraysCode = arraysCode.tail.map {
arrayCode => {
val tmpListName = ctx.freshName("array")
val (arrayTmpName, genArrayTmpCode) = doGenJavaArray(ctx, arrayCode)
s"""
${genArrayTmpCode}
${listClass} ${tmpListName} = ${arraysClass}.asList(${arrayTmpName});
${resultsArrayListName}.retainAll(${tmpListName});
""".stripMargin
}
}

val resultsAsArrayDataCode =
s"""
final ArrayData ${arrayDataName} = new ${genericArrayClass}(${resultsArrayListName});
""".stripMargin

ev.copy(
code = setupResultsCode
+ ctx.splitExpressions(
intersectArraysCode, "apply",
("InternalRow", ctx.INPUT_ROW) :: (arrayListClass, resultsArrayListName) :: Nil)
+ resultsAsArrayDataCode,
value = arrayDataName,
isNull = "false")
}
}

override def prettyName: String = "array_intersect"
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,117 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("Array intersects") {
val a0 = Literal.create(1, IntegerType)
val a1 = Literal.create(2, IntegerType)
val a2 = Literal.create(3, IntegerType)
val a3 = Literal.create(4, IntegerType)

val b0 = Literal.create(1L, LongType)
val b2 = Literal.create(3L, LongType)

val c0 = Literal.create(1.0, DoubleType)
val d0 = Literal.create("1", StringType)

val nullLiteral = Literal.create(null)

checkEvaluation(ArrayIntersect(Seq(nullLiteral)), null)

checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq()))), Seq())

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(a0, a1)), CreateArray(Seq(a2)), CreateArray(Seq(a3)))), Seq())

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(a0, a1)), CreateArray(Seq(a0)))), Seq(1))

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(a0, a1)), CreateArray(Seq(a0)), CreateArray(Seq(a0)))), Seq(1))

checkEvaluation(ArrayIntersect(Seq(ArrayIntersect(Seq(CreateArray(
Seq(a0, a1)), CreateArray(Seq(a2)))), CreateArray(Seq(a0)))), Seq())

checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq(a0, a1)),
CreateArray(Seq(Cast(b0, IntegerType), Cast(b2, IntegerType))))), Seq(1))

checkEvaluation(ArrayIntersect(
Seq(CreateArray(Seq(a0, a0, a1, a3)), CreateArray(Seq(a0, a2, a3, a3)),
CreateArray(Seq(a0, a1, a3)))), Seq(1, 1, 4))

checkEvaluation(ArrayIntersect(Seq(nullLiteral, CreateArray(Seq(a0, a2, a3, a3)),
CreateArray(Seq(a0, a1, a3)))), null)

checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq(a0, a1)), CreateArray(Seq(a0)))), Seq(1))

checkEvaluation(If(LessThan(Rand(0L), c0), a0, a0), 1)

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a0, a1, a3)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2, a3, a3)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2, a3)))), Seq(1, 1, 4))

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0))))), Seq(1))

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))), Seq())

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(a0, a1)), CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0))))), Seq(1))

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(a0, a1)), CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))), Seq())

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1)), CreateArray(Seq(a0)))), Seq(1))

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(a0, a1, a2)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1, a2)),
CreateArray(Seq(a0, a1, a2)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2)))),
Seq(1, 3))

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(d0, Cast(a1, StringType), Cast(a2, StringType))),
CreateArray(Seq(a1, a2)))),
Seq())

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(a0, a1, a2)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1, a2)),
CreateArray(Seq(a0, a1, a2)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0), Cast(a2, StringType))))),
Seq())

checkEvaluation(ArrayIntersect(Seq(
nullLiteral,
CreateArray(Seq(a0, a1, a2)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))),
null)

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(If(LessThan(Rand(0L), c0), nullLiteral, nullLiteral))),
CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2)),
CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))),
Seq())

checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq(a0, a1)), nullLiteral)), null)

checkEvaluation(ArrayIntersect(Seq(nullLiteral, CreateArray(Seq(a0, a1)))), null)

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(If(LessThan(Rand(0L), c0), nullLiteral, nullLiteral))),
nullLiteral, nullLiteral, CreateArray(Seq(a0)))), null)

checkEvaluation(ArrayIntersect(Seq(
CreateArray(Seq(If(LessThan(Rand(0L), c0), nullLiteral, nullLiteral))),
CreateArray(Seq(a0, a1)),
CreateArray(Seq(a0)))),
Seq())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
Alias(expression, s"Optimized($expression)2")() :: Nil),
expression)

plan.initialize(0)
val unsafeRow = plan(inputRow)
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"

Expand Down