-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-28962][SQL] Provide index argument to filter lambda functions #25666
Conversation
Ok to test |
ok to test |
Test build #110089 has finished for PR 25666 at commit
|
Test build #110572 has finished for PR 25666 at commit
|
retest this please |
Test build #110580 has finished for PR 25666 at commit
|
ok to test |
Test build #110695 has finished for PR 25666 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM so far, but we might need to add tests to DataFrameFunctionsSuite
for the new usage.
@@ -344,6 +344,8 @@ case class MapFilter( | |||
Examples: | |||
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); | |||
[1,3] | |||
> SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i); | |||
[2, 3] | |||
""", | |||
since = "2.4.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a note
to describe this can take the index argument since 3.0.0? E.g., from collectionOperations.scala
:
Lines 1049 to 1051 in 7402935
note = """ | |
Reverse logic for arrays is available since 2.4.0. | |
""" |
@@ -344,6 +344,8 @@ case class MapFilter( | |||
Examples: | |||
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); | |||
[1,3] | |||
> SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i); | |||
[2, 3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: [2,3]
?
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function | ||
@transient lazy val (elementVar, indexVar) = { | ||
val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function | ||
val indexVar = if (tail.nonEmpty) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: val indexVar = tail.headOption.map(_.asInstanceOf[NamedLambdaVariable])
case LambdaFunction(_, arguments, _) if arguments.size == 2 => | ||
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) | ||
case _ => | ||
copy(function = f(function, (elementType, containsNull) :: Nil)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to validate # of arguments here? (the case: arguments.size > 2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check the current error mesasage for the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ArrayTransform doesn't validate arguments.size > 2. I'm not sure what happens in that case either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm. I checked the error handling works well for the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it does. See the test here: https://github.com/apache/spark/pull/25666/files#diff-8e1a34391fdefa4a3a0349d7d454d86fR2204.
Should we also provide similar overloads with index arguments in |
@nvander1 I'm not sure whether we also need the index argument in |
@ueshin comments addressed. I added a test to |
Test build #111596 has finished for PR 25666 at commit
|
Jenkins, retest this please. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending tests.
Test build #111619 has finished for PR 25666 at commit
|
@@ -369,6 +383,9 @@ case class ArrayFilter( | |||
var i = 0 | |||
while (i < arr.numElements) { | |||
elementVar.value.set(arr.get(i, elementVar.dataType)) | |||
if (indexVar.isDefined) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you avoid this per-row check? The current code causes unnecessary runtime overheads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maropu do you have a suggestion about how to do this without implementing codegen? I tried rewriting the logic like so:
@transient private lazy val evalFn: (InternalRow, Any) => Any = indexVar match {
case None => (inputRow, argumentValue) =>
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
}
new GenericArrayData(buffer)
case Some(expr) => (inputRow, argumentValue) =>
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
expr.value.set(i)
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
}
new GenericArrayData(buffer)
}
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
evalFn(inputRow, argumentValue)
}
But from some hacky microbenchmarking this doesn't seem to be meaningfully faster and if anything is marginally slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the benchmark code I was using:
test("ArrayFilter - benchmark") {
import scala.concurrent.duration._
val b = new Benchmark(
"array_filter",
1000,
warmupTime = 5.seconds,
minTime = 5.seconds)
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val isEven: Expression => Expression = x => x % 2 === 0
b.addCase("filter") { _ =>
var i = 0
while (i < 1000) {
filter(ai0, isEven).eval()
i += 1
}
}
b.run()
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maropu @henrydavidge The best performing way to avoid the per-row check in a non-codegen setting is to introduce a new expression type, say ArrayFilterWithIndex
.
The tradeoff between the inline per-row check and the lambda batch solution is that on input arrays that are small (like the one @henrydavidge used in his benchmark), the lambda invocation (which is not guaranteed to be inlined+optimized) overhead may exceed the per-row check overhead. You'd need a fairly large input array to amortize that.
If we want to make it stay simple for now, I'm okay with the inline per-row check version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought code like this;
@transient lazy val (elementVar, mayFillIndex) = function match {
case LambdaFunction(_, Seq(elemVar: NamedLambdaVariable), _) =>
(elemVar, (_: Int) => {})
case LambdaFunction(_, Seq(elemVar: NamedLambdaVariable, idxVar: NamedLambdaVariable), _) =>
(elemVar, (i: Int) => idxVar.value.set(i))
}
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
mayFillIndex(i)
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
}
new GenericArrayData(buffer)
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, tried that as well. It doesn't seem to be significantly different from the others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, if no big difference, I like the similar handling with the others, e.g., https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala#L555
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good enough to go.
How about merging this for now, and addressing it in a separate PR?
transform
is doing the same way, so I think we should do the same thing if needed, maybe at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for this is ready to go for now and we can address the optimization separately.
Side-comment on the version that @maropu gave:
The lambda version that @henrydavidge gave (i.e. "batch-wise lambda") would technically have less overhead:
// lambda invocation overhead outside of loop
for each element in array
do specialized filter action
whereas the version that @maropu gave (i.e. "element-wise lambda") would be:
// shared loop between the two versions
for each element in array
// lambda invocation overhead per element
invoke mayFillIndex lambda
With @maropu 's version, let's assume that we're running on the HotSpot JVM and both the with-index and without-index paths have been used, then the best the HotSpot JIT compiler could have done is a profile-guided bimorphic devirtualization on that lambda call site, which will look like the following after devirtualization+inlining:
local_mayFillIndex = this.mayFillIndex
klazz = local_mayFillIndex.klass
for each element in array
// ...
if (klazz == lambda_klass_1) {
// no-op
} else if (klazz == lambda_klass_2) {
idxVar.value.set(i)
} else {
uncommon_trap() // aka deoptimize, or potentially a full virtual call
}
}
The point is that this JIT-optimized version is actually a degenerated version of Henry's hand-written inline per-element check version, so I wouldn't want to go down this route.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, kris! That explanation's very helpful to me.
Thanks all, I'd merge this for now as per the agreement at #25666 (comment). |
Thanks. merging to master. |
@@ -344,8 +344,13 @@ case class MapFilter( | |||
Examples: | |||
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); | |||
[1,3] | |||
> SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the indices start at 0. but it sounds like the other built-in functions start at 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember there was the (not-merged) PR to standardize one-based column indexes in built-in funcs: #24051
Better to fix them up for consistency?
…cala function API filter ### What changes were proposed in this pull request? This PR is a follow-up PR #25666 for adding the description and example for the Scala function API `filter`. ### Why are the changes needed? It is hard to tell which parameter is the index column. ### Does this PR introduce any user-facing change? No ### How was this patch tested? N/A Closes #27336 from gatorsmile/spark28962. Authored-by: Xiao Li <gatorsmile@gmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
What changes were proposed in this pull request?
Lambda functions to array
filter
can now take as input the index as well as the element. This behavior matches arraytransform
.Why are the changes needed?
See JIRA. It's generally useful, and particularly so if you're working with fixed length arrays.
Does this PR introduce any user-facing change?
Previously filter lambdas had to look like
filter(arr, el -> whatever)
Now, lambdas can take an index argument as well
filter(array, (el, idx) -> whatever)
How was this patch tested?
I added unit tests to
HigherOrderFunctionsSuite
.