Skip to content

Commit

Permalink
support UDAF
Browse files Browse the repository at this point in the history
  • Loading branch information
gatorsmile committed Jul 21, 2017
1 parent f18b905 commit 4028155
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql.catalyst.catalog

import java.lang.reflect.InvocationTargetException
import java.net.URI
import java.util.Locale
import java.util.concurrent.Callable
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.util.{Failure, Success, Try}
import scala.util.control.NonFatal

import com.google.common.cache.{Cache, CacheBuilder}
import org.apache.hadoop.conf.Configuration
Expand All @@ -40,6 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias,
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

object SessionCatalog {
val DEFAULT_DATABASE = "default"
Expand Down Expand Up @@ -1096,8 +1099,43 @@ class SessionCatalog(
* This performs reflection to decide what type of [[Expression]] to return in the builder.
*/
protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = {
// TODO: at least support UDAFs here
throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.")
makeFunctionBuilder(name, Utils.classForName(functionClassName))
}

/**
* Construct a [[FunctionBuilder]] based on the provided class that represents a function.
*/
private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = {
// When we instantiate ScalaUDAF class, we may throw exception if the input
// expressions don't satisfy the UDAF, such as type mismatch, input number
// mismatch, etc. Here we catch the exception and throw AnalysisException instead.
(children: Seq[Expression]) => {
try {
val clsForUDAF =
Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction")
if (clsForUDAF.isAssignableFrom(clazz)) {
val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF")
// val ctor = classOf[Integer].getConstructor(classOf[Int])
cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int])
.newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
.asInstanceOf[Expression]
} else {
throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.")
}
} catch {
case NonFatal(exception) =>
val e = exception match {
// Since we are using shim, the exceptions thrown by the underlying method of
// Method.invoke() are wrapped by InvocationTargetException
case i: InvocationTargetException => i.getCause
case o => o
}
val analysisException =
new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}': $e")
analysisException.setStackTrace(e.getStackTrace)
throw analysisException
}
}
}

/**
Expand All @@ -1116,12 +1154,17 @@ class SessionCatalog(
overrideIfExists: Boolean,
functionBuilder: Option[FunctionBuilder] = None): Unit = {
val func = funcDefinition.identifier
val className = funcDefinition.className
if (functionRegistry.functionExists(func) && !overrideIfExists) {
throw new AnalysisException(s"Function $func already exists")
}
val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName)
if (!Utils.classIsLoadable(className)) {
throw new AnalysisException(s"Can not load class '$className' when registering " +
s"the function '$func', please make sure it is on the classpath")
}
val info = new ExpressionInfo(className, func.database.orNull, func.funcName)
val builder =
functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className))
functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, className))
functionRegistry.registerFunction(func, info, builder)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
Expand Down Expand Up @@ -326,6 +326,11 @@ case class ScalaUDAF(
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {

if (children.length != udaf.inputSchema.length) {
throw new AnalysisException(s"Invalid number of arguments for the function " +
s"Expected: ${udaf.inputSchema.length}; Found: ${children.length}")
}

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/udaf.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
(1), (2), (3), (4)
as t1(int_col1);

CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg';

SELECT default.myDoubleAvg(int_col1) as my_avg from t1;

SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1;

CREATE FUNCTION udaf1 AS 'test.non.existent.udaf';

SELECT default.udaf1(int_col1) as udaf1 from t1;
54 changes: 54 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/udaf.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 6


-- !query 0
CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
(1), (2), (3), (4)
as t1(int_col1)
-- !query 0 schema
struct<>
-- !query 0 output



-- !query 1
CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg'
-- !query 1 schema
struct<>
-- !query 1 output



-- !query 2
SELECT default.myDoubleAvg(int_col1) as my_avg from t1
-- !query 2 schema
struct<my_avg:double>
-- !query 2 output
102.5


-- !query 3
SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1
-- !query 3 schema
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
No handler for UDAF 'test.org.apache.spark.sql.MyDoubleAvg': org.apache.spark.sql.AnalysisException: Invalid number of arguments for the function Expected: 1; Found: 2;; line 1 pos 7


-- !query 4
CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'
-- !query 4 schema
struct<>
-- !query 4 output



-- !query 5
SELECT default.udaf1(int_col1) as udaf1 from t1
-- !query 5 schema
struct<>
-- !query 5 output
org.apache.spark.sql.AnalysisException
Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, DoubleType}
Expand Down Expand Up @@ -95,6 +97,8 @@ private[sql] class HiveSessionCatalog(
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children)
udtf.elementSchema // Force it to check input data types.
udtf
} else if (classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
ScalaUDAF(children, clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction])
} else {
throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
import test.org.apache.spark.sql.MyDoubleAvg

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
Expand Down Expand Up @@ -86,6 +87,18 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
))
}

test("call JAVA UDAF") {
withTempView("temp") {
withUserDefinedFunction("myDoubleAvg" -> false) {
spark.range(1, 10).toDF("value").createOrReplaceTempView("temp")
sql(s"CREATE FUNCTION myDoubleAvg AS '${classOf[MyDoubleAvg].getName}'")
checkAnswer(
spark.sql("SELECT default.myDoubleAvg(value) as my_avg from temp"),
Row(105.0))
}
}
}

test("non-deterministic children expressions of UDAF") {
withTempView("view1") {
spark.range(1).selectExpr("id as x", "id as y").createTempView("view1")
Expand Down

0 comments on commit 4028155

Please sign in to comment.