Skip to content

Commit

Permalink
Add client max size configuration (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Jul 10, 2019
1 parent ea580ae commit 2279d9a
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 22 deletions.
16 changes: 13 additions & 3 deletions src/Grpc.Net.Client/HttpClientCallInvoker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace Grpc.Net.Client
public sealed class HttpClientCallInvoker : CallInvoker
{
private readonly HttpClient _client;
private readonly ILoggerFactory _loggerFactory;
internal ILoggerFactory LoggerFactory { get; }

// Override the current time for unit testing
internal ISystemClock Clock = SystemClock.Instance;
Expand All @@ -50,7 +50,7 @@ public HttpClientCallInvoker(HttpClient client, ILoggerFactory? loggerFactory)
}

_client = client;
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
LoggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
Deadline = DateTime.MaxValue;
}

Expand All @@ -72,6 +72,16 @@ public HttpClientCallInvoker(HttpClient client, ILoggerFactory? loggerFactory)
/// </summary>
public DateTime Deadline { get; set; }

/// <summary>
/// Gets or sets the maximum message size in bytes that can be sent from the client.
/// </summary>
public int? SendMaxMessageSize { get; set; }

/// <summary>
/// Gets or sets the maximum message size in bytes that can be received by the client.
/// </summary>
public int? ReceiveMaxMessageSize { get; set; }

/// <summary>
/// Invokes a client streaming call asynchronously.
/// In client streaming scenario, client sends a stream of requests and server responds with a single response.
Expand Down Expand Up @@ -187,7 +197,7 @@ private GrpcCall<TRequest, TResponse> CreateGrpcCall<TRequest, TResponse>(
}
}

var call = new GrpcCall<TRequest, TResponse>(method, options, Clock, _loggerFactory);
var call = new GrpcCall<TRequest, TResponse>(method, options, this);

// Clean up linked cancellation token
disposeAction = linkedCts != null
Expand Down
12 changes: 7 additions & 5 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ internal partial class GrpcCall<TRequest, TResponse> : IDisposable
{
private readonly CancellationTokenSource _callCts;
private readonly CancellationTokenRegistration? _ctsRegistration;
private readonly ISystemClock _clock;
private readonly TimeSpan? _timeout;
private readonly Uri _uri;
private readonly GrpcCallScope _logScope;
Expand All @@ -54,13 +53,14 @@ internal partial class GrpcCall<TRequest, TResponse> : IDisposable
public HttpResponseMessage? HttpResponse { get; private set; }
public CallOptions Options { get; }
public Method<TRequest, TResponse> Method { get; }
public HttpClientCallInvoker CallInvoker { get; }

public ILogger Logger { get; }
public Task? SendTask { get; private set; }
public HttpContentClientStreamWriter<TRequest, TResponse>? ClientStreamWriter { get; private set; }
public HttpContentClientStreamReader<TRequest, TResponse>? ClientStreamReader { get; private set; }

public GrpcCall(Method<TRequest, TResponse> method, CallOptions options, ISystemClock clock, ILoggerFactory loggerFactory)
public GrpcCall(Method<TRequest, TResponse> method, CallOptions options, HttpClientCallInvoker callInvoker)
{
// Validate deadline before creating any objects that require cleanup
ValidateDeadline(options.Deadline);
Expand All @@ -70,8 +70,8 @@ public GrpcCall(Method<TRequest, TResponse> method, CallOptions options, ISystem
_uri = new Uri(method.FullName, UriKind.Relative);
_logScope = new GrpcCallScope(method.Type, _uri);
Options = options;
_clock = clock;
Logger = loggerFactory.CreateLogger<GrpcCall<TRequest, TResponse>>();
CallInvoker = callInvoker;
Logger = callInvoker.LoggerFactory.CreateLogger<GrpcCall<TRequest, TResponse>>();

if (options.CancellationToken.CanBeCanceled)
{
Expand All @@ -87,7 +87,7 @@ public GrpcCall(Method<TRequest, TResponse> method, CallOptions options, ISystem

if (options.Deadline != null && options.Deadline != DateTime.MaxValue)
{
var timeout = options.Deadline.GetValueOrDefault() - _clock.UtcNow;
var timeout = options.Deadline.GetValueOrDefault() - CallInvoker.Clock.UtcNow;
_timeout = (timeout > TimeSpan.Zero) ? timeout : TimeSpan.Zero;
}
}
Expand Down Expand Up @@ -302,6 +302,7 @@ public async Task<TResponse> GetResponseAsync()
Logger,
Method.ResponseMarshaller.ContextualDeserializer,
GrpcProtocolHelpers.GetGrpcEncoding(HttpResponse),
CallInvoker.ReceiveMaxMessageSize,
_callCts.Token).ConfigureAwait(false);
FinishResponse();

Expand Down Expand Up @@ -384,6 +385,7 @@ private void SetMessageContent(TRequest request, HttpRequestMessage message)
request,
Method.RequestMarshaller.ContextualSerializer,
grpcEncoding,
CallInvoker.SendMaxMessageSize,
Options.CancellationToken);
},
GrpcProtocolConstants.GrpcContentTypeHeaderValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
_call.Logger,
_call.Method.ResponseMarshaller.ContextualDeserializer,
GrpcProtocolHelpers.GetGrpcEncoding(_httpResponse),
_call.CallInvoker.ReceiveMaxMessageSize,
cancellationToken).ConfigureAwait(false);
if (Current == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ await writeStream.WriteMessage<TRequest>(
message,
_call.Method.RequestMarshaller.ContextualSerializer,
_grpcEncoding,
_call.CallInvoker.SendMaxMessageSize,
_call.CancellationToken).ConfigureAwait(false);
}
catch (TaskCanceledException)
Expand Down
20 changes: 18 additions & 2 deletions src/Grpc.Net.Client/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ internal static partial class StreamExtensions
private const int MessageDelimiterSize = 4; // how many bytes it takes to encode "Message-Length"
private const int HeaderSize = MessageDelimiterSize + 1; // message length + compression flag

