Skip to content

Commit

Permalink
QUIC: fix unobserved exception from _connectionCloseTcs (#104894)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManickaP authored Jul 16, 2024
1 parent bc9b3b6 commit 817a2fb
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Net.Security;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Security.Authentication;
using System.Threading;
using System.Threading.Tasks;
using static Microsoft.Quic.MsQuic;

namespace System.Net.Quic;
Expand All @@ -28,27 +30,14 @@ internal static QuicException GetOperationAbortedException(string? message = nul
return new QuicException(QuicError.OperationAborted, null, message ?? SR.net_quic_operationaborted);
}

internal static bool TryGetStreamExceptionForMsQuicStatus(int status, [NotNullWhen(true)] out Exception? exception, bool streamWasSuccessfullyStarted = true, string? message = null)
internal static bool TryGetStreamExceptionForMsQuicStatus(int status, [NotNullWhen(true)] out Exception? exception)
{
if (status == QUIC_STATUS_ABORTED)
{
// Connection has been closed by the peer (either at transport or application level),
if (streamWasSuccessfullyStarted)
{
// we will receive an event later, which will complete the stream with concrete
// information why the connection was aborted.
exception = null;
return false;
}
else
{
// we won't be receiving any event callback for shutdown on this stream, so we don't
// necessarily know which error to report. So we throw an exception which we can distinguish
// at the caller (ConnectionAborted normally has App error code) and throw the correct
// exception from there.
exception = new QuicException(QuicError.ConnectionAborted, null, "");
return true;
}
// If status == QUIC_STATUS_ABORTED, the connection was closed by transport or the peer.
// We will receive an event later with details for ConnectionAborted exception to complete the task source with.
exception = null;
return false;
}
else if (status == QUIC_STATUS_INVALID_STATE)
{
Expand All @@ -58,16 +47,13 @@ internal static bool TryGetStreamExceptionForMsQuicStatus(int status, [NotNullWh
}
else if (StatusFailed(status))
{
exception = GetExceptionForMsQuicStatus(status, message: message);
exception = GetExceptionForMsQuicStatus(status);
return true;
}
exception = null;
return false;
}

// see TryGetStreamExceptionForMsQuicStatus for explanation
internal static bool IsConnectionAbortedWhenStartingStreamException(Exception ex) => ex is QuicException qe && qe.QuicError == QuicError.ConnectionAborted && qe.ApplicationErrorCode is null;

internal static Exception GetExceptionForMsQuicStatus(int status, long? errorCode = default, string? message = null)
{
Exception ex = GetExceptionInternal(status, errorCode, message);
Expand Down Expand Up @@ -229,4 +215,26 @@ public static void ValidateNotNull(string argumentName, string resourceName, obj
throw new ArgumentNullException(argumentName, SR.Format(resourceName, propertyName));
}
}

public static void ObserveException(this Task task)
{
if (task.IsCompleted)
{
ObserveExceptionCore(task);
}
else
{
task.ContinueWith(static (t) => ObserveExceptionCore(t), CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Default);
}

static void ObserveExceptionCore(Task task)
{
Debug.Assert(task.IsCompleted);
if (task.IsFaulted)
{
// Access Exception to avoid TaskScheduler.UnobservedTaskException firing.
Exception? e = task.Exception!.InnerException;
}
}
}
}
28 changes: 13 additions & 15 deletions src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ static async ValueTask<QuicConnection> StartConnectAsync(QuicClientConnectionOpt
{
await connection.DisposeAsync().ConfigureAwait(false);

// throw OCE with correct token if cancellation requested by user
// Throw OCE with correct token if cancellation requested by user.
cancellationToken.ThrowIfCancellationRequested();

// cancellation by the linkedCts.CancelAfter. Convert to Timeout
// Cancellation by the linkedCts.CancelAfter, convert to timeout.
throw new QuicException(QuicError.ConnectionTimeout, null, SR.Format(SR.net_quic_handshake_timeout, options.HandshakeTimeout));
}
catch
Expand All @@ -113,11 +113,6 @@ static async ValueTask<QuicConnection> StartConnectAsync(QuicClientConnectionOpt
/// </summary>
private int _disposed;

/// <summary>
/// Completed when connection shutdown is initiated.
/// </summary>
private TaskCompletionSource _connectionCloseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly ValueTaskSource _connectedTcs = new ValueTaskSource();
private readonly ResettableValueTaskSource _shutdownTcs = new ResettableValueTaskSource()
{
Expand All @@ -140,6 +135,11 @@ static async ValueTask<QuicConnection> StartConnectAsync(QuicClientConnectionOpt
}
};

/// <summary>
/// Completed when connection shutdown is initiated.
/// </summary>
private readonly TaskCompletionSource _connectionCloseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly CancellationTokenSource _shutdownTokenSource = new CancellationTokenSource();

// Token that fires when the connection is closed.
Expand Down Expand Up @@ -369,7 +369,7 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options,
{
Debug.Assert(host is not null);

// Given just a ServerName to connect to, msquic would also use the first address after the resolution
// Given just a ServerName to connect to, MsQuic would also use the first address after the resolution
// (https://github.com/microsoft/msquic/issues/1181) and it would not return a well-known error code
// for resolution failures we could rely on. By doing the resolution in managed code, we can guarantee
// that a SocketException will surface to the user if the name resolution fails.
Expand Down Expand Up @@ -526,13 +526,9 @@ public async ValueTask<QuicStream> OpenOutboundStreamAsync(QuicStreamType type,
// Propagate ODE if disposed in the meantime.
ObjectDisposedException.ThrowIf(_disposed == 1, this);

// In case of an incoming race when the connection is closed by the peer just before we open the stream,
// we receive QUIC_STATUS_ABORTED from MsQuic, but we don't know how the connection was closed. We throw
// special exception and handle it here where we can determine the shutdown reason.
bool connectionAbortedByPeer = ThrowHelper.IsConnectionAbortedWhenStartingStreamException(ex);

// Propagate connection error if present.
if (_connectionCloseTcs.Task.IsFaulted || connectionAbortedByPeer)
// Propagate connection error when the connection was closed (remotely = ABORTED / locally = INVALID_STATE).
if (ex is QuicException qex && qex.QuicError == QuicError.InternalError &&
(qex.HResult == QUIC_STATUS_ABORTED || qex.HResult == QUIC_STATUS_INVALID_STATE))
{
await _connectionCloseTcs.Task.ConfigureAwait(false);
}
Expand Down Expand Up @@ -822,8 +818,10 @@ public async ValueTask DisposeAsync()
// Wait for SHUTDOWN_COMPLETE, the last event, so that all resources can be safely released.
await _shutdownTcs.GetFinalTask(this).ConfigureAwait(false);
Debug.Assert(_connectedTcs.IsCompleted);
Debug.Assert(_connectionCloseTcs.Task.IsCompleted);
_handle.Dispose();
_shutdownTokenSource.Dispose();
_connectionCloseTcs.Task.ObserveException();
_configuration?.Dispose();

// Dispose remote certificate only if it hasn't been accessed via getter, in which case the accessing code becomes the owner of the certificate lifetime.
Expand Down
29 changes: 13 additions & 16 deletions src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,17 @@ internal unsafe QuicStream(MsQuicContextSafeHandle connectionHandle, QuicStreamT
try
{
QUIC_HANDLE* handle;
int status = MsQuicApi.Api.StreamOpen(
ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamOpen(
connectionHandle,
type == QuicStreamType.Unidirectional ? QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL : QUIC_STREAM_OPEN_FLAGS.NONE,
&NativeCallback,
(void*)GCHandle.ToIntPtr(context),
&handle);

if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? ex, streamWasSuccessfullyStarted: false, message: "StreamOpen failed"))
&handle),
"StreamOpen failed");
_handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Stream, connectionHandle)
{
throw ex;
}

_handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Stream, connectionHandle);
_handle.Disposable = _sendBuffers;
Disposable = _sendBuffers
};
}
catch
{
Expand Down Expand Up @@ -213,8 +210,10 @@ internal unsafe QuicStream(MsQuicContextSafeHandle connectionHandle, QUIC_HANDLE
GCHandle context = GCHandle.Alloc(this, GCHandleType.Weak);
try
{
_handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Stream, connectionHandle);
_handle.Disposable = _sendBuffers;
_handle = new MsQuicContextSafeHandle(handle, context, SafeHandleType.Stream, connectionHandle)
{
Disposable = _sendBuffers
};
delegate* unmanaged[Cdecl]<QUIC_HANDLE*, void*, QUIC_STREAM_EVENT*, int> nativeCallback = &NativeCallback;
MsQuicApi.Api.SetCallbackHandler(
_handle,
Expand Down Expand Up @@ -261,14 +260,12 @@ internal ValueTask StartAsync(Action<QuicStreamType> decrementStreamCapacity, Ca
int status = MsQuicApi.Api.StreamStart(
_handle,
QUIC_STREAM_START_FLAGS.SHUTDOWN_ON_FAIL | QUIC_STREAM_START_FLAGS.INDICATE_PEER_ACCEPT);

if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception, streamWasSuccessfullyStarted: false))
if (StatusFailed(status))
{
_decrementStreamCapacity = null;
_startedTcs.TrySetException(exception);
_startedTcs.TrySetException(ThrowHelper.GetExceptionForMsQuicStatus(status));
}
}

return valueTask;
}

Expand Down Expand Up @@ -637,7 +634,7 @@ private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE_DATA data)
// It's local shutdown by app, this side called QuicConnection.CloseAsync, throw QuicError.OperationAborted.
(shutdownByApp: true, closedRemotely: false) => ThrowHelper.GetOperationAbortedException(),
// It's remote shutdown by transport, we received a CONNECTION_CLOSE frame with a QUIC transport error code, throw error based on the status.
(shutdownByApp: false, closedRemotely: true) => ThrowHelper.GetExceptionForMsQuicStatus(data.ConnectionCloseStatus, (long)data.ConnectionErrorCode, $"Shutdown by transport {data.ConnectionErrorCode}"),
(shutdownByApp: false, closedRemotely: true) => ThrowHelper.GetExceptionForMsQuicStatus(data.ConnectionCloseStatus, (long)data.ConnectionErrorCode),
// It's local shutdown by transport, most likely due to a timeout, throw error based on the status.
(shutdownByApp: false, closedRemotely: false) => ThrowHelper.GetExceptionForMsQuicStatus(data.ConnectionCloseStatus, (long)data.ConnectionErrorCode),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ public async Task ConnectWithCertificate_MissingTargetHost_Succeeds()
return true;
};

