Skip to content

Commit

Permalink
[SPARK-3046] use executor's class loader as the default serializer cl…
Browse files Browse the repository at this point in the history
…ass loader.
  • Loading branch information
rxin committed Aug 15, 2014
1 parent fd9fcd2 commit d879e67
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 1 deletion.
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ private[spark] class Executor(
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)

// Set the classloader for serializer
env.serializer.setDefaultClassLoader(urlClassLoader)

// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
val classLoader = Thread.currentThread.getContextClassLoader

val oldClassLoader = Thread.currentThread.getContextClassLoader
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)

// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
Expand All @@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
try {
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]

// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
} finally {
Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}

Expand Down
17 changes: 17 additions & 0 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
*/
@DeveloperApi
trait Serializer {

/**
* Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
* make sure it is using this when set.
*/
@volatile protected var defaultClassLoader: Option[ClassLoader] = None

/**
* Sets a class loader for the serializer to use in deserialization.
*
* @return this Serializer object
*/
def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
defaultClassLoader = Some(classLoader)
this
}

def newInstance(): SerializerInstance
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import org.apache.spark.util.Utils

import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils}
import org.apache.spark.SparkContext._
import org.apache.spark.serializer.KryoDistributedTest._

class KryoSerializerDistributedSuite extends FunSuite {

test("kryo objects are serialised consistently in different processes") {
val conf = new SparkConf(false)
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
conf.set("spark.task.maxFailures", "1")

val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
conf.setJars(List(jar.getPath))

val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val original = Thread.currentThread.getContextClassLoader
val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
SparkEnv.get.serializer.setDefaultClassLoader(loader)

val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()

// Randomly mix the keys so that the join below will require a shuffle with each partition
// sending data to multiple other partitions.
val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}

// Join the two RDDs, and force evaluation
assert(shuffledRDD.join(cachedRDD).collect().size == 1)

LocalSparkContext.stop(sc)
}
}

object KryoDistributedTest {
class MyCustomClass

class AppJarRegistrator extends KryoRegistrator {
override def registerClasses(k: Kryo) {
val classLoader = Thread.currentThread.getContextClassLoader
k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
}
}

object AppJarRegistrator {
val customClassName = "KryoSerializerDistributedSuiteCustomClass"
}
}

0 comments on commit d879e67

Please sign in to comment.