private static readonly Status SendingMessageExceedsLimitStatus = new Status(StatusCode.ResourceExhausted, "Sending message exceeds the maximum configured message size.");
private static readonly Status ReceivedMessageExceedsLimitStatus = new Status(StatusCode.ResourceExhausted, "Received message exceeds the maximum configured message size.");
private static readonly Status NoMessageEncodingMessageStatus = new Status(StatusCode.Internal, "Request did not include grpc-encoding value with compressed message.");
private static readonly Status IdentityMessageEncodingMessageStatus = new Status(StatusCode.Internal, "Request sent 'identity' grpc-encoding value with compressed message.");
private static Status CreateUnknownMessageEncodingMessageStatus(string unsupportedEncoding, IEnumerable<string> supportedEncodings)
Expand All @@ -49,28 +51,31 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
ILogger logger,
Func<DeserializationContext, TResponse> deserializer,
string grpcEncoding,
int? maximumMessageSize,
CancellationToken cancellationToken)
where TResponse : class
{
return responseStream.ReadMessageCoreAsync(logger, deserializer, grpcEncoding, cancellationToken, true, true);
return responseStream.ReadMessageCoreAsync(logger, deserializer, grpcEncoding, maximumMessageSize, cancellationToken, true, true);
}

public static Task<TResponse?> ReadStreamedMessageAsync<TResponse>(
this Stream responseStream,
ILogger logger,
Func<DeserializationContext, TResponse> deserializer,
string grpcEncoding,
int? maximumMessageSize,
CancellationToken cancellationToken)
where TResponse : class
{
return responseStream.ReadMessageCoreAsync(logger, deserializer, grpcEncoding, cancellationToken, true, false);
return responseStream.ReadMessageCoreAsync(logger, deserializer, grpcEncoding, maximumMessageSize, cancellationToken, true, false);
}

private static async Task<TResponse?> ReadMessageCoreAsync<TResponse>(
this Stream responseStream,
ILogger logger,
Func<DeserializationContext, TResponse> deserializer,
string grpcEncoding,
int? maximumMessageSize,
CancellationToken cancellationToken,
bool canBeEmpty,
bool singleMessage)
Expand Down Expand Up @@ -117,6 +122,11 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
throw new InvalidDataException("Message too large.");
}

if (length > maximumMessageSize)
{
throw new RpcException(ReceivedMessageExceedsLimitStatus);
}

// Read message content until content length is reached
byte[] messageData;
if (length > 0)
Expand Down Expand Up @@ -212,6 +222,7 @@ public static async Task WriteMessage<TMessage>(
TMessage message,
Action<TMessage, SerializationContext> serializer,
string grpcEncoding,
int? maximumMessageSize,
CancellationToken cancellationToken)
{
try
Expand All @@ -228,6 +239,11 @@ public static async Task WriteMessage<TMessage>(
throw new InvalidOperationException("Serialization did not return a payload.");
}

if (data.Length > maximumMessageSize)
{
throw new RpcException(SendingMessageExceedsLimitStatus);
}

var isCompressed = !string.Equals(grpcEncoding, GrpcProtocolConstants.IdentityGrpcEncoding, StringComparison.Ordinal);

if (isCompressed)
Expand Down
4 changes: 2 additions & 2 deletions test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ public async Task AsyncClientStreamingCall_Success_RequestContentSent()
await call.RequestStream.CompleteAsync().DefaultTimeout();

var requestContent = await streamTask.DefaultTimeout();
var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout();
var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, maximumMessageSize: null, CancellationToken.None).DefaultTimeout();
Assert.AreEqual("1", requestMessage.Name);
requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout();
requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, maximumMessageSize: null, CancellationToken.None).DefaultTimeout();
Assert.AreEqual("2", requestMessage.Name);

var responseMessage = await responseTask.DefaultTimeout();
Expand Down
4 changes: 2 additions & 2 deletions test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ public async Task AsyncDuplexStreamingCall_MessagesStreamed_MessagesReceived()

Assert.IsNotNull(content);
var requestContent = await content!.ReadAsStreamAsync().DefaultTimeout();
var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout();
var requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, maximumMessageSize: null, CancellationToken.None).DefaultTimeout();
Assert.AreEqual("1", requestMessage.Name);
requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout();
requestMessage = await requestContent.ReadStreamedMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, maximumMessageSize: null, CancellationToken.None).DefaultTimeout();
Assert.AreEqual("2", requestMessage.Name);

Assert.IsNull(responseStream.Current);
Expand Down
2 changes: 1 addition & 1 deletion test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public async Task AsyncUnaryCall_Success_RequestContentSent()
Assert.IsNotNull(content);

var requestContent = await content!.ReadAsStreamAsync().DefaultTimeout();
var requestMessage = await requestContent.ReadSingleMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, CancellationToken.None).DefaultTimeout();
var requestMessage = await requestContent.ReadSingleMessageAsync(NullLogger.Instance, ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer, GrpcProtocolConstants.IdentityGrpcEncoding, maximumMessageSize: null, CancellationToken.None).DefaultTimeout();

Assert.AreEqual("World", requestMessage.Name);
}
Expand Down
12 changes: 8 additions & 4 deletions test/Grpc.Net.Client.Tests/CompressionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace Grpc.Net.Client.Tests
public class CompressionTests
{
[Test]
public void AsyncUnaryCall_UnknownCompressMetadataSentWithRequest_ThrowsError()
public async Task AsyncUnaryCall_UnknownCompressMetadataSentWithRequest_ThrowsError()
{
// Arrange
HttpRequestMessage? httpRequestMessage = null;
Expand All @@ -55,6 +55,7 @@ public void AsyncUnaryCall_UnknownCompressMetadataSentWithRequest_ThrowsError()
NullLogger.Instance,
ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer,
"gzip",
maximumMessageSize: null,
CancellationToken.None);

HelloReply reply = new HelloReply
Expand All @@ -76,7 +77,7 @@ public void AsyncUnaryCall_UnknownCompressMetadataSentWithRequest_ThrowsError()
});

// Assert
var ex = Assert.ThrowsAsync<InvalidOperationException>(async () => await call.ResponseAsync.DefaultTimeout());
var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() => call.ResponseAsync).DefaultTimeout();
Assert.AreEqual("Could not find compression provider for 'not-supported'.", ex.Message);
}

Expand All @@ -98,6 +99,7 @@ public async Task AsyncUnaryCall_CompressMetadataSentWithRequest_RequestMessageC
NullLogger.Instance,
ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer,
"gzip",
maximumMessageSize: null,
CancellationToken.None);

HelloReply reply = new HelloReply
Expand Down Expand Up @@ -149,6 +151,7 @@ public async Task AsyncUnaryCall_CompressedResponse_ResponseMessageDecompressed(
NullLogger.Instance,
ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer,
"gzip",
maximumMessageSize: null,
CancellationToken.None);

HelloReply reply = new HelloReply
Expand Down Expand Up @@ -176,7 +179,7 @@ public async Task AsyncUnaryCall_CompressedResponse_ResponseMessageDecompressed(
}

[Test]
public void AsyncUnaryCall_CompressedResponseWithUnknownEncoding_ErrorThrown()
public async Task AsyncUnaryCall_CompressedResponseWithUnknownEncoding_ErrorThrown()
{
// Arrange
HttpRequestMessage? httpRequestMessage = null;
Expand All @@ -193,6 +196,7 @@ public void AsyncUnaryCall_CompressedResponseWithUnknownEncoding_ErrorThrown()
NullLogger.Instance,
ClientTestHelpers.ServiceMethod.RequestMarshaller.ContextualDeserializer,
"gzip",
maximumMessageSize: null,
CancellationToken.None);

HelloReply reply = new HelloReply
Expand All @@ -214,7 +218,7 @@ public void AsyncUnaryCall_CompressedResponseWithUnknownEncoding_ErrorThrown()
});

// Assert
var ex = Assert.ThrowsAsync<RpcException>(async () => await call.ResponseAsync.DefaultTimeout());
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();
Assert.AreEqual(StatusCode.Unimplemented, ex.StatusCode);
Assert.AreEqual("Unsupported grpc-encoding value 'not-supported'. Supported encodings: gzip", ex.Status.Detail);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public async Task MoveNext_TokenCanceledBeforeCall_ThrowError()
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, content));
});

var call = new GrpcCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, new CallOptions(), SystemClock.Instance, NullLoggerFactory.Instance);
var httpClientCallInvoker = new HttpClientCallInvoker(httpClient, null);
var call = new GrpcCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, new CallOptions(), httpClientCallInvoker);
call.StartServerStreaming(httpClient, new HelloRequest());

// Act
Expand All @@ -73,7 +74,8 @@ public async Task MoveNext_TokenCanceledDuringCall_ThrowError()
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, content));
});

var call = new GrpcCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, new CallOptions(), SystemClock.Instance, NullLoggerFactory.Instance);
var httpClientCallInvoker = new HttpClientCallInvoker(httpClient, null);
var call = new GrpcCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, new CallOptions(), httpClientCallInvoker);
call.StartServerStreaming(httpClient, new HelloRequest());

// Act
Expand All @@ -99,7 +101,8 @@ public async Task MoveNext_MultipleCallsWithoutAwait_ThrowError()
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, content));
});

var call = new GrpcCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, new CallOptions(), SystemClock.Instance, NullLoggerFactory.Instance);
var httpClientCallInvoker = new HttpClientCallInvoker(httpClient, null);
var call = new GrpcCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, new CallOptions(), httpClientCallInvoker);
call.StartServerStreaming(httpClient, new HelloRequest());

// Act
Expand Down
Loading

0 comments on commit 2279d9a

Please sign in to comment.