diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index 97e8478fa7487e..84d6ff8b345bba 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -33,7 +33,7 @@ public CertificateCallbackMapper(Func EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken) + private static SslClientAuthenticationOptions SetUpRemoteCertificateValidationCallback(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request) { // If there's a cert validation callback, and if it came from HttpClientHandler, // wrap the original delegate in order to change the sender to be the request message (expected by HttpClientHandler's delegate). @@ -52,12 +52,13 @@ public static ValueTask EstablishSslConnectionAsync(SslClientAuthenti }; } - // Create the SslStream, authenticate, and return it. - return EstablishSslConnectionAsyncCore(async, stream, sslOptions, cancellationToken); + return sslOptions; } - private static async ValueTask EstablishSslConnectionAsyncCore(bool async, Stream stream, SslClientAuthenticationOptions sslOptions, CancellationToken cancellationToken) + public static async ValueTask EstablishSslConnectionAsync(SslClientAuthenticationOptions sslOptions, HttpRequestMessage request, bool async, Stream stream, CancellationToken cancellationToken) { + sslOptions = SetUpRemoteCertificateValidationCallback(sslOptions, request); + SslStream sslStream = new SslStream(stream); try @@ -104,8 +105,10 @@ private static async ValueTask EstablishSslConnectionAsyncCore(bool a [SupportedOSPlatform("windows")] [SupportedOSPlatform("linux")] [SupportedOSPlatform("macos")] - public static async ValueTask ConnectQuicAsync(QuicImplementationProvider quicImplementationProvider, DnsEndPoint endPoint, SslClientAuthenticationOptions? clientAuthenticationOptions, CancellationToken cancellationToken) + public static async ValueTask ConnectQuicAsync(HttpRequestMessage request, QuicImplementationProvider quicImplementationProvider, DnsEndPoint endPoint, SslClientAuthenticationOptions clientAuthenticationOptions, CancellationToken cancellationToken) { + clientAuthenticationOptions = SetUpRemoteCertificateValidationCallback(clientAuthenticationOptions, request); + QuicConnection con = new QuicConnection(quicImplementationProvider, endPoint, clientAuthenticationOptions); try { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index e84f21dce3b460..480724dd8eb5da 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -725,7 +725,7 @@ private async ValueTask GetHttp3ConnectionAsync(HttpRequestMess QuicConnection quicConnection; try { - quicConnection = await ConnectHelper.ConnectQuicAsync(Settings._quicImplementationProvider ?? QuicImplementationProviders.Default, new DnsEndPoint(authority.IdnHost, authority.Port), _sslOptionsHttp3, cancellationToken).ConfigureAwait(false); + quicConnection = await ConnectHelper.ConnectQuicAsync(request, Settings._quicImplementationProvider ?? QuicImplementationProviders.Default, new DnsEndPoint(authority.IdnHost, authority.Port), _sslOptionsHttp3!, cancellationToken).ConfigureAwait(false); } catch { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index d4d806f6881361..b3551199a2994e 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -314,6 +314,64 @@ public async Task ReservedFrameType_Throws() await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000); } + [Fact] + public async Task ServerCertificateCustomValidationCallback_Succeeds() + { + // Mock doesn't make use of cart validation callback. + if (UseQuicImplementationProvider == QuicImplementationProviders.Mock) + { + return; + } + + HttpRequestMessage? callbackRequest = null; + int invocationCount = 0; + + var httpClientHandler = CreateHttpClientHandler(); + httpClientHandler.ServerCertificateCustomValidationCallback = (request, _, _, _) => + { + callbackRequest = request; + ++invocationCount; + return true; + }; + + using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + using HttpClient client = CreateHttpClient(httpClientHandler); + + Task serverTask = Task.Run(async () => + { + using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + await stream.HandleRequestAsync(); + using Http3LoopbackStream stream2 = await connection.AcceptRequestStreamAsync(); + await stream2.HandleRequestAsync(); + }); + + var request = new HttpRequestMessage(HttpMethod.Get, server.Address); + request.Version = HttpVersion.Version30; + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var response = await client.SendAsync(request); + + response.EnsureSuccessStatusCode(); + Assert.Equal(HttpVersion.Version30, response.Version); + Assert.Same(request, callbackRequest); + Assert.Equal(1, invocationCount); + + // Second request, the callback shouldn't be hit at all. + callbackRequest = null; + + request = new HttpRequestMessage(HttpMethod.Get, server.Address); + request.Version = HttpVersion.Version30; + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + response = await client.SendAsync(request); + + response.EnsureSuccessStatusCode(); + Assert.Equal(HttpVersion.Version30, response.Version); + Assert.Null(callbackRequest); + Assert.Equal(1, invocationCount); + } + [OuterLoop] [ConditionalTheory(nameof(IsMsQuicSupported))] [MemberData(nameof(InteropUris))] diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index ec1710316c1259..3eea4d19c692bb 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -418,6 +418,10 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti if (connection._remoteCertificateValidationCallback != null) { bool success = connection._remoteCertificateValidationCallback(connection, certificate, chain, sslPolicyErrors); + // Unset the callback to prevent multiple invocations of the callback per a single connection. + // Return the same value as the custom callback just did. + connection._remoteCertificateValidationCallback = (_, _, _, _) => success; + if (!success && NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"{state.TraceId} Remote certificate rejected by verification callback"); return success ? MsQuicStatusCodes.Success : MsQuicStatusCodes.HandshakeFailure;