await CreateQuicConnection(clientOptions);
await using QuicConnection connection = await CreateQuicConnection(clientOptions);
}
finally
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,67 @@ await RunClientServer(
});
}

[Theory]
[InlineData(true)]
[InlineData(false)]
[InlineData(null)]
public async Task CloseAsync_PendingOpenStream_Throws(bool? localClose)
{
byte[] data = new byte[10];

await using QuicListener listener = await CreateQuicListener(changeServerOptions: localClose is null ? options => options.IdleTimeout = TimeSpan.FromSeconds(10) : null);

// Allow client to accept a stream, one will be accepted and another will be pending while we close the server connection.
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions(listener.LocalEndPoint);
clientOptions.MaxInboundBidirectionalStreams = 1;
await using QuicConnection clientConnection = await CreateQuicConnection(clientOptions);

await using QuicConnection serverConnection = await listener.AcceptConnectionAsync();

// Put one stream into server stream queue.
QuicStream queuedStream = await clientConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await queuedStream.WriteAsync(data.AsMemory(), completeWrites: true);

// Open one stream to the client that is allowed.
QuicStream firstStream = await serverConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
await firstStream.WriteAsync(data.AsMemory(), completeWrites: true);

// Try to open another stream which should wait on capacity.
ValueTask<QuicStream> secondStreamTask = serverConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional);
Assert.False(secondStreamTask.IsCompleted);

