From 396c0e1bf10d4ca69675801aa68bf9b21ba5c9bf Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 8 Dec 2014 22:09:52 -0800 Subject: [PATCH] avoid Simple UDF to be serialized --- .../org/apache/spark/sql/hive/hiveUdfs.scala | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 48b3b15c83a9b..842ffb8579c28 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -86,39 +86,50 @@ class HiveFunctionCache(var functionClassName: String) extends java.io.Externali private var instance: Any = null def writeExternal(out: java.io.ObjectOutput) { - // Some of the UDF are serializable, but some others are not - // Hive Utilities can handle both cases - val baos = new java.io.ByteArrayOutputStream() - HiveShim.serializePlan(instance, baos) - val functionInBytes = baos.toByteArray - // output the function name out.writeUTF(functionClassName) - // output the function bytes - out.writeInt(functionInBytes.length) - out.write(functionInBytes, 0, functionInBytes.length) + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + HiveShim.serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } } def readExternal(in: java.io.ObjectInput) { // read the function name functionClassName = in.readUTF() - // read the function in bytes - val functionInBytesLength = in.readInt() - val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) - // deserialize the function object via Hive Utilities - instance = HiveShim.deserializePlan(new java.io.ByteArrayInputStream(functionInBytes), - getContextOrSparkClassLoader.loadClass(functionClassName)) + // deserialize the function object via Hive Utilities + instance = HiveShim.deserializePlan(new java.io.ByteArrayInputStream(functionInBytes), + getContextOrSparkClassLoader.loadClass(functionClassName)) + } } - def createFunction[UDFType]() = { - if (instance == null) { - instance = getContextOrSparkClassLoader.loadClass(functionClassName).newInstance + def createFunction[UDFType](alwaysCreateNewInstance: Boolean = false) = { + if (alwaysCreateNewInstance) { + getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + } else { + if (instance == null) { + instance = getContextOrSparkClassLoader.loadClass(functionClassName).newInstance + } + instance.asInstanceOf[UDFType] } - instance.asInstanceOf[UDFType] } } @@ -130,7 +141,7 @@ private[hive] case class HiveSimpleUdf(cache: HiveFunctionCache, children: Seq[E def nullable = true @transient - lazy val function = cache.createFunction[UDFType]() + lazy val function = cache.createFunction[UDFType](true) // Simple UDF should be not serialized. @transient protected lazy val method =