Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26604][CORE][BACKPORT-2.4] Clean up channel registration for StreamManager #24013

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand All @@ -49,7 +50,7 @@ private static class StreamState {
final Iterator<ManagedBuffer> buffers;

// The channel associated to the stream
Channel associatedChannel = null;
final Channel associatedChannel;

// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
Expand All @@ -58,9 +59,10 @@ private static class StreamState {
// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;

StreamState(String appId, Iterator<ManagedBuffer> buffers) {
StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
this.associatedChannel = channel;
}
}

Expand All @@ -71,13 +73,6 @@ public OneForOneStreamManager() {
streams = new ConcurrentHashMap<>();
}

@Override
public void registerChannel(Channel channel, long streamId) {
if (streams.containsKey(streamId)) {
streams.get(streamId).associatedChannel = channel;
}
}

@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId);
Expand Down Expand Up @@ -195,11 +190,19 @@ public long chunksBeingTransferred() {
*
* If an app ID is provided, only callers who've authenticated with the given app ID will be
* allowed to fetch from this stream.
*
* This method also associates the stream with a single client connection, which is guaranteed
* to be the only reader of the stream. Once the connection is closed, the stream will never
* be used again, enabling cleanup by `connectionTerminated`.
*/
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
long myStreamId = nextStreamId.getAndIncrement();
streams.put(myStreamId, new StreamState(appId, buffers));
streams.put(myStreamId, new StreamState(appId, buffers, channel));
return myStreamId;
}

@VisibleForTesting
public int numStreamStates() {
return streams.size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,6 @@ public ManagedBuffer openStream(String streamId) {
throw new UnsupportedOperationException();
}

/**
* Associates a stream with a single client connection, which is guaranteed to be the only reader
* of the stream. The getChunk() method will be called serially on this connection and once the
* connection is closed, the stream will never be used again, enabling cleanup.
*
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
* times with the same channel and stream id.
*/
public void registerChannel(Channel channel, long streamId) { }

/**
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
* to read from the associated streams again, so any state can be cleaned up.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ private void processFetchRequest(final ChunkFetchRequest req) {
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format("Error opening block %s for request from %s",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ public void handleFetchRequestAndStreamRequest() throws Exception {
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
streamManager.registerChannel(channel, streamId);
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);

assert streamManager.numStreamStates() == 1;

TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L);
Expand Down Expand Up @@ -98,6 +100,9 @@ public void handleFetchRequestAndStreamRequest() throws Exception {
requestHandler.handle(request3);
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;

streamManager.connectionTerminated(channel);
assert streamManager.numStreamStates() == 0;
}

private class ExtendedChannelPromise extends DefaultChannelPromise {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
buffers.add(buffer1);
buffers.add(buffer2);
long streamId = manager.registerStream("appId", buffers.iterator());

Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
manager.registerChannel(dummyChannel, streamId);
manager.registerStream("appId", buffers.iterator(), dummyChannel);
assert manager.numStreamStates() == 1;

manager.connectionTerminated(dummyChannel);

Mockito.verify(buffer1, Mockito.times(1)).release();
Mockito.verify(buffer2, Mockito.times(1)).release();
assert manager.numStreamStates() == 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ protected void handleMessage(
OpenBlocks msg = (OpenBlocks) msgObj;
checkAuth(client, msg.appId);
long streamId = streamManager.registerStream(client.getClientId(),
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel());
if (logger.isTraceEnabled()) {
logger.trace("Registered streamId {} with {} buffers for client {} from host {}",
streamId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ public void testOpenShuffleBlocks() {
@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
any());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class NettyBlockRpcServer(
val blocksNum = openBlocks.blockIds.length
val blocks = for (i <- (0 until blocksNum).view)
yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
client.getChannel)
logTrace(s"Registered streamId $streamId with $blocksNum buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)

Expand Down