Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26604][CORE][BACKPORT-1.6] Clean up channel registration for S… #26988

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class NettyBlockRpcServer(
case openBlocks: OpenBlocks =>
val blocks: Seq[ManagedBuffer] =
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
client.getChannel)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ private static class StreamState {
final Iterator<ManagedBuffer> buffers;

// The channel associated to the stream
Channel associatedChannel = null;
final Channel associatedChannel;

// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;

StreamState(String appId, Iterator<ManagedBuffer> buffers) {
StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel associatedChannel) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
this.associatedChannel = associatedChannel;
}
}

Expand All @@ -67,13 +68,6 @@ public OneForOneStreamManager() {
streams = new ConcurrentHashMap<Long, StreamState>();
}

@Override
public void registerChannel(Channel channel, long streamId) {
if (streams.containsKey(streamId)) {
streams.get(streamId).associatedChannel = channel;
}
}

@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId);
Expand Down Expand Up @@ -135,10 +129,9 @@ public void checkAuthorization(TransportClient client, long streamId) {
* If an app ID is provided, only callers who've authenticated with the given app ID will be
* allowed to fetch from this stream.
*/
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
long myStreamId = nextStreamId.getAndIncrement();
streams.put(myStreamId, new StreamState(appId, buffers));
streams.put(myStreamId, new StreamState(appId, buffers, channel));
return myStreamId;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ public ManagedBuffer openStream(String streamId) {
throw new UnsupportedOperationException();
}

/**
* Associates a stream with a single client connection, which is guaranteed to be the only reader
* of the stream. The getChunk() method will be called serially on this connection and once the
* connection is closed, the stream will never be used again, enabling cleanup.
*
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
* times with the same channel and stream id.
*/
public void registerChannel(Channel channel, long streamId) { }

/**
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
* to read from the associated streams again, so any state can be cleaned up.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ private void processFetchRequest(final ChunkFetchRequest req) {
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected void handleMessage(
for (String blockId : msg.blockIds) {
blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId));
}
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator());
long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator(), client.getChannel());
logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length);
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.Iterator;

import io.netty.channel.Channel;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
Expand Down Expand Up @@ -94,7 +95,8 @@ public void testOpenShuffleBlocks() {
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
(Channel)any());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());
Expand Down