diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs index 50f736d429f7f6..3c1118e73a4a6d 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/MsQuicStatusCodes.cs @@ -10,6 +10,7 @@ internal static class MsQuicStatusCodes internal static uint InternalError => OperatingSystem.IsWindows() ? Windows.InternalError : Posix.InternalError; internal static uint InvalidState => OperatingSystem.IsWindows() ? Windows.InvalidState : Posix.InvalidState; internal static uint HandshakeFailure => OperatingSystem.IsWindows() ? Windows.HandshakeFailure : Posix.HandshakeFailure; + internal static uint UserCanceled => OperatingSystem.IsWindows() ? Windows.UserCanceled : Posix.UserCanceled; // TODO return better error messages here. public static string GetError(uint status) => OperatingSystem.IsWindows() ? Windows.GetError(status) : Posix.GetError(status); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index 5e98e27654df01..6324acd8dd133c 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -29,8 +29,7 @@ internal sealed class MsQuicConnection : QuicConnectionProvider private readonly SafeMsQuicConfigurationHandle? _configuration; private readonly State _state = new State(); - private GCHandle _stateHandle; - private bool _disposed; + private int _disposed; private IPEndPoint? _localEndPoint; private readonly EndPoint _remoteEndPoint; @@ -43,6 +42,7 @@ internal sealed class MsQuicConnection : QuicConnectionProvider internal sealed class State { public SafeMsQuicConnectionHandle Handle = null!; // set inside of MsQuicConnection ctor. + public GCHandle StateGCHandle; // These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown). public MsQuicConnection? Connection; @@ -59,6 +59,8 @@ internal sealed class State public bool Connected; public long AbortErrorCode = -1; + public int StreamCount; + private bool _closing; // Queue for accepted streams. // Backlog limit is managed by MsQuic so it can be unbounded here. @@ -67,30 +69,83 @@ internal sealed class State SingleReader = true, SingleWriter = true, }); + + public void RemoveStream(MsQuicStream stream) + { + bool releaseHandles; + lock (this) + { + StreamCount--; + Debug.Assert(StreamCount >= 0); + releaseHandles = _closing && StreamCount == 0; + } + + if (releaseHandles) + { + Handle?.Dispose(); + StateGCHandle.Free(); + } + } + + public bool TryQueueNewStream(SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags) + { + var stream = new MsQuicStream(this, streamHandle, flags); + if (AcceptQueue.Writer.TryWrite(stream)) + { + return true; + } + else + { + stream.Dispose(); + return false; + } + } + + public bool TryAddStream(MsQuicStream stream) + { + lock (this) + { + if (_closing) + { + return false; + } + + StreamCount++; + return true; + } + } + + // This is called under lock from connection dispose + public void SetClosing() + { + lock (this) + { + _closing = true; + } + } } // constructor for inbound connections public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle) { _state.Handle = handle; + _state.StateGCHandle = GCHandle.Alloc(_state); _state.Connected = true; _localEndPoint = localEndPoint; _remoteEndPoint = remoteEndPoint; _remoteCertificateRequired = false; _isServer = true; - _stateHandle = GCHandle.Alloc(_state); - try { MsQuicApi.Api.SetCallbackHandlerDelegate( _state.Handle, s_connectionDelegate, - GCHandle.ToIntPtr(_stateHandle)); + GCHandle.ToIntPtr(_state.StateGCHandle)); } catch { - _stateHandle.Free(); + _state.StateGCHandle.Free(); throw; } @@ -113,7 +168,7 @@ public MsQuicConnection(QuicClientConnectionOptions options) _remoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback; } - _stateHandle = GCHandle.Alloc(_state); + _state.StateGCHandle = GCHandle.Alloc(_state); try { // this handle is ref counted by MsQuic, so safe to dispose here. @@ -122,14 +177,14 @@ public MsQuicConnection(QuicClientConnectionOptions options) uint status = MsQuicApi.Api.ConnectionOpenDelegate( MsQuicApi.Api.Registration, s_connectionDelegate, - GCHandle.ToIntPtr(_stateHandle), + GCHandle.ToIntPtr(_state.StateGCHandle), out _state.Handle); QuicExceptionHelpers.ThrowIfFailed(status, "Could not open the connection."); } catch { - _stateHandle.Free(); + _state.StateGCHandle.Free(); throw; } @@ -224,9 +279,13 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent private static uint HandleEventNewStream(State state, ref ConnectionEvent connectionEvent) { var streamHandle = new SafeMsQuicStreamHandle(connectionEvent.Data.PeerStreamStarted.Stream); - var stream = new MsQuicStream(state, streamHandle, connectionEvent.Data.PeerStreamStarted.Flags); + if (!state.TryQueueNewStream(streamHandle, connectionEvent.Data.PeerStreamStarted.Flags)) + { + // This will call StreamCloseDelegate and free the stream. + // We will return Success to the MsQuic to prevent double free. + streamHandle.Dispose(); + } - state.AcceptQueue.Writer.TryWrite(stream); return MsQuicStatusCodes.Success; } @@ -598,17 +657,45 @@ public override void Dispose() Dispose(false); } + private async Task FlushAcceptQueue() + { + _state.AcceptQueue.Writer.TryComplete(); + await foreach (MsQuicStream item in _state.AcceptQueue.Reader.ReadAllAsync().ConfigureAwait(false)) + { + item.Dispose(); + } + } + private void Dispose(bool disposing) { - if (_disposed) + int disposed = Interlocked.Exchange(ref _disposed, 1); + if (disposed != 0) { return; } + bool releaseHandles = false; + lock (_state) + { + _state.Connection = null; + if (_state.StreamCount == 0) + { + releaseHandles = true; + } + else + { + // We have pending streams so we need to defer cleanup until last one is gone. + _state.SetClosing(); + } + } + + FlushAcceptQueue().GetAwaiter().GetResult(); _configuration?.Dispose(); - _state?.Handle?.Dispose(); - if (_stateHandle.IsAllocated) _stateHandle.Free(); - _disposed = true; + if (releaseHandles) + { + _state!.Handle?.Dispose(); + if (_state.StateGCHandle.IsAllocated) _state.StateGCHandle.Free(); + } } // TODO: this appears abortive and will cause prior successfully shutdown and closed streams to drop data. @@ -622,7 +709,7 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell private void ThrowIfDisposed() { - if (_disposed) + if (_disposed == 1) { throw new ObjectDisposedException(nameof(MsQuicStream)); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index f6c768d9d4d4ef..2813f860318e4f 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -93,6 +93,12 @@ internal MsQuicStream(MsQuicConnection.State connectionState, SafeMsQuicStreamHa throw; } + if (!connectionState.TryAddStream(this)) + { + _stateHandle.Free(); + throw new ObjectDisposedException(nameof(QuicConnection)); + } + if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info( @@ -133,6 +139,13 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F throw; } + if (!connectionState.TryAddStream(this)) + { + _state.Handle?.Dispose(); + _stateHandle.Free(); + throw new ObjectDisposedException(nameof(QuicConnection)); + } + if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info( @@ -321,7 +334,6 @@ internal override async ValueTask ReadAsync(Memory destination, Cance { shouldComplete = true; } - state.ReadState = ReadState.Aborted; } @@ -557,6 +569,8 @@ private void Dispose(bool disposing) Marshal.FreeHGlobal(_state.SendQuicBuffers); if (_stateHandle.IsAllocated) _stateHandle.Free(); CleanupSendState(_state); + Debug.Assert(_state.ConnectionState != null); + _state.ConnectionState?.RemoveStream(this); if (NetEventSource.Log.IsEnabled()) { diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 443f759d0fbe85..58fd8954f6582d 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -442,7 +442,6 @@ await RunClientServer( await (new[] { t1, t2 }).WhenAllOrAnyFailed(millisecondsTimeout: 1000000); } - [ActiveIssue("https://github.com/dotnet/runtime/issues/52048")] [Fact] public async Task ManagedAVE_MinimalFailingTest() { @@ -461,6 +460,32 @@ async Task GetStreamIdWithoutStartWorks() // TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer } + await GetStreamIdWithoutStartWorks().WaitAsync(TimeSpan.FromSeconds(15)); + + GC.Collect(); + } + + [Fact] + public async Task DisposingConnection_OK() + { + async Task GetStreamIdWithoutStartWorks() + { + using QuicListener listener = CreateQuicListener(); + using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); + + ValueTask clientTask = clientConnection.ConnectAsync(); + using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); + await clientTask; + + using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); + Assert.Equal(0, clientStream.StreamId); + + // Dispose all connections before the streams; + clientConnection.Dispose(); + serverConnection.Dispose(); + listener.Dispose(); + } + await GetStreamIdWithoutStartWorks(); GC.Collect();