Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OkHttp: Flushes headers out immediately for client streaming and bidi streaming #274

Merged
merged 1 commit into from
Apr 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.squareup.okhttp.internal.spdy.Header;

import io.grpc.Metadata;
import io.grpc.MethodType;
import io.grpc.Status;
import io.grpc.transport.ClientStreamListener;
import io.grpc.transport.Http2ClientStream;
Expand All @@ -58,14 +59,17 @@ class OkHttpClientStream extends Http2ClientStream {
private static int WINDOW_UPDATE_THRESHOLD =
OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 2;

private final MethodType type;

/**
* Construct a new client stream.
*/
static OkHttpClientStream newStream(ClientStreamListener listener,
AsyncFrameWriter frameWriter,
OkHttpClientTransport transport,
OutboundFlowController outboundFlow) {
return new OkHttpClientStream(listener, frameWriter, transport, outboundFlow);
OutboundFlowController outboundFlow,
MethodType type) {
return new OkHttpClientStream(listener, frameWriter, transport, outboundFlow, type);
}

@GuardedBy("lock")
Expand All @@ -82,11 +86,20 @@ static OkHttpClientStream newStream(ClientStreamListener listener,
private OkHttpClientStream(ClientStreamListener listener,
AsyncFrameWriter frameWriter,
OkHttpClientTransport transport,
OutboundFlowController outboundFlow) {
OutboundFlowController outboundFlow,
MethodType type) {
super(new OkHttpWritableBufferAllocator(), listener);
this.frameWriter = frameWriter;
this.transport = transport;
this.outboundFlow = outboundFlow;
this.type = type;
}

/**
* Returns the type of this stream.
*/
public MethodType getType() {
return type;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodType;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.transport.ClientStreamListener;
Expand Down Expand Up @@ -209,7 +210,7 @@ public OkHttpClientStream newStream(MethodDescriptor<?, ?> method,
Preconditions.checkNotNull(listener, "listener");

OkHttpClientStream clientStream =
OkHttpClientStream.newStream(listener, frameWriter, this, outboundFlow);
OkHttpClientStream.newStream(listener, frameWriter, this, outboundFlow, method.getType());

String defaultPath = "/" + method.getName();
List<Header> requestHeaders =
Expand Down Expand Up @@ -250,6 +251,11 @@ private void startStream(OkHttpClientStream stream, List<Header> requestHeaders)
stream.id(nextStreamId);
streams.put(stream.id(), stream);
frameWriter.synStream(false, false, stream.id(), 0, requestHeaders);
// For unary and server streaming, there will be a data frame soon, no need to flush the header.
if (stream.getType() != MethodType.UNARY
&& stream.getType() != MethodType.SERVER_STREAMING) {
frameWriter.flush();
}
if (nextStreamId >= Integer.MAX_VALUE - 2) {
onGoAway(Integer.MAX_VALUE, Status.INTERNAL.withDescription("Stream ids exhausted"));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -58,6 +59,7 @@

import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodType;
import io.grpc.Status;
import io.grpc.transport.ClientStreamListener;
import io.grpc.transport.ClientTransport;
Expand All @@ -71,6 +73,7 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand Down Expand Up @@ -126,6 +129,7 @@ public void setUp() {
frameHandler = clientTransport.getHandler();
streams = clientTransport.getStreams();
when(method.getName()).thenReturn("fakemethod");
when(method.getType()).thenReturn(MethodType.UNARY);
when(frameWriter.maxDataLength()).thenReturn(Integer.MAX_VALUE);
}

Expand Down Expand Up @@ -219,7 +223,7 @@ public void invalidInboundHeadersCancelStream() throws Exception {
@Test
public void readStatus() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method,new Metadata.Headers(), listener);
clientTransport.newStream(method, new Metadata.Headers(), listener);
assertTrue(streams.containsKey(3));
frameHandler.headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
listener.waitUntilStreamClosed();
Expand All @@ -229,7 +233,7 @@ public void readStatus() throws Exception {
@Test
public void receiveReset() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method,new Metadata.Headers(), listener);
clientTransport.newStream(method, new Metadata.Headers(), listener);
assertTrue(streams.containsKey(3));
frameHandler.rstStream(3, ErrorCode.PROTOCOL_ERROR);
listener.waitUntilStreamClosed();
Expand All @@ -239,7 +243,7 @@ public void receiveReset() throws Exception {
@Test
public void cancelStream() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method,new Metadata.Headers(), listener);
clientTransport.newStream(method, new Metadata.Headers(), listener);
OkHttpClientStream stream = streams.get(3);
assertNotNull(stream);
stream.cancel();
Expand All @@ -253,7 +257,7 @@ public void cancelStream() throws Exception {
public void writeMessage() throws Exception {
final String message = "Hello Server";
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method,new Metadata.Headers(), listener);
clientTransport.newStream(method, new Metadata.Headers(), listener);
OkHttpClientStream stream = streams.get(3);
InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8));
assertEquals(12, input.available());
Expand Down Expand Up @@ -431,7 +435,7 @@ public void streamIdExhausted() throws Exception {
streams = transport.getStreams();

MockStreamListener listener1 = new MockStreamListener();
transport.newStream(method,new Metadata.Headers(), listener1);
transport.newStream(method, new Metadata.Headers(), listener1);

assertNewStreamFail(transport);

Expand Down Expand Up @@ -581,7 +585,7 @@ public void run() {
@Test
public void receivingWindowExceeded() throws Exception {
MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method,new Metadata.Headers(), listener).request(1);
clientTransport.newStream(method, new Metadata.Headers(), listener).request(1);

frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);

Expand All @@ -598,6 +602,43 @@ public void receivingWindowExceeded() throws Exception {
verify(frameWriter).rstStream(eq(3), eq(ErrorCode.FLOW_CONTROL_ERROR));
}

@Test
public void unaryHeadersShouldNotBeFlushed() throws Exception {
// By default the method is a Unary call
shouldHeadersBeFlushed(false);
}

@Test
public void serverStreamingHeadersShouldNotBeFlushed() throws Exception {
when(method.getType()).thenReturn(MethodType.SERVER_STREAMING);
shouldHeadersBeFlushed(false);
}

@Test
public void clientStreamingHeadersShouldBeFlushed() throws Exception {
when(method.getType()).thenReturn(MethodType.CLIENT_STREAMING);
shouldHeadersBeFlushed(true);
}

@Test
public void duplexStreamingHeadersShouldNotBeFlushed() throws Exception {
when(method.getType()).thenReturn(MethodType.DUPLEX_STREAMING);
shouldHeadersBeFlushed(true);
}

private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception {
OkHttpClientStream stream = clientTransport.newStream(
method, new Metadata.Headers(), new MockStreamListener());
verify(frameWriter).synStream(
eq(false), eq(false), eq(3), eq(0), Matchers.anyListOf(Header.class));
if (shouldBeFlushed) {
verify(frameWriter).flush();
} else {
verify(frameWriter, times(0)).flush();
}
stream.cancel();
}

private void waitForStreamPending(int expected) throws Exception {
int duration = TIME_OUT_MS / 10;
for (int i = 0; i < 10; i++) {
Expand Down