// Close the connection, second stream task should complete with appropriate error.
if (localClose is true)
{
await serverConnection.CloseAsync(123);
await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await secondStreamTask);

// Try to open yet another stream which should fail because of already closed connection.
await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, async () => await serverConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional));
}
else if (localClose is false)
{
await clientConnection.CloseAsync(456);
QuicException ex1 = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () => await secondStreamTask);
Assert.Equal(456, ex1.ApplicationErrorCode);

// Try to open yet another stream which should fail because of already closed connection.
QuicException ex2 = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionAborted, async () => await serverConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional));
Assert.Equal(456, ex2.ApplicationErrorCode);
}
else
{
await Task.Delay(TimeSpan.FromSeconds(15));

QuicException ex1 = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionIdle, async () => await secondStreamTask);
Assert.Equal(1, ex1.TransportErrorCode);

// Try to open yet another stream which should fail because of already closed connection.
QuicException ex2 = await AssertThrowsQuicExceptionAsync(QuicError.ConnectionIdle, async () => await serverConnection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional));
Assert.Equal(1, ex2.TransportErrorCode);
}
}

[Fact]
public async Task Dispose_WithPendingAcceptAndConnect_PendingAndSubsequentThrowOperationAbortedException()
{
Expand Down Expand Up @@ -228,6 +289,9 @@ public async Task GetStreamCapacity_OpenCloseStream_CountsCorrectly()
await streamsAvailableFired.WaitAsync();
Assert.Equal(0, bidiIncrement);
Assert.Equal(1, unidiIncrement);

await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
}

