Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3046] use executor's class loader as the default serializer classloader #1972

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ private[spark] class Executor(
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)

// Set the classloader for serializer
env.serializer.setDefaultClassLoader(urlClassLoader)

// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ extends DeserializationStream {
def close() { objIn.close() }
}

private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
extends SerializerInstance {

def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
Expand Down Expand Up @@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)

def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
override def newInstance(): SerializerInstance = {
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
new JavaSerializerInstance(counterReset, classLoader)
}

override def writeExternal(out: ObjectOutput) {
out.writeInt(counterReset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
val classLoader = Thread.currentThread.getContextClassLoader

val oldClassLoader = Thread.currentThread.getContextClassLoader
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)

// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
Expand All @@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
try {
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]

// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
} finally {
Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}

Expand Down
17 changes: 17 additions & 0 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
*/
@DeveloperApi
trait Serializer {

/**
* Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
* make sure it is using this when set.
*/
@volatile protected var defaultClassLoader: Option[ClassLoader] = None

/**
* Sets a class loader for the serializer to use in deserialization.
*
* @return this Serializer object
*/
def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
defaultClassLoader = Some(classLoader)
this
}

def newInstance(): SerializerInstance
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import org.apache.spark.util.Utils

import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils}
import org.apache.spark.SparkContext._
import org.apache.spark.serializer.KryoDistributedTest._

class KryoSerializerDistributedSuite extends FunSuite {

test("kryo objects are serialised consistently in different processes") {
val conf = new SparkConf(false)
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
conf.set("spark.task.maxFailures", "1")

val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
conf.setJars(List(jar.getPath))

val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val original = Thread.currentThread.getContextClassLoader
val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
SparkEnv.get.serializer.setDefaultClassLoader(loader)

val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()

// Randomly mix the keys so that the join below will require a shuffle with each partition
// sending data to multiple other partitions.
val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}

// Join the two RDDs, and force evaluation
assert(shuffledRDD.join(cachedRDD).collect().size == 1)

LocalSparkContext.stop(sc)
}
}

object KryoDistributedTest {
class MyCustomClass

class AppJarRegistrator extends KryoRegistrator {
override def registerClasses(k: Kryo) {
val classLoader = Thread.currentThread.getContextClassLoader
k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
}
}

object AppJarRegistrator {
val customClassName = "KryoSerializerDistributedSuiteCustomClass"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite

import org.apache.spark.SharedSparkContext
import org.apache.spark.{SparkConf, SharedSparkContext}
import org.apache.spark.serializer.KryoTest._

class KryoSerializerSuite extends FunSuite with SharedSparkContext {
Expand Down Expand Up @@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
}

test("default class loader can be set by a different thread") {
val ser = new KryoSerializer(new SparkConf)

// First serialize the object
val serInstance = ser.newInstance()
val bytes = serInstance.serialize(new ClassLoaderTestingObject)

// Deserialize the object to make sure normal deserialization works
serInstance.deserialize[ClassLoaderTestingObject](bytes)

// Set a special, broken ClassLoader and make sure we get an exception on deserialization
ser.setDefaultClassLoader(new ClassLoader() {
override def loadClass(name: String) = throw new UnsupportedOperationException
})
intercept[UnsupportedOperationException] {
ser.newInstance().deserialize[ClassLoaderTestingObject](bytes)
}
}
}

class ClassLoaderTestingObject

class KryoSerializerResizableOutputSuite extends FunSuite {
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
Expand Down