Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17714][Core][test-maven][test-hadoop2.6]Avoid using ExecutorClassLoader to load Netty generated classes #16859

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,20 @@ public class TransportContext {
private final RpcHandler rpcHandler;
private final boolean closeIdleConnections;

private final MessageEncoder encoder;
private final MessageDecoder decoder;
/**
* Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
* before switching the current context class loader to ExecutorClassLoader.
*
* Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the
* implementation calls "Class.forName" to check if this calls is already generated. If the
* following two objects are created in "ExecutorClassLoader.findClass", it will cause
* "ClassCircularityError". This is because loading this Netty generated class will call
* "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use
* RPC to load it and cause to load the non-exist matcher class again. JVM will report
* `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
*/
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;

public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
this(conf, rpcHandler, false);
Expand All @@ -75,8 +87,6 @@ public TransportContext(
boolean closeIdleConnections) {
this.conf = conf;
this.rpcHandler = rpcHandler;
this.encoder = new MessageEncoder();
this.decoder = new MessageDecoder();
this.closeIdleConnections = closeIdleConnections;
}

Expand Down Expand Up @@ -135,9 +145,9 @@ public TransportChannelHandler initializePipeline(
try {
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
.addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast("decoder", decoder)
.addLast("decoder", DECODER)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
// would require more logic to guarantee if this were not part of the same event loop.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {

private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);

public static final MessageDecoder INSTANCE = new MessageDecoder();

private MessageDecoder() {}

@Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
Message.Type msgType = Message.Type.decode(in);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {

private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);

public static final MessageEncoder INSTANCE = new MessageEncoder();

private MessageEncoder() {}

/***
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
package org.apache.spark.network.server;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.ResponseMessage;
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
Expand All @@ -48,7 +47,7 @@
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
* timeout if the client is continuously sending but getting no responses, for simplicity.
*/
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimpleChannelInboundHandler also uses Javassist to generate a matcher class. Since SimpleChannelInboundHandler provides little value for us, I just changed to extend ChannelInboundHandlerAdapter directly.

private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);

private final TransportClient client;
Expand Down Expand Up @@ -114,11 +113,13 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}

@Override
public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
if (request instanceof RequestMessage) {
requestHandler.handle((RequestMessage) request);
} else {
} else if (request instanceof ResponseMessage) {
responseHandler.handle((ResponseMessage) request);
} else {
ctx.fireChannelRead(request);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@
public class ProtocolSuite {
private void testServerToClient(Message msg) {
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder());
MessageEncoder.INSTANCE);
serverChannel.writeOutbound(msg);

EmbeddedChannel clientChannel = new EmbeddedChannel(
NettyUtils.createFrameDecoder(), new MessageDecoder());
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);

while (!serverChannel.outboundMessages().isEmpty()) {
clientChannel.writeInbound(serverChannel.readOutbound());
Expand All @@ -65,11 +65,11 @@ private void testServerToClient(Message msg) {

private void testClientToServer(Message msg) {
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder());
MessageEncoder.INSTANCE);
clientChannel.writeOutbound(msg);

EmbeddedChannel serverChannel = new EmbeddedChannel(
NettyUtils.createFrameDecoder(), new MessageDecoder());
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);

while (!clientChannel.outboundMessages().isEmpty()) {
serverChannel.writeInbound(clientChannel.readOutbound());
Expand Down
16 changes: 4 additions & 12 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2601,12 +2601,8 @@ private[util] object CallerContext extends Logging {
val callerContextSupported: Boolean = {
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
try {
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
// master Maven build, so do not use it before resolving SPARK-17714.
// scalastyle:off classforname
Class.forName("org.apache.hadoop.ipc.CallerContext")
Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:on classforname
Utils.classForName("org.apache.hadoop.ipc.CallerContext")
Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
true
} catch {
case _: ClassNotFoundException =>
Expand Down Expand Up @@ -2681,12 +2677,8 @@ private[spark] class CallerContext(
def setCurrentContext(): Unit = {
if (CallerContext.callerContextSupported) {
try {
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
// master Maven build, so do not use it before resolving SPARK-17714.
// scalastyle:off classforname
val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
// scalastyle:on classforname
val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
val hdfsContext = builder.getMethod("build").invoke(builderInst)
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
Expand Down