diff --git a/reactor-netty-core/src/main/java/reactor/netty/tcp/SniProvider.java b/reactor-netty-core/src/main/java/reactor/netty/tcp/SniProvider.java index 8ce994b25d..ca628f25ab 100644 --- a/reactor-netty-core/src/main/java/reactor/netty/tcp/SniProvider.java +++ b/reactor-netty-core/src/main/java/reactor/netty/tcp/SniProvider.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2020-2022 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,14 +15,19 @@ */ package reactor.netty.tcp; -import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; -import io.netty.handler.ssl.SniHandler; -import io.netty.handler.ssl.SslContext; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.ssl.AbstractSniHandler; import io.netty.handler.ssl.SslHandler; +import io.netty.util.AsyncMapping; import io.netty.util.DomainWildcardMappingBuilder; import io.netty.util.Mapping; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.PlatformDependent; import reactor.netty.NettyPipeline; import java.util.Map; @@ -52,49 +57,81 @@ void addSniHandler(Channel channel, boolean sslDebug) { SslProvider.addSslReadHandler(pipeline, sslDebug); } - final Map confPerDomainName; - final SslProvider defaultSslProvider; + final AsyncMapping mappings; - SniProvider(Map confPerDomainName, SslProvider defaultSslProvider) { - this.confPerDomainName = confPerDomainName; - this.defaultSslProvider = defaultSslProvider; + SniProvider(AsyncMapping mappings) { + this.mappings = mappings; } - SniHandler newSniHandler() { - DomainWildcardMappingBuilder mappingsContextBuilder = - new DomainWildcardMappingBuilder<>(defaultSslProvider.getSslContext()); - confPerDomainName.forEach((s, sslProvider) -> mappingsContextBuilder.add(s, sslProvider.getSslContext())); + SniProvider(Map confPerDomainName, SslProvider defaultSslProvider) { DomainWildcardMappingBuilder mappingsSslProviderBuilder = new DomainWildcardMappingBuilder<>(defaultSslProvider); confPerDomainName.forEach(mappingsSslProviderBuilder::add); - return new AdvancedSniHandler(mappingsSslProviderBuilder.build(), defaultSslProvider, mappingsContextBuilder.build()); + this.mappings = new AsyncMappingAdapter(mappingsSslProviderBuilder.build()); + } + + SniHandler newSniHandler() { + return new SniHandler(mappings); + } + + static final class AsyncMappingAdapter implements AsyncMapping { + + final Mapping mapping; + + AsyncMappingAdapter(Mapping mapping) { + this.mapping = mapping; + } + + @Override + public Future map(String input, Promise promise) { + try { + return promise.setSuccess(mapping.map(input)); + } + catch (Throwable cause) { + return promise.setFailure(cause); + } + } } - static final class AdvancedSniHandler extends SniHandler { + static final class SniHandler extends AbstractSniHandler { - final Mapping confPerDomainName; - final SslProvider defaultSslProvider; + final AsyncMapping mappings; - AdvancedSniHandler( - Mapping confPerDomainName, - SslProvider defaultSslProvider, - Mapping mappings) { - super(mappings); - this.confPerDomainName = confPerDomainName; - this.defaultSslProvider = defaultSslProvider; + SniHandler(AsyncMapping mappings) { + this.mappings = mappings; } @Override - protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) { - SslHandler sslHandler = super.newSslHandler(context, allocator); - String hostName = hostname(); - if (hostName == null) { - defaultSslProvider.configure(sslHandler); + protected Future lookup(ChannelHandlerContext ctx, String hostname) { + return mappings.map(hostname, ctx.executor().newPromise()); + } + + @Override + protected void onLookupComplete(ChannelHandlerContext ctx, String hostname, Future future) { + if (!future.isSuccess()) { + final Throwable cause = future.cause(); + if (cause instanceof Error) { + throw (Error) cause; + } + throw new DecoderException("failed to get the SslContext for " + hostname, cause); + } + + SslProvider sslProvider = future.getNow(); + SslHandler sslHandler = null; + try { + sslHandler = sslProvider.getSslContext().newHandler(ctx.alloc()); + sslProvider.configure(sslHandler); + ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler); + sslHandler = null; + } + catch (Throwable cause) { + PlatformDependent.throwException(cause); } - else { - confPerDomainName.map(hostname()).configure(sslHandler); + finally { + if (sslHandler != null) { + ReferenceCountUtil.safeRelease(sslHandler.engine()); + } } - return sslHandler; } } } diff --git a/reactor-netty-core/src/main/java/reactor/netty/tcp/SslProvider.java b/reactor-netty-core/src/main/java/reactor/netty/tcp/SslProvider.java index 5443f418d9..fd4ff1d8aa 100644 --- a/reactor-netty-core/src/main/java/reactor/netty/tcp/SslProvider.java +++ b/reactor-netty-core/src/main/java/reactor/netty/tcp/SslProvider.java @@ -46,6 +46,7 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.handler.ssl.SupportedCipherSuiteFilter; +import io.netty.util.AsyncMapping; import reactor.core.Exceptions; import reactor.netty.NettyPipeline; import reactor.netty.ReactorNetty; @@ -176,7 +177,9 @@ public interface Builder { /** * Adds a mapping for the given domain name to an {@link SslProvider} builder. * If a mapping already exists, it will be overridden. - * Note: This configuration is applicable only when configuring the server. + *

Note: This method is a sync alternative of {@link #setSniAsyncMappings(AsyncMapping)}, + * which removes the async mappings. + *

Note: This configuration is applicable only when configuring the server. * * @param domainName the domain name, it may contain wildcard * @param sslProviderBuilder an {@link SslProvider} builder for building the {@link SslProvider} @@ -187,7 +190,9 @@ public interface Builder { /** * Adds the provided mappings of domain names to {@link SslProvider} builders to the existing mappings. * If a mapping already exists, it will be overridden. - * Note: This configuration is applicable only when configuring the server. + *

Note: This method is a sync alternative of {@link #setSniAsyncMappings(AsyncMapping)}, + * which removes the async mappings. + *

Note: This configuration is applicable only when configuring the server. * * @param confPerDomainName mappings of domain names to {@link SslProvider} builders * @return {@literal this} @@ -197,13 +202,27 @@ public interface Builder { /** * Sets the provided mappings of domain names to {@link SslProvider} builders. * The existing mappings will be removed. - * Note: This configuration is applicable only when configuring the server. + *

Note: This method is a sync alternative of {@link #setSniAsyncMappings(AsyncMapping)}, + * which removes the async mappings. + *

Note: This configuration is applicable only when configuring the server. * * @param confPerDomainName mappings of domain names to {@link SslProvider} builders * @return {@literal this} */ Builder setSniMappings(Map> confPerDomainName); + /** + * Sets the provided mappings of domain names to {@link SslProvider}. + *

Note: This method is an alternative of {@link #addSniMapping(String, Consumer)}, + * {@link #addSniMappings(Map)} and {@link #setSniMappings(Map)}. + *

Note: This configuration is applicable only when configuring the server. + * + * @param mappings mappings of domain names to {@link SslProvider} + * @return {@literal this} + * @since 1.0.19 + */ + Builder setSniAsyncMappings(AsyncMapping mappings); + /** * Sets the desired {@link SNIServerName}s. * Note: This configuration is applicable only when configuring the client. @@ -337,6 +356,8 @@ public interface ProtocolSslContextSpec { final Consumer handlerConfigurator; final int builderHashCode; final SniProvider sniProvider; + final Map confPerDomainName; + final AsyncMapping sniMappings; SslProvider(SslProvider.Build builder) { this.sslContextBuilder = builder.sslCtxBuilder; @@ -386,14 +407,19 @@ else if (builder.protocolSslContextSpec != null) { this.closeNotifyFlushTimeoutMillis = builder.closeNotifyFlushTimeoutMillis; this.closeNotifyReadTimeoutMillis = builder.closeNotifyReadTimeoutMillis; this.builderHashCode = builder.hashCode(); - if (!builder.confPerDomainName.isEmpty()) { + this.confPerDomainName = builder.confPerDomainName; + this.sniMappings = builder.sniMappings; + if (!confPerDomainName.isEmpty()) { if (this.type != null) { - this.sniProvider = updateAllSslProviderConfiguration(builder.confPerDomainName, this, type); + this.sniProvider = updateAllSslProviderConfiguration(confPerDomainName, this, type); } else { - this.sniProvider = new SniProvider(builder.confPerDomainName, this); + this.sniProvider = new SniProvider(confPerDomainName, this); } } + else if (sniMappings != null) { + this.sniProvider = new SniProvider(sniMappings); + } else { this.sniProvider = null; } @@ -416,6 +442,8 @@ else if (builder.protocolSslContextSpec != null) { this.closeNotifyFlushTimeoutMillis = from.closeNotifyFlushTimeoutMillis; this.closeNotifyReadTimeoutMillis = from.closeNotifyReadTimeoutMillis; this.builderHashCode = from.builderHashCode; + this.confPerDomainName = from.confPerDomainName; + this.sniMappings = from.sniMappings; this.sniProvider = from.sniProvider; } @@ -439,8 +467,15 @@ else if (builder.protocolSslContextSpec != null) { this.closeNotifyFlushTimeoutMillis = from.closeNotifyFlushTimeoutMillis; this.closeNotifyReadTimeoutMillis = from.closeNotifyReadTimeoutMillis; this.builderHashCode = from.builderHashCode; + this.confPerDomainName = from.confPerDomainName; + this.sniMappings = from.sniMappings; if (from.sniProvider != null) { - this.sniProvider = updateAllSslProviderConfiguration(from.sniProvider.confPerDomainName, this, type); + if (!confPerDomainName.isEmpty()) { + this.sniProvider = updateAllSslProviderConfiguration(confPerDomainName, this, type); + } + else { + this.sniProvider = new SniProvider(sniMappings); + } } else { this.sniProvider = null; @@ -613,6 +648,7 @@ static final class Build implements SslContextSpec, DefaultConfigurationSpec, Bu long closeNotifyReadTimeoutMillis; List serverNames; final Map confPerDomainName = new HashMap<>(); + AsyncMapping sniMappings; // SslContextSpec @@ -704,6 +740,7 @@ public final Builder closeNotifyReadTimeoutMillis(long closeNotifyReadTimeoutMil @Override public Builder addSniMapping(String domainName, Consumer sslProviderBuilder) { addInternal(domainName, sslProviderBuilder); + this.sniMappings = null; return this; } @@ -711,6 +748,7 @@ public Builder addSniMapping(String domainName, Consumer public Builder addSniMappings(Map> confPerDomainName) { Objects.requireNonNull(confPerDomainName); confPerDomainName.forEach(this::addInternal); + this.sniMappings = null; return this; } @@ -719,6 +757,14 @@ public Builder setSniMappings(Map> conf Objects.requireNonNull(confPerDomainName); this.confPerDomainName.clear(); confPerDomainName.forEach(this::addInternal); + this.sniMappings = null; + return this; + } + + @Override + public Builder setSniAsyncMappings(AsyncMapping mappings) { + this.sniMappings = Objects.requireNonNull(mappings); + this.confPerDomainName.clear(); return this; } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java index a2d7ca798a..440cef2488 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java @@ -117,6 +117,7 @@ import reactor.netty.http.client.PrematureCloseException; import reactor.netty.resources.ConnectionProvider; import reactor.netty.resources.LoopResources; +import reactor.netty.tcp.SslProvider; import reactor.netty.tcp.TcpClient; import reactor.netty.tcp.TcpServer; import reactor.netty.transport.TransportConfig; @@ -1976,6 +1977,53 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { assertThat(hostname.get()).isEqualTo("test.com"); } + @Test + void testSniSupportAsyncMappings() throws Exception { + SelfSignedCertificate defaultCert = new SelfSignedCertificate("default"); + Http11SslContextSpec defaultSslContextBuilder = + Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()); + + SelfSignedCertificate testCert = new SelfSignedCertificate("test.com"); + Http11SslContextSpec testSslContextBuilder = + Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey()); + SslProvider testSslProvider = SslProvider.builder().sslContext(testSslContextBuilder).build(); + + Http11SslContextSpec clientSslContextBuilder = + Http11SslContextSpec.forClient() + .configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)); + + AtomicReference hostname = new AtomicReference<>(); + disposableServer = + createServer() + .secure(spec -> spec.sslContext(defaultSslContextBuilder) + .setSniAsyncMappings((input, promise) -> promise.setSuccess(testSslProvider))) + .doOnChannelInit((obs, channel, remoteAddress) -> + channel.pipeline() + .addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + hostname.set(((SniCompletionEvent) evt).hostname()); + } + ctx.fireUserEventTriggered(evt); + } + })) + .handle((req, res) -> res.sendString(Mono.just("testSniSupport"))) + .bindNow(); + + createClient(disposableServer::address) + .secure(spec -> spec.sslContext(clientSslContextBuilder) + .serverNames(new SNIHostName("test.com"))) + .get() + .uri("/") + .responseContent() + .aggregate() + .block(Duration.ofSeconds(30)); + + assertThat(hostname.get()).isNotNull(); + assertThat(hostname.get()).isEqualTo("test.com"); + } + @Test void testIssue1286_HTTP11() throws Exception { doTestIssue1286(Function.identity(), Function.identity(), false, false); diff --git a/reactor-netty-http/src/test/java/reactor/netty/tcp/SslProviderTests.java b/reactor-netty-http/src/test/java/reactor/netty/tcp/SslProviderTests.java index d5d95c8346..7ed7afb16e 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/tcp/SslProviderTests.java +++ b/reactor-netty-http/src/test/java/reactor/netty/tcp/SslProviderTests.java @@ -30,8 +30,7 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; -import io.netty.util.DomainWildcardMappingBuilder; -import io.netty.util.Mapping; +import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -245,12 +244,14 @@ void testAdd() { .addSniMapping("localhost", spec -> spec.sslContext(localhostSslContext)); SniProvider provider = builder.build().sniProvider; - assertThat(mappings(provider).map("localhost")).isSameAs(localhostSslContext); + assertThat(provider.mappings.map("localhost", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(localhostSslContext); provider = builder.addSniMapping("localhost", spec -> spec.sslContext(anotherSslContext)) .build() .sniProvider; - assertThat(mappings(provider).map("localhost")).isSameAs(anotherSslContext); + assertThat(provider.mappings.map("localhost", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(anotherSslContext); } @Test @@ -277,13 +278,16 @@ void testAddAll() { .addSniMappings(map); SniProvider provider = builder.build().sniProvider; - assertThat(mappings(provider).map("localhost")).isSameAs(localhostSslContext); + assertThat(provider.mappings.map("localhost", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(localhostSslContext); map.put("another", spec -> spec.sslContext(anotherSslContext)); provider = builder.addSniMappings(map).build().sniProvider; - assertThat(mappings(provider).map("localhost")).isSameAs(localhostSslContext); - assertThat(mappings(provider).map("another")).isSameAs(anotherSslContext); + assertThat(provider.mappings.map("localhost", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(localhostSslContext); + assertThat(provider.mappings.map("another", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(anotherSslContext); } @Test @@ -306,14 +310,17 @@ void testSetAll() throws Exception { .setSniMappings(map); SniProvider provider = builder.build().sniProvider; - assertThat(mappings(provider).map("localhost")).isSameAs(localhostSslContext); + assertThat(provider.mappings.map("localhost", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(localhostSslContext); map.clear(); map.put("another", spec -> spec.sslContext(anotherSslContext)); provider = builder.setSniMappings(map).build().sniProvider; - assertThat(mappings(provider).map("localhost")).isSameAs(defaultSslContext); - assertThat(mappings(provider).map("another")).isSameAs(anotherSslContext); + assertThat(provider.mappings.map("localhost", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(defaultSslContext); + assertThat(provider.mappings.map("another", GlobalEventExecutor.INSTANCE.newPromise()).getNow().sslContext) + .isSameAs(anotherSslContext); } @Test @@ -324,6 +331,14 @@ void testSetAllBadValues() { .setSniMappings(null)); } + @Test + void testSetSniAsyncMappingsBadValues() { + assertThatExceptionOfType(NullPointerException.class) + .isThrownBy(() -> SslProvider.builder() + .sslContext(serverSslContextBuilder) + .setSniAsyncMappings(null)); + } + @Test void testServerNames() throws Exception { SslContext defaultSslContext = clientSslContextBuilder.sslContext(); @@ -351,11 +366,4 @@ void testServerNamesBadValues() throws Exception { .sslContext(defaultSslContext) .serverNames((SNIServerName[]) null)); } - - static Mapping mappings(SniProvider provider) { - DomainWildcardMappingBuilder mappingsBuilder = - new DomainWildcardMappingBuilder<>(provider.defaultSslProvider.getSslContext()); - provider.confPerDomainName.forEach((s, sslProvider) -> mappingsBuilder.add(s, sslProvider.getSslContext())); - return mappingsBuilder.build(); - } }