Skip to content

Commit

Permalink
avoid Simple UDF to be serialized
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Dec 9, 2014
1 parent e9c3212 commit 396c0e1
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,39 +86,50 @@ class HiveFunctionCache(var functionClassName: String) extends java.io.Externali
private var instance: Any = null

def writeExternal(out: java.io.ObjectOutput) {
// Some of the UDF are serializable, but some others are not
// Hive Utilities can handle both cases
val baos = new java.io.ByteArrayOutputStream()
HiveShim.serializePlan(instance, baos)
val functionInBytes = baos.toByteArray

// output the function name
out.writeUTF(functionClassName)

// output the function bytes
out.writeInt(functionInBytes.length)
out.write(functionInBytes, 0, functionInBytes.length)
// Write a flag if instance is null or not
out.writeBoolean(instance != null)
if (instance != null) {
// Some of the UDF are serializable, but some others are not
// Hive Utilities can handle both cases
val baos = new java.io.ByteArrayOutputStream()
HiveShim.serializePlan(instance, baos)
val functionInBytes = baos.toByteArray

// output the function bytes
out.writeInt(functionInBytes.length)
out.write(functionInBytes, 0, functionInBytes.length)
}
}

def readExternal(in: java.io.ObjectInput) {
// read the function name
functionClassName = in.readUTF()

// read the function in bytes
val functionInBytesLength = in.readInt()
val functionInBytes = new Array[Byte](functionInBytesLength)
in.read(functionInBytes, 0, functionInBytesLength)
if (in.readBoolean()) {
// if the instance is not null
// read the function in bytes
val functionInBytesLength = in.readInt()
val functionInBytes = new Array[Byte](functionInBytesLength)
in.read(functionInBytes, 0, functionInBytesLength)

// deserialize the function object via Hive Utilities
instance = HiveShim.deserializePlan(new java.io.ByteArrayInputStream(functionInBytes),
getContextOrSparkClassLoader.loadClass(functionClassName))
// deserialize the function object via Hive Utilities
instance = HiveShim.deserializePlan(new java.io.ByteArrayInputStream(functionInBytes),
getContextOrSparkClassLoader.loadClass(functionClassName))
}
}

def createFunction[UDFType]() = {
if (instance == null) {
instance = getContextOrSparkClassLoader.loadClass(functionClassName).newInstance
def createFunction[UDFType](alwaysCreateNewInstance: Boolean = false) = {
if (alwaysCreateNewInstance) {
getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
} else {
if (instance == null) {
instance = getContextOrSparkClassLoader.loadClass(functionClassName).newInstance
}
instance.asInstanceOf[UDFType]
}
instance.asInstanceOf[UDFType]
}
}

Expand All @@ -130,7 +141,7 @@ private[hive] case class HiveSimpleUdf(cache: HiveFunctionCache, children: Seq[E
def nullable = true

@transient
lazy val function = cache.createFunction[UDFType]()
lazy val function = cache.createFunction[UDFType](true) // Simple UDF should be not serialized.

@transient
protected lazy val method =
Expand Down

0 comments on commit 396c0e1

Please sign in to comment.