Skip to content

Commit

Permalink
Backports apache#9664 to branch-1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Nov 12, 2015
1 parent b478ee3 commit 5502d69
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
}
}

test("SPARK-11595 ADD JAR with input path having URL scheme") {
withJdbcStatement { statement =>
val jarPath = "../hive/src/test/resources/TestUDTF.jar"
val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"

Seq(
s"ADD JAR $jarURL",
s"""CREATE TEMPORARY FUNCTION udtf_count2
|AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
""".stripMargin
).foreach(statement.execute)

val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")

assert(rs1.next())
assert(rs1.getString(1) === "Function: udtf_count2")

assert(rs1.next())
assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
rs1.getString(1)
}

assert(rs1.next())
assert(rs1.getString(1) === "Usage: To be added.")

val dataPath = "../hive/src/test/resources/data/files/kv1.txt"

Seq(
s"CREATE TABLE test_udtf(key INT, value STRING)",
s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf"
).foreach(statement.execute)

val rs2 = statement.executeQuery(
"SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc")

assert(rs2.next())
assert(rs2.getInt(1) === 97)
assert(rs2.getInt(2) === 500)

assert(rs2.next())
assert(rs2.getInt(1) === 97)
assert(rs2.getInt(2) === 500)
}
}
}

class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ import org.apache.spark.util.Utils
/**
* This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
*/
private[hive] class HiveQLDialect extends ParserDialect {
private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect {
override def parse(sqlText: String): LogicalPlan = {
HiveQl.parseSql(sqlText)
sqlContext.executionHive.withHiveState {
HiveQl.parseSql(sqlText)
}
}
}

Expand Down Expand Up @@ -410,7 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry: FunctionRegistry =
new HiveFunctionRegistry(FunctionRegistry.builtin)
new HiveFunctionRegistry(FunctionRegistry.builtin, this)

/* An analyzer that uses the Hive metastore. */
@transient
Expand Down Expand Up @@ -517,10 +519,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
}
}

override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") {
classOf[HiveQLDialect].getCanonicalName
} else {
super.dialectClassName
protected[sql] override def getSQLDialect(): ParserDialect = {
if (conf.dialect == "hiveql") {
new HiveQLDialect(this)
} else {
super.getSQLDialect()
}
}

@transient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ private[hive] class ClientWrapper(
/**
* Runs `f` with ThreadLocal session state and classloaders configured for this version of hive.
*/
private def withHiveState[A](f: => A): A = retryLocked {
def withHiveState[A](f: => A): A = retryLocked {
val original = Thread.currentThread().getContextClassLoader
// Set the thread local metastore client to the client associated with this ClientWrapper.
Hive.set(client)
Expand Down
20 changes: 15 additions & 5 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,23 @@ import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._


private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
private[hive] class HiveFunctionRegistry(
underlying: analysis.FunctionRegistry,
hiveContext: HiveContext)
extends analysis.FunctionRegistry with HiveInspectors {

def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
def getFunctionInfo(name: String): FunctionInfo = {
hiveContext.executionHive.withHiveState {
FunctionRegistry.getFunctionInfo(name)
}
}

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
Try(underlying.lookupFunction(name, children)).getOrElse {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
// not always serializable.
val functionInfo: FunctionInfo =
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
Option(getFunctionInfo(name.toLowerCase)).getOrElse(
throw new AnalysisException(s"undefined function $name"))

val functionClassName = functionInfo.getFunctionClass.getName
Expand Down Expand Up @@ -89,7 +95,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
override def lookupFunction(name: String): Option[ExpressionInfo] = {
underlying.lookupFunction(name).orElse(
Try {
val info = FunctionRegistry.getFunctionInfo(name)
val info = getFunctionInfo(name)
val annotation = info.getFunctionClass.getAnnotation(classOf[Description])
if (annotation != null) {
Some(new ExpressionInfo(
Expand All @@ -98,7 +104,11 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
annotation.value(),
annotation.extended()))
} else {
None
Some(new ExpressionInfo(
info.getFunctionClass.getCanonicalName,
name,
null,
null))
}
}.getOrElse(None))
}
Expand Down

0 comments on commit 5502d69

Please sign in to comment.