Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
imback82 committed Feb 7, 2021
1 parent a68b977 commit 56c3001
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
case udaf: UserDefinedAggregator[_, _, _] =>
def builder(children: Seq[Expression]) = udaf.scalaAggregator(children)
functionRegistry.createOrReplaceTempFunction(name, builder)
udf
udaf
case other =>
def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr
functionRegistry.createOrReplaceTempFunction(name, builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,26 +321,34 @@ object IntegratedUDFTestUtils extends SQLHelper {
* casted_col.cast(df.schema("col").dataType)
* }}}
*/
case class TestScalaUDF(name: String) extends TestUDF {
private[IntegratedUDFTestUtils] lazy val udf = new SparkUserDefinedFunction(
(input: Any) => if (input == null) {
null
} else {
input.toString
},
StringType,
inputEncoders = Seq.fill(1)(None),
name = Some(name)) {

override def apply(exprs: Column*): Column = {
assert(exprs.length == 1, "Defined UDF only has one column")
val expr = exprs.head.expr
assert(expr.resolved, "column should be resolved to use the same type " +
"as input. Try df(name) or df.col(name)")
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
}
class TestInternalScalaUDF(name: String) extends SparkUserDefinedFunction(
(input: Any) => if (input == null) {
null
} else {
input.toString
},
StringType,
inputEncoders = Seq.fill(1)(None),
name = Some(name)) {

override def apply(exprs: Column*): Column = {
assert(exprs.length == 1, "Defined UDF only has one column")
val expr = exprs.head.expr
assert(expr.resolved, "column should be resolved to use the same type " +
"as input. Try df(name) or df.col(name)")
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
}

override def withName(name: String): TestInternalScalaUDF = {
// "withName" should overridden to return TestInternalScalaUDF. Otherwise, the current object
// is sliced and the overridden "apply" is not invoked.
new TestInternalScalaUDF(name)
}
}

case class TestScalaUDF(name: String) extends TestUDF {
private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name)

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Scala UDF"
Expand Down

0 comments on commit 56c3001

Please sign in to comment.