Skip to content

Commit

Permalink
Add API for providing SNI AsyncMapping (#2172)
Browse files Browse the repository at this point in the history
Fixes #2156
  • Loading branch information
violetagg authored May 3, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 44bafbc commit 0d440cb
Showing 4 changed files with 195 additions and 56 deletions.
101 changes: 69 additions & 32 deletions reactor-netty-core/src/main/java/reactor/netty/tcp/SniProvider.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021 VMware, Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2020-2022 VMware, Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,14 +15,19 @@
*/
package reactor.netty.tcp;

import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.ssl.AbstractSniHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AsyncMapping;
import io.netty.util.DomainWildcardMappingBuilder;
import io.netty.util.Mapping;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.PlatformDependent;
import reactor.netty.NettyPipeline;

import java.util.Map;
@@ -52,49 +57,81 @@ void addSniHandler(Channel channel, boolean sslDebug) {
SslProvider.addSslReadHandler(pipeline, sslDebug);
}

final Map<String, SslProvider> confPerDomainName;
final SslProvider defaultSslProvider;
final AsyncMapping<String, SslProvider> mappings;

SniProvider(Map<String, SslProvider> confPerDomainName, SslProvider defaultSslProvider) {
this.confPerDomainName = confPerDomainName;
this.defaultSslProvider = defaultSslProvider;
SniProvider(AsyncMapping<String, SslProvider> mappings) {
this.mappings = mappings;
}

SniHandler newSniHandler() {
DomainWildcardMappingBuilder<SslContext> mappingsContextBuilder =
new DomainWildcardMappingBuilder<>(defaultSslProvider.getSslContext());
confPerDomainName.forEach((s, sslProvider) -> mappingsContextBuilder.add(s, sslProvider.getSslContext()));
SniProvider(Map<String, SslProvider> confPerDomainName, SslProvider defaultSslProvider) {
DomainWildcardMappingBuilder<SslProvider> mappingsSslProviderBuilder =
new DomainWildcardMappingBuilder<>(defaultSslProvider);
confPerDomainName.forEach(mappingsSslProviderBuilder::add);
return new AdvancedSniHandler(mappingsSslProviderBuilder.build(), defaultSslProvider, mappingsContextBuilder.build());
this.mappings = new AsyncMappingAdapter(mappingsSslProviderBuilder.build());
}

SniHandler newSniHandler() {
return new SniHandler(mappings);
}

static final class AsyncMappingAdapter implements AsyncMapping<String, SslProvider> {

final Mapping<String, SslProvider> mapping;

AsyncMappingAdapter(Mapping<String, SslProvider> mapping) {
this.mapping = mapping;
}

@Override
public Future<SslProvider> map(String input, Promise<SslProvider> promise) {
try {
return promise.setSuccess(mapping.map(input));
}
catch (Throwable cause) {
return promise.setFailure(cause);
}
}
}

static final class AdvancedSniHandler extends SniHandler {
static final class SniHandler extends AbstractSniHandler<SslProvider> {

final Mapping<? super String, ? extends SslProvider> confPerDomainName;
final SslProvider defaultSslProvider;
final AsyncMapping<String, SslProvider> mappings;

AdvancedSniHandler(
Mapping<? super String, ? extends SslProvider> confPerDomainName,
SslProvider defaultSslProvider,
Mapping<? super String, ? extends SslContext> mappings) {
super(mappings);
this.confPerDomainName = confPerDomainName;
this.defaultSslProvider = defaultSslProvider;
SniHandler(AsyncMapping<String, SslProvider> mappings) {
this.mappings = mappings;
}

@Override
protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) {
SslHandler sslHandler = super.newSslHandler(context, allocator);
String hostName = hostname();
if (hostName == null) {
defaultSslProvider.configure(sslHandler);
protected Future<SslProvider> lookup(ChannelHandlerContext ctx, String hostname) {
return mappings.map(hostname, ctx.executor().newPromise());
}

@Override
protected void onLookupComplete(ChannelHandlerContext ctx, String hostname, Future<SslProvider> future) {
if (!future.isSuccess()) {
final Throwable cause = future.cause();
if (cause instanceof Error) {
throw (Error) cause;
}
throw new DecoderException("failed to get the SslContext for " + hostname, cause);
}

SslProvider sslProvider = future.getNow();
SslHandler sslHandler = null;
try {
sslHandler = sslProvider.getSslContext().newHandler(ctx.alloc());
sslProvider.configure(sslHandler);
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
sslHandler = null;
}
catch (Throwable cause) {
PlatformDependent.throwException(cause);
}
else {
confPerDomainName.map(hostname()).configure(sslHandler);
finally {
if (sslHandler != null) {
ReferenceCountUtil.safeRelease(sslHandler.engine());
}
}
return sslHandler;
}
}
}
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
import io.netty.util.AsyncMapping;
import reactor.core.Exceptions;
import reactor.netty.NettyPipeline;
import reactor.netty.ReactorNetty;
@@ -176,7 +177,9 @@ public interface Builder {
/**
* Adds a mapping for the given domain name to an {@link SslProvider} builder.
* If a mapping already exists, it will be overridden.
* Note: This configuration is applicable only when configuring the server.
* <p><strong>Note:</strong> This method is a sync alternative of {@link #setSniAsyncMappings(AsyncMapping)},
* which removes the async mappings.
* <p><strong>Note:</strong> This configuration is applicable only when configuring the server.
*
* @param domainName the domain name, it may contain wildcard
* @param sslProviderBuilder an {@link SslProvider} builder for building the {@link SslProvider}
@@ -187,7 +190,9 @@ public interface Builder {
/**
* Adds the provided mappings of domain names to {@link SslProvider} builders to the existing mappings.
* If a mapping already exists, it will be overridden.
* Note: This configuration is applicable only when configuring the server.
* <p><strong>Note:</strong> This method is a sync alternative of {@link #setSniAsyncMappings(AsyncMapping)},
* which removes the async mappings.
* <p><strong>Note:</strong> This configuration is applicable only when configuring the server.
*
* @param confPerDomainName mappings of domain names to {@link SslProvider} builders
* @return {@literal this}
@@ -197,13 +202,27 @@ public interface Builder {
/**
* Sets the provided mappings of domain names to {@link SslProvider} builders.
* The existing mappings will be removed.
* Note: This configuration is applicable only when configuring the server.
* <p><strong>Note:</strong> This method is a sync alternative of {@link #setSniAsyncMappings(AsyncMapping)},
* which removes the async mappings.
* <p><strong>Note:</strong> This configuration is applicable only when configuring the server.
*
* @param confPerDomainName mappings of domain names to {@link SslProvider} builders
* @return {@literal this}
*/
Builder setSniMappings(Map<String, Consumer<? super SslProvider.SslContextSpec>> confPerDomainName);

/**
* Sets the provided mappings of domain names to {@link SslProvider}.
* <p><strong>Note:</strong> This method is an alternative of {@link #addSniMapping(String, Consumer)},
* {@link #addSniMappings(Map)} and {@link #setSniMappings(Map)}.
* <p><strong>Note:</strong> This configuration is applicable only when configuring the server.
*
* @param mappings mappings of domain names to {@link SslProvider}
* @return {@literal this}
* @since 1.0.19
*/
Builder setSniAsyncMappings(AsyncMapping<String, SslProvider> mappings);

/**
* Sets the desired {@link SNIServerName}s.
* Note: This configuration is applicable only when configuring the client.
@@ -337,6 +356,8 @@ public interface ProtocolSslContextSpec {
final Consumer<? super SslHandler> handlerConfigurator;
final int builderHashCode;
final SniProvider sniProvider;
final Map<String, SslProvider> confPerDomainName;
final AsyncMapping<String, SslProvider> sniMappings;

SslProvider(SslProvider.Build builder) {
this.sslContextBuilder = builder.sslCtxBuilder;
@@ -386,14 +407,19 @@ else if (builder.protocolSslContextSpec != null) {
this.closeNotifyFlushTimeoutMillis = builder.closeNotifyFlushTimeoutMillis;
this.closeNotifyReadTimeoutMillis = builder.closeNotifyReadTimeoutMillis;
this.builderHashCode = builder.hashCode();
if (!builder.confPerDomainName.isEmpty()) {
this.confPerDomainName = builder.confPerDomainName;
this.sniMappings = builder.sniMappings;
if (!confPerDomainName.isEmpty()) {
if (this.type != null) {
this.sniProvider = updateAllSslProviderConfiguration(builder.confPerDomainName, this, type);
this.sniProvider = updateAllSslProviderConfiguration(confPerDomainName, this, type);
}
else {
this.sniProvider = new SniProvider(builder.confPerDomainName, this);
this.sniProvider = new SniProvider(confPerDomainName, this);
}
}
else if (sniMappings != null) {
this.sniProvider = new SniProvider(sniMappings);
}
else {
this.sniProvider = null;
}
@@ -416,6 +442,8 @@ else if (builder.protocolSslContextSpec != null) {
this.closeNotifyFlushTimeoutMillis = from.closeNotifyFlushTimeoutMillis;
this.closeNotifyReadTimeoutMillis = from.closeNotifyReadTimeoutMillis;
this.builderHashCode = from.builderHashCode;
this.confPerDomainName = from.confPerDomainName;
this.sniMappings = from.sniMappings;
this.sniProvider = from.sniProvider;
}

@@ -439,8 +467,15 @@ else if (builder.protocolSslContextSpec != null) {
this.closeNotifyFlushTimeoutMillis = from.closeNotifyFlushTimeoutMillis;
this.closeNotifyReadTimeoutMillis = from.closeNotifyReadTimeoutMillis;
this.builderHashCode = from.builderHashCode;
this.confPerDomainName = from.confPerDomainName;
this.sniMappings = from.sniMappings;
if (from.sniProvider != null) {
this.sniProvider = updateAllSslProviderConfiguration(from.sniProvider.confPerDomainName, this, type);
if (!confPerDomainName.isEmpty()) {
this.sniProvider = updateAllSslProviderConfiguration(confPerDomainName, this, type);
}
else {
this.sniProvider = new SniProvider(sniMappings);
}
}
else {
this.sniProvider = null;
@@ -613,6 +648,7 @@ static final class Build implements SslContextSpec, DefaultConfigurationSpec, Bu
long closeNotifyReadTimeoutMillis;
List<SNIServerName> serverNames;
final Map<String, SslProvider> confPerDomainName = new HashMap<>();
AsyncMapping<String, SslProvider> sniMappings;

// SslContextSpec

@@ -704,13 +740,15 @@ public final Builder closeNotifyReadTimeoutMillis(long closeNotifyReadTimeoutMil
@Override
public Builder addSniMapping(String domainName, Consumer<? super SslContextSpec> sslProviderBuilder) {
addInternal(domainName, sslProviderBuilder);
this.sniMappings = null;
return this;
}

@Override
public Builder addSniMappings(Map<String, Consumer<? super SslContextSpec>> confPerDomainName) {
Objects.requireNonNull(confPerDomainName);
confPerDomainName.forEach(this::addInternal);
this.sniMappings = null;
return this;
}

@@ -719,6 +757,14 @@ public Builder setSniMappings(Map<String, Consumer<? super SslContextSpec>> conf
Objects.requireNonNull(confPerDomainName);
this.confPerDomainName.clear();
confPerDomainName.forEach(this::addInternal);
this.sniMappings = null;
return this;
}

@Override
public Builder setSniAsyncMappings(AsyncMapping<String, SslProvider> mappings) {
this.sniMappings = Objects.requireNonNull(mappings);
this.confPerDomainName.clear();
return this;
}

Original file line number Diff line number Diff line change
@@ -117,6 +117,7 @@
import reactor.netty.http.client.PrematureCloseException;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.resources.LoopResources;
import reactor.netty.tcp.SslProvider;
import reactor.netty.tcp.TcpClient;
import reactor.netty.tcp.TcpServer;
import reactor.netty.transport.TransportConfig;
@@ -1976,6 +1977,53 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
assertThat(hostname.get()).isEqualTo("test.com");
}

@Test
void testSniSupportAsyncMappings() throws Exception {
SelfSignedCertificate defaultCert = new SelfSignedCertificate("default");
Http11SslContextSpec defaultSslContextBuilder =
Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey());

SelfSignedCertificate testCert = new SelfSignedCertificate("test.com");
Http11SslContextSpec testSslContextBuilder =
Http11SslContextSpec.forServer(testCert.certificate(), testCert.privateKey());
SslProvider testSslProvider = SslProvider.builder().sslContext(testSslContextBuilder).build();

Http11SslContextSpec clientSslContextBuilder =
Http11SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

AtomicReference<String> hostname = new AtomicReference<>();
disposableServer =
createServer()
.secure(spec -> spec.sslContext(defaultSslContextBuilder)
.setSniAsyncMappings((input, promise) -> promise.setSuccess(testSslProvider)))
.doOnChannelInit((obs, channel, remoteAddress) ->
channel.pipeline()
.addAfter(NettyPipeline.SslHandler, "test", new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
hostname.set(((SniCompletionEvent) evt).hostname());
}
ctx.fireUserEventTriggered(evt);
}
}))
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();

createClient(disposableServer::address)
.secure(spec -> spec.sslContext(clientSslContextBuilder)
.serverNames(new SNIHostName("test.com")))
.get()
.uri("/")
.responseContent()
.aggregate()
.block(Duration.ofSeconds(30));

assertThat(hostname.get()).isNotNull();
assertThat(hostname.get()).isEqualTo("test.com");
}

@Test
void testIssue1286_HTTP11() throws Exception {
doTestIssue1286(Function.identity(), Function.identity(), false, false);
Loading

0 comments on commit 0d440cb

Please sign in to comment.