Skip to content

Commit

Permalink
Cleanup SASL state upon connection termination
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed Nov 4, 2014
1 parent 7b42adb commit f6177d7
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

import java.io.Closeable;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SettableFuture;
Expand Down Expand Up @@ -187,9 +186,11 @@ public void close() {
channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}

/** Returns a stable key for the given channel. Only valid after the channel is connected. */
public String getChannelKey() {
return String.format("[%s, %s, %s]", channel.remoteAddress(), channel.localAddress(),
channel.hashCode());
@Override
public String toString() {
return Objects.toStringHelper(this)
.add("remoteAdress", channel.remoteAddress())
.add("isActive", isActive())
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.spark.network.client.TransportClient;

/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
public class NoOpRpcHandler implements RpcHandler {
public class NoOpRpcHandler extends RpcHandler {
private final StreamManager streamManager;

public NoOpRpcHandler() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,33 @@
/**
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
*/
public interface RpcHandler {
public abstract class RpcHandler {
/**
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
* the client in string form as a standard RPC failure.
*
* This method will not be called in parallel for a single TransportClient (i.e., channel).
*
* @param client A channel client which enables the handler to make requests back to the sender
* of this RPC.
* of this RPC. This will always be the exact same object for a particular channel.
* @param message The serialized bytes of the RPC.
* @param callback Callback which should be invoked exactly once upon success or failure of the
* RPC.
*/
void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
public abstract void receive(
TransportClient client,
byte[] message,
RpcResponseCallback callback);

/**
* Returns the StreamManager which contains the state about which streams are currently being
* fetched by a TransportClient.
*/
StreamManager getStreamManager();
public abstract StreamManager getStreamManager();

/**
* Invoked when the connection associated with the given client has been invalidated.
* No further requests will come from this client.
*/
public void connectionTerminated(TransportClient client) { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public void channelUnregistered() {
for (long streamId : streamIds) {
streamManager.connectionTerminated(streamId);
}
rpcHandler.connectionTerminated(reverseClient);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,26 @@ public SaslBootstrap(String secretKeyId, SecretKeyHolder secretKeyHolder) {

public void doBootstrap(TransportClient client) {
SparkSaslClient saslClient = new SparkSaslClient(secretKeyId, secretKeyHolder);
byte[] payload = saslClient.firstToken();

while (!saslClient.isComplete()) {
SaslMessage msg = new SaslMessage(secretKeyId, payload);
logger.info("Sending msg {} {}", secretKeyId, payload.length);
ByteBuf buf = Unpooled.buffer(msg.encodedLength());
msg.encode(buf);

byte[] response = client.sendRpcSync(buf.array(), 300000);
logger.info("Got response {} {}", secretKeyId, response.length);
payload = saslClient.response(response);
try {
byte[] payload = saslClient.firstToken();

while (!saslClient.isComplete()) {
SaslMessage msg = new SaslMessage(secretKeyId, payload);
logger.info("Sending msg {} {}", secretKeyId, payload.length);
ByteBuf buf = Unpooled.buffer(msg.encodedLength());
msg.encode(buf);

byte[] response = client.sendRpcSync(buf.array(), 300000);
logger.info("Got response {} {}", secretKeyId, response.length);
payload = saslClient.response(response);
}
} finally {
try {
// Once authentication is complete, the server will trust all remaining communication.
saslClient.dispose();
} catch (RuntimeException e) {
logger.error("Error while disposing SASL client", e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@
* RPC Handler which performs SASL authentication before delegating to a child RPC handler.
* The delegate will only receive messages if the given connection has been successfully
* authenticated. A connection may be authenticated at most once.
*
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
public class SaslRpcHandler implements RpcHandler {
public class SaslRpcHandler extends RpcHandler {
private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);

private final RpcHandler delegate;
private final SecretKeyHolder secretKeyHolder;

// TODO: Invalidate channels that have closed!
private final ConcurrentMap<String, SparkSaslServer> channelAuthenticationMap;
private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;

public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
this.delegate = delegate;
Expand All @@ -54,9 +57,7 @@ public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {

@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
String channelKey = client.getChannelKey();

SparkSaslServer saslServer = channelAuthenticationMap.get(channelKey);
SparkSaslServer saslServer = channelAuthenticationMap.get(client);
if (saslServer != null && saslServer.isComplete()) {
// Authentication complete, delegate to base handler.
delegate.receive(client, message, callback);
Expand All @@ -66,13 +67,14 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback
SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));

if (saslServer == null) {
// First message in the handshake, setup the necessary state.
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
channelAuthenticationMap.put(channelKey, saslServer);
channelAuthenticationMap.put(client, saslServer);
}

byte[] response = saslServer.response(saslMessage.payload);
if (saslServer.isComplete()) {
logger.debug("SASL authentication successful for channel {}", channelKey);
logger.debug("SASL authentication successful for channel {}", client);
}
callback.onSuccess(response);
}
Expand All @@ -81,5 +83,13 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}

@Override
public void connectionTerminated(TransportClient client) {
SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
if (saslServer != null) {
saslServer.dispose();
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
* with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark-
* level shuffle block.
*/
public class ExternalShuffleBlockHandler implements RpcHandler {
public class ExternalShuffleBlockHandler extends RpcHandler {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);

private final ExternalShuffleBlockManager blockManager;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public void testNoSaslServer() {
}

/** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler implements RpcHandler {
public static class TestRpcHandler extends RpcHandler {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
callback.onSuccess(message);
Expand Down

0 comments on commit f6177d7

Please sign in to comment.