Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34388][SQL] Propagate the registered UDF name to ScalaUDF, ScalaUDAF and ScalaAggregator #31500

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1409,9 +1409,14 @@ class SessionCatalog(
Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction")
if (clsForUDAF.isAssignableFrom(clazz)) {
val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF")
val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int])
.newInstance(input,
clazz.getConstructor().newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
val e = cls.getConstructor(
classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int], classOf[Option[String]])
.newInstance(
input,
clazz.getConstructor().newInstance().asInstanceOf[Object],
Int.box(1),
Int.box(1),
Some(name))
.asInstanceOf[ImplicitCastInputTypes]

// Check input argument size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1088,4 +1088,6 @@ trait ComplexTypeMergingExpression extends Expression {
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
* and Hive function wrappers.
*/
trait UserDefinedExpression
trait UserDefinedExpression {
def name: String
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe default to using the class name or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an internal trait, seems OK to require it.

}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ case class ScalaUDF(

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})"
override def toString: String = s"$name(${children.mkString(", ")})"

override def name: String = udfName.getOrElse("UDF")

override lazy val canonicalized: Expression = {
// SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
" via the functions.udaf(agg) method.", "3.0.0")
def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf, udafName = Some(name))
functionRegistry.createOrReplaceTempFunction(name, builder)
udaf
}
Expand All @@ -109,15 +109,15 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 2.2.0
*/
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
udf match {
udf.withName(name) match {
case udaf: UserDefinedAggregator[_, _, _] =>
def builder(children: Seq[Expression]) = udaf.scalaAggregator(children)
functionRegistry.createOrReplaceTempFunction(name, builder)
udf
case _ =>
def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
udaf
case other =>
def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr
functionRegistry.createOrReplaceTempFunction(name, builder)
udf
other
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ case class ScalaUDAF(
children: Seq[Expression],
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
inputAggBufferOffset: Int = 0,
udafName: Option[String] = None)
extends ImperativeAggregate
with NonSQLExpression
with Logging
Expand Down Expand Up @@ -447,10 +448,12 @@ case class ScalaUDAF(
}

override def toString: String = {
s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
s"""$nodeName(${children.mkString(",")})"""
}

override def nodeName: String = udaf.getClass.getSimpleName
override def nodeName: String = name

override def name: String = udafName.getOrElse(udaf.getClass.getSimpleName)
}

case class ScalaAggregator[IN, BUF, OUT](
Expand All @@ -461,7 +464,8 @@ case class ScalaAggregator[IN, BUF, OUT](
nullable: Boolean = true,
isDeterministic: Boolean = true,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
inputAggBufferOffset: Int = 0,
aggregatorName: Option[String] = None)
extends TypedImperativeAggregate[BUF]
with NonSQLExpression
with UserDefinedExpression
Expand Down Expand Up @@ -513,7 +517,9 @@ case class ScalaAggregator[IN, BUF, OUT](

override def toString: String = s"""${nodeName}(${children.mkString(",")})"""

override def nodeName: String = agg.getClass.getSimpleName
override def nodeName: String = name

override def name: String = aggregatorName.getOrElse(agg.getClass.getSimpleName)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = {
val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]]
val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]]
ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, deterministic)
ScalaAggregator(
exprs, aggregator, iEncoder, bEncoder, nullable, deterministic, aggregatorName = name)
}

override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,26 +321,34 @@ object IntegratedUDFTestUtils extends SQLHelper {
* casted_col.cast(df.schema("col").dataType)
* }}}
*/
case class TestScalaUDF(name: String) extends TestUDF {
private[IntegratedUDFTestUtils] lazy val udf = new SparkUserDefinedFunction(
(input: Any) => if (input == null) {
null
} else {
input.toString
},
StringType,
inputEncoders = Seq.fill(1)(None),
name = Some(name)) {

override def apply(exprs: Column*): Column = {
assert(exprs.length == 1, "Defined UDF only has one column")
val expr = exprs.head.expr
assert(expr.resolved, "column should be resolved to use the same type " +
"as input. Try df(name) or df.col(name)")
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
}
class TestInternalScalaUDF(name: String) extends SparkUserDefinedFunction(
(input: Any) => if (input == null) {
null
} else {
input.toString
},
StringType,
inputEncoders = Seq.fill(1)(None),
name = Some(name)) {

override def apply(exprs: Column*): Column = {
assert(exprs.length == 1, "Defined UDF only has one column")
val expr = exprs.head.expr
assert(expr.resolved, "column should be resolved to use the same type " +
"as input. Try df(name) or df.col(name)")
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
}

override def withName(name: String): TestInternalScalaUDF = {
// "withName" should overridden to return TestInternalScalaUDF. Otherwise, the current object
// is sliced and the overridden "apply" is not invoked.
new TestInternalScalaUDF(name)
}
}

case class TestScalaUDF(name: String) extends TestUDF {
private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name)

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Scala UDF"
Expand Down
52 changes: 49 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ import scala.collection.mutable.{ArrayBuffer, WrappedArray}

import org.apache.spark.SparkException
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.expressions.{Literal, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{QueryExecution, SimpleMode}
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.{lit, struct, udf}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, SparkUserDefinedFunction, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions.{lit, struct, udaf, udf}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData._
Expand Down Expand Up @@ -798,4 +801,47 @@ class UDFSuite extends QueryTest with SharedSparkSession {
.select(myUdf(Column("col"))),
Row(ArrayBuffer(100)))
}

test("SPARK-34388: UDF name is propagated with registration for ScalaUDF") {
spark.udf.register("udf34388", udf((value: Int) => value > 2))
spark.sessionState.catalog.lookupFunction(
FunctionIdentifier("udf34388"), Seq(Literal(1))) match {
case udf: ScalaUDF => assert(udf.name === "udf34388")
}
}

test("SPARK-34388: UDF name is propagated with registration for ScalaAggregator") {
val agg = new Aggregator[Long, Long, Long] {
override def zero: Long = 0L
override def reduce(b: Long, a: Long): Long = a + b
override def merge(b1: Long, b2: Long): Long = b1 + b2
override def finish(reduction: Long): Long = reduction
override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
}

spark.udf.register("agg34388", udaf(agg))
spark.sessionState.catalog.lookupFunction(
FunctionIdentifier("agg34388"), Seq(Literal(1))) match {
case agg: ScalaAggregator[_, _, _] => assert(agg.name === "agg34388")
}
}

test("SPARK-34388: UDF name is propagated with registration for ScalaUDAF") {
val udaf = new UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType().add("a", LongType)
def bufferSchema: StructType = new StructType().add("product", LongType)
def dataType: DataType = LongType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}
def evaluate(buffer: Row): Any = buffer.getLong(0)
}
spark.udf.register("udaf34388", udaf)
spark.sessionState.catalog.lookupFunction(
FunctionIdentifier("udaf34388"), Seq(Literal(1))) match {
case udaf: ScalaUDAF => assert(udaf.name === "udaf34388")
}
}
}