-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle) #19872
Conversation
cc @HyukjinKwon @holdenk @ueshin Passing some basic tests. I will work on this more next week to clean up and add more testing. |
Test build #84414 has finished for PR 19872 at commit
|
@@ -113,6 +113,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |||
def apply(plan: SparkPlan): SparkPlan = plan transformUp { | |||
// FlatMapGroupsInPandas can be evaluated directly in python worker | |||
// Therefore we don't need to extract the UDFs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlatMapGroupsInPandas
and AggregateInPandasExec
can be...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
Test build #84415 has finished for PR 19872 at commit
|
python/pyspark/sql/group.py
Outdated
jdf = self._jgd.aggInPandas( | ||
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs])) | ||
else: | ||
jdf = self._jgd.agg(exprs[0]._jc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If exprs[n]
(n > 0) is a UDFColumn
? I think we should make sure if any column is a UDFColumn
, all columns should be UDFColumn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is removed.
python/pyspark/sql/group.py
Outdated
jdf = self._jgd.agg(exprs[0]._jc, | ||
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) | ||
if isinstance(exprs[0], UDFColumn): | ||
assert all(isinstance(c, UDFColumn) for c in exprs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A informative error message should be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like all exprs should be UDFColumn"
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I'm a little worried about this change, if other folks have wrapped Java UDAFs (which is reasonable since there aren't other ways to make UDAFs in PySpark before this), this seems like they won't be able to mix them. I'd suggest maybe doing what @viirya suggested bellow but instead of a failure just a warning until Spark 3.
What do y'all think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still trying to figure out the best way to dispatch this, but either way I think we won't be able to fix Java UDAF with pandas UDF.
@holdenk I am not sure what kind of warning message do you have in mind. Can you please explain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah so what your saying is you don't support mixed Python & Java UDAFs? That's certainly something which needs to be communicated in both the documentation and the error message.
Is there a reason why we don't support this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answered in #19872 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this. I'm off for a flight to Strata but a few quick questions. I'll read this more over the coming week :)
python/pyspark/sql/udf.py
Outdated
@@ -56,6 +56,10 @@ def _create_udf(f, returnType, evalType): | |||
return udf_obj._wrapped() | |||
|
|||
|
|||
class UDFColumn(Column): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we add this new sub-class?
@@ -2070,6 +2070,8 @@ class PandasUDFType(object): | |||
|
|||
GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF | |||
|
|||
GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I'm worried that it isn't clear to the user that this will result in a full-shuffle with no-partial aggregation. Is there maybe a place we can document this warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in docstring of pandas_udf
and groupby().agg()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought @ueshin is working on this BTW.
|
||
val argOffsets = inputs.map { input => | ||
input.map { e => | ||
allInputs += e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indentation nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks!
functionExprs: Seq[Expression], | ||
output: Seq[Attribute], | ||
child: LogicalPlan | ||
) extends UnaryNode { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
child: LogicalPlan) extends UnaryNode {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
python/pyspark/sql/udf.py
Outdated
@@ -56,6 +56,10 @@ def _create_udf(f, returnType, evalType): | |||
return udf_obj._wrapped() | |||
|
|||
|
|||
class UDFColumn(Column): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, what do you think about adding an attribute instead in __call__
like a flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
val childrenExpressions = exprs.flatMap(expr => | ||
expr.children.map { | ||
case ne: NamedExpression => ne | ||
case other => Alias(other, other.toString)() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indentation nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
|
||
val udfOutputs = exprs.flatMap(expr => | ||
Seq(AttributeReference(expr.name, expr.dataType)()) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be inlined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
python/pyspark/sql/tests.py
Outdated
class GroupbyAggTests(ReusedSQLTestCase): | ||
def assertFramesEqual(self, expected, result): | ||
msg = ("DataFrames are not equal: " + | ||
("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indentation nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
Oh, I certainly don't want to duplicate @ueshin 's work. I am under the impression that @ueshin is working on two-stage PySpark UDAF with pandas_udf, but I cannot really find the Jira for it... @ueshin can you point me to what you are working on so I don't overstep? |
a1058b8
to
c1dc543
Compare
Test build #84446 has finished for PR 19872 at commit
|
Test build #84628 has finished for PR 19872 at commit
|
I end up removing The code works and three tests (test_basic, test_alias, test_multiple) passes now but the code is kind of messy. I am going on vacation next week but I will clean up the code and move this PR forward when I get back (Dec 16). Thanks all. |
And to @holdenk 's question. Pandas group_agg udf fundamentally uses different physical plan than the existing java/scala udf and therefore it's hard to combine them together. I don't know a good way to do this, the closest is maybe to compute java/scala and python aggregation separately and join them together. |
3352050
to
184b37f
Compare
Test build #84630 has finished for PR 19872 at commit
|
Test build #84631 has finished for PR 19872 at commit
|
Test build #84632 has finished for PR 19872 at commit
|
@icexelloss I'm sorry for the late response. |
@@ -32,7 +31,5 @@ case class PythonUDF( | |||
evalType: Int) | |||
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { | |||
|
|||
override def toString: String = s"$name(${children.mkString(", ")})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, my bad, adding back
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added back
python/pyspark/sql/tests.py
Outdated
@@ -4016,6 +4016,124 @@ def test_unsupported_types(self): | |||
with self.assertRaisesRegexp(Exception, 'Unsupported data type'): | |||
df.groupby('id').apply(f).collect() | |||
|
|||
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") | |||
class GroupbyAggTests(ReusedSQLTestCase): | |||
def assertFramesEqual(self, expected, result): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: how about making this the common method?
val joined = new JoinedRow | ||
val resultProj = UnsafeProjection.create(output, output) | ||
|
||
columnarBatchIter.map(_.rowIterator.next()).map{ outputRow => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: columnarBatchIter.flatMap(_.rowIterator)
?
nit: style, add a space between map
and { outputRow =>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
columnarBatchIter.flatMap(_.rowIterator)
Doesn't work because rowIterator is a java iterator not a scala iterator, we can convert it, but I am not sure it's better though. @ueshin if you prefer the flatMap one I can change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant columnarBatchIter.flatMap(_.rowIterator.asScala)
. I'd prefer this one.
@@ -48,9 +48,26 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { | |||
}.isDefined | |||
} | |||
|
|||
private def isPandasGroupAggUdf(expr: Expression): Boolean = expr match { | |||
case _ @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF ) => true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need _ @
here.
nit: remove extra space after SQL_PANDAS_GROUP_AGG_UDF
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
if (hasPandasGroupAggUdf(agg)) { | ||
Aggregate(agg.groupingExpressions, agg.aggregateExpressions, agg.child) | ||
} else { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: style, we need indent for this block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
@@ -15,10 +15,9 @@ | |||
* limitations under the License. | |||
*/ | |||
|
|||
package org.apache.spark.sql.execution.python | |||
package org.apache.spark.sql.catalyst.expressions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to move package to catalyst?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do. This is similar to https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
The reason is we need to access the class PythonUDF
in analyzer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks!
Test build #85136 has finished for PR 19872 at commit
|
Test build #85137 has finished for PR 19872 at commit
|
Test build #85138 has finished for PR 19872 at commit
|
@ramacode2014 Hi, I'm not sure why you received notifications from this PR, but I guess you can unsubscribe by the "Unsubscribe" button in the right column of this page. Sorry for the inconvenience. Thanks! |
alias.toAttribute | ||
|
||
if (hasPandasGroupAggUdf(agg)) { | ||
Aggregate(agg.groupingExpressions, agg.aggregateExpressions, agg.child) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. But I added copy in ExtractGroupAggPandasUDFFromAggregate
similar to existing rules.
} | ||
|
||
private def hasPandasGroupAggUdf(agg: Aggregate): Boolean = { | ||
val actualAggExpr = agg.aggregateExpressions.drop(agg.groupingExpressions.length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to drop the grouping expressions?
If we need, we can drop them only if conf.dataFrameRetainGroupColumns == true
, otherwise aggregateExpressions
doesn't contain groupingExpressions
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fixed. Added test_retain_grouping_columns
test
val allInputs = new ArrayBuffer[Expression] | ||
val dataTypes = new ArrayBuffer[DataType] | ||
|
||
allInputs.appendAll(groupingExpressions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we don't need to append groupingExpressions
. Seems like they are dropped later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fixed.
.compute(projectedRowIter, context.partitionId(), context) | ||
|
||
val joined = new JoinedRow | ||
val resultProj = UnsafeProjection.create(output, output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to handle resultExpressions
for the following cases:
def test_result_expressions(self):
import numpy as np
from pyspark.sql.functions import mean, pandas_udf, PandasUDFType
df = self.data
@pandas_udf('double', PandasUDFType.GROUP_AGG)
def mean_udf(v, w):
return np.average(v, weights=w)
result1 = (df.groupby('id')
.agg(mean_udf(df.v, lit(1.0)) + 1)
.sort('id')
.toPandas())
expected1 = (df.groupby('id')
.agg(mean(df.v) + 1)
.sort('id')
.toPandas())
self.assertPandasEqual(expected1, result1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ueshin for reminding me of this. Just want to clarify the semantics:
Does
.agg(mean(df.v) + 1)
mean "compute mean of df.v and plus the mean by one as output", i.e, same as
.agg(mean(df.v).alias('mean'))
.withColumn('mean', col('mean') + 1)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so about the behavior. I guess the plan could be different, though.
We can compare the behavior with non-udf aggregation and let's follow the behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added ExtractGroupAggPandasUDFFromAggregate
rule to deal with this
actualAggExpr.exists(isPandasGroupAggUdf) | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove an extra line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
Test build #85152 has finished for PR 19872 at commit
|
99367a6
to
66a31f9
Compare
@ueshin I pushed some more change to address your comments. There is one regression in existing test |
Test build #85442 has finished for PR 19872 at commit
|
Test build #85446 has finished for PR 19872 at commit
|
a94b146
to
17fad5c
Compare
Test build #86345 has finished for PR 19872 at commit
|
Test build #86344 has finished for PR 19872 at commit
|
Test build #86346 has finished for PR 19872 at commit
|
Test build #86350 has finished for PR 19872 at commit
|
@ueshin I think all comments are addressed. Can you take a final look? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to add PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF
to udf.py#L40-L41 to pass require_minimum_pyarrow_version()
.
LGTM except for the comments.
Btw, I'm afraid I guess we shouldn't merge this into branch-2.3 since we are already close to release 2.3.
WDYT? @HyukjinKwon @cloud-fan
python/pyspark/sql/functions.py
Outdated
3. GROUP_AGG | ||
|
||
A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar | ||
The `returnType` should be a primitive data type, e.g, :class:`DoubleType`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very small nit: e.g.
instead of e.g
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks!
+1 for master-only. We can cherry-pick and backport if we should even after this gets merged anyway. For a reminder, we should complete the doc #19575 too. |
Addressed latest comments. Yeah I think master only is fine. |
Test build #86487 has finished for PR 19872 at commit
|
Test build #86492 has finished for PR 19872 at commit
|
Thanks! merging to master. |
Thanks all for review! |
@@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { | |||
object PhysicalAggregation { | |||
// groupingExpressions, aggregateExpressions, resultExpressions, child | |||
type ReturnType = | |||
(Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) | |||
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@icexelloss Thank you for this contribution! I just came across the change in this file. I am not sure if changing the type at here is the best option. The reason is that whenever we use this PhysicalAggregation rule, we have to check the instance type of those aggregate expressions and do casting. To me, it seems better to leave this rule untouched and create a new rule just for Python UDAF. What do you think?
(maybe you and reviewers already discussed it. If so, can you point me to the discussion?)
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @yhuai,
You bring up a good point. I agree with you ideally we should avoid doing. When I was making the change, I found the solution implemented results in least amount of duplicate code, because a lot of logic is shared between AggregateExpression and Python UDF, but the downside is exactly what you mentioned.
One alternative is to create new rules for Python UDAF, my concern is that could result in quite a bit of code duplication. Maybe there is a way to avoid code duplication and keep the type safety, I am happy to explore the option. (Maybe create a parent class for AggregateExpression and Python UDAF)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer that we try out using a new rule. We can create utility function to reuse code. Will you have a chance to try it out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yhuai Yeah I can certainly try it out. Created https://issues.apache.org/jira/browse/SPARK-23302 to track.
I assume this is not urgent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be good to try it out soon. But it is not urgent.
from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
|
||
with QuietTest(self.sc): | ||
with self.assertRaisesRegex(NotImplementedError, 'not supported'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@icexelloss This line does not compile ( we need assertRaisesRegexp
). Can you file a pr to fix it? Thanks! Meanwhile, we will look into jenkins setup and see why the test was not exercised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll file the follow-up pr to fix it soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed #20467. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yhuai, if you meant not running tests in Python 2, this link might be helpful. Let me leave it just in case - #19884 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ueshin Thanks for fixing this. (I am late to the party)
… of `assertRaisesRegex`. ## What changes were proposed in this pull request? This is a follow-up pr of apache#19872 which uses `assertRaisesRegex` but it doesn't exist in Python 2, so some tests fail when running tests in Python 2 environment. Unfortunately, we missed it because currently Python 2 environment of the pr builder doesn't have proper versions of pandas or pyarrow, so the tests were skipped. This pr modifies to use `assertRaisesRegexp` instead of `assertRaisesRegex`. ## How was this patch tested? Tested manually in my local environment. Author: Takuya UESHIN <ueshin@databricks.com> Closes apache#20467 from ueshin/issues/SPARK-22274/fup1.
What changes were proposed in this pull request?
Add support for using pandas UDFs with groupby().agg().
This PR introduces a new type of pandas UDF - group aggregate pandas UDF. This type of UDF defines a transformation of multiple pandas Series -> a scalar value. Group aggregate pandas UDFs can be used with groupby().agg(). Note group aggregate pandas UDF doesn't support partial aggregation, i.e., a full shuffle is required.
This PR doesn't support group aggregate pandas UDFs that return ArrayType, StructType or MapType. Support for these types is left for future PR.
How was this patch tested?
GroupbyAggPandasUDFTests