[Theory]
Expand Down Expand Up @@ -298,6 +362,9 @@ public async Task GetStreamCapacity_OpenCloseStreamIntoNegative_CountsCorrectly(
Assert.False(streamsAvailableFired.CurrentCount > 0);
Assert.Equal(unidirectional ? QuicDefaults.DefaultServerMaxInboundBidirectionalStreams : QuicDefaults.DefaultServerMaxInboundBidirectionalStreams * 3, bidiTotal);
Assert.Equal(unidirectional ? QuicDefaults.DefaultServerMaxInboundUnidirectionalStreams * 3 : QuicDefaults.DefaultServerMaxInboundUnidirectionalStreams, unidiTotal);

await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
}

[Theory]
Expand Down Expand Up @@ -368,6 +435,9 @@ public async Task GetStreamCapacity_OpenCloseStreamCanceledIntoNegative_CountsCo
Assert.False(streamsAvailableFired.CurrentCount > 0);
Assert.Equal(unidirectional ? QuicDefaults.DefaultServerMaxInboundBidirectionalStreams : QuicDefaults.DefaultServerMaxInboundBidirectionalStreams * 3, bidiTotal);
Assert.Equal(unidirectional ? QuicDefaults.DefaultServerMaxInboundUnidirectionalStreams * 3 : QuicDefaults.DefaultServerMaxInboundUnidirectionalStreams, unidiTotal);

await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
}

[Fact]
Expand Down Expand Up @@ -434,6 +504,9 @@ public async Task GetStreamCapacity_SumInvariant()

// by now, we opened and closed 2 * Limit, and expect a budget of 'Limit' more
Assert.Equal(3 * Limit, maxStreamIndex);

await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
}

[Fact]
Expand Down Expand Up @@ -634,6 +707,8 @@ public async Task AcceptStreamAsync_ConnectionDisposed_Throws()

var accept1Exception = await Assert.ThrowsAsync<ObjectDisposedException>(async () => await acceptTask1);
var accept2Exception = await Assert.ThrowsAsync<ObjectDisposedException>(async () => await acceptTask2);

await clientConnection.DisposeAsync();
}

[Theory]
Expand Down
Loading

0 comments on commit 817a2fb

Please sign in to comment.