Skip to content

Commit

Permalink
fix Python UDF with aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 15, 2016
1 parent dae4d5d commit 31b42eb
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 11 deletions.
10 changes: 9 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,21 @@ def test_broadcast_in_udf(self):

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import udf, col, sum
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])

my_copy = udf(lambda x: x, IntegerType())
my_add = udf(lambda a, b: int(a + b), IntegerType())
my_strlen = udf(lambda x: len(x), IntegerType())
sel = df.groupBy(my_copy(col("key")).alias("k"))\
.agg(sum(my_strlen(col("value"))).alias("s"))\
.select(my_add(col("k"), col("s")).alias("t"))
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

class SparkOptimizer(
Expand All @@ -28,6 +29,7 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {

override def batches: Seq[Batch] = super.batches :+ Batch(
"User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
override def batches: Seq[Batch] = super.batches :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi

def children: Seq[SparkPlan] = child :: Nil

override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))

private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,68 @@
package org.apache.spark.sql.execution.python

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan


/**
* Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
* grouping key, evaluate them after aggregate.
*/
private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {

/**
* Returns whether the expression could only be evaluated within aggregate.
*/
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}

private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
expr.find {
e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
}.isDefined
}

private def extract(agg: Aggregate): LogicalPlan = {
val projList = new ArrayBuffer[NamedExpression]()
val aggExpr = new ArrayBuffer[NamedExpression]()
agg.aggregateExpressions.foreach { expr =>
if (hasPythonUdfOverAggregate(expr, agg)) {
// Python UDF can only be evaluated after aggregate
val newE = expr transformDown {
case e: Expression if belongAggregate(e, agg) =>
val alias = e match {
case a: NamedExpression => a
case o => Alias(e, "agg")()
}
aggExpr += alias
alias.toAttribute
}
projList += newE.asInstanceOf[NamedExpression]
} else {
aggExpr += expr
projList += expr.toAttribute
}
}
// There is no Python UDF over aggregate expression
Project(projList, agg.copy(aggregateExpressions = aggExpr))
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
extract(agg)
}
}


/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
Expand Down Expand Up @@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
}

/**
* Extract all the PythonUDFs from the current operator.
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
def extract(plan: SparkPlan): SparkPlan = {
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
Expand All @@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
if (udf.references.subsetOf(plan.inputSet)) {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
} else {
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
}
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}

val rewritten = plan.transformExpressions {
Expand Down

0 comments on commit 31b42eb

Please sign in to comment.