diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java index 0bb424d954ad5..17862e496fa06 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/Netty4Plugin.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkModule; import org.elasticsearch.common.network.NetworkService; @@ -83,8 +84,8 @@ public Map> getTransports(Settings settings, ThreadP CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { - return Collections.singletonMap(NETTY_TRANSPORT_NAME, () -> new Netty4Transport(settings, threadPool, networkService, bigArrays, - namedWriteableRegistry, circuitBreakerService)); + return Collections.singletonMap(NETTY_TRANSPORT_NAME, () -> new Netty4Transport(settings, Version.CURRENT, threadPool, + networkService, bigArrays, namedWriteableRegistry, circuitBreakerService)); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index a6104bed48af8..1d0b11f59c621 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -38,7 +38,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.action.ActionListener; +import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -101,9 +101,9 @@ public class Netty4Transport extends TcpTransport { private volatile Bootstrap clientBootstrap; private volatile NioEventLoopGroup eventLoopGroup; - public Netty4Transport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, + public Netty4Transport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { - super("netty", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + super("netty", settings, version, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); this.workerCount = WORKER_COUNT.get(settings); @@ -221,44 +221,31 @@ protected final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) } @Override - protected NettyTcpChannel initiateChannel(DiscoveryNode node, ActionListener listener) throws IOException { + protected NettyTcpChannel initiateChannel(DiscoveryNode node) throws IOException { InetSocketAddress address = node.getAddress().address(); Bootstrap bootstrapWithHandler = clientBootstrap.clone(); bootstrapWithHandler.handler(getClientChannelInitializer(node)); bootstrapWithHandler.remoteAddress(address); - ChannelFuture channelFuture = bootstrapWithHandler.connect(); + ChannelFuture connectFuture = bootstrapWithHandler.connect(); - Channel channel = channelFuture.channel(); + Channel channel = connectFuture.channel(); if (channel == null) { - ExceptionsHelper.maybeDieOnAnotherThread(channelFuture.cause()); - throw new IOException(channelFuture.cause()); + ExceptionsHelper.maybeDieOnAnotherThread(connectFuture.cause()); + throw new IOException(connectFuture.cause()); } addClosedExceptionLogger(channel); - NettyTcpChannel nettyChannel = new NettyTcpChannel(channel); + NettyTcpChannel nettyChannel = new NettyTcpChannel(channel, "default", connectFuture); channel.attr(CHANNEL_KEY).set(nettyChannel); - channelFuture.addListener(f -> { - if (f.isSuccess()) { - listener.onResponse(null); - } else { - Throwable cause = f.cause(); - if (cause instanceof Error) { - ExceptionsHelper.maybeDieOnAnotherThread(cause); - listener.onFailure(new Exception(cause)); - } else { - listener.onFailure((Exception) cause); - } - } - }); - return nettyChannel; } @Override protected NettyTcpChannel bind(String name, InetSocketAddress address) { Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel(); - NettyTcpChannel esChannel = new NettyTcpChannel(channel); + // TODO: Switch to same server channels + NettyTcpChannel esChannel = new NettyTcpChannel(channel, "server", channel.newSucceededFuture()); channel.attr(CHANNEL_KEY).set(esChannel); return esChannel; } @@ -314,7 +301,8 @@ protected ServerChannelInitializer(String name) { @Override protected void initChannel(Channel ch) throws Exception { addClosedExceptionLogger(ch); - NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch); + NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name, ch.newSucceededFuture()); + ch.attr(CHANNEL_KEY).set(nettyTcpChannel); serverAcceptedChannel(nettyTcpChannel); ch.pipeline().addLast("logging", new ESLoggingHandler()); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java index 291f98f19e55c..d441870cc5344 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java @@ -20,14 +20,19 @@ package org.elasticsearch.transport.netty4; import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPromise; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.concurrent.CompletableContext; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TransportException; +import java.io.IOException; import java.net.InetSocketAddress; import java.util.concurrent.CompletableFuture; @@ -35,9 +40,13 @@ public class NettyTcpChannel implements TcpChannel { private final Channel channel; private final CompletableFuture closeContext = new CompletableFuture<>(); + private final CompletableContext connectContext; + private final String profile; - NettyTcpChannel(Channel channel) { + NettyTcpChannel(Channel channel, String profile, @Nullable ChannelFuture connectFuture) { this.channel = channel; + this.profile = profile; + this.connectContext = new CompletableContext<>(); this.channel.closeFuture().addListener(f -> { if (f.isSuccess()) { closeContext.complete(null); @@ -51,6 +60,20 @@ public class NettyTcpChannel implements TcpChannel { } } }); + + connectFuture.addListener(f -> { + if (f.isSuccess()) { + connectContext.complete(null); + } else { + Throwable cause = f.cause(); + if (cause instanceof Error) { + ExceptionsHelper.maybeDieOnAnotherThread(cause); + connectContext.completeExceptionally(new Exception(cause)); + } else { + connectContext.completeExceptionally((Exception) cause); + } + } + }); } @Override @@ -63,9 +86,19 @@ public void addCloseListener(ActionListener listener) { closeContext.whenComplete(ActionListener.toBiConsumer(listener)); } + public void addConnectListener(ActionListener listener) { + connectContext.addListener(ActionListener.toBiConsumer(listener)); + } + @Override - public void setSoLinger(int value) { - channel.config().setOption(ChannelOption.SO_LINGER, value); + public void setSoLinger(int value) throws IOException { + if (channel.isOpen()) { + try { + channel.config().setOption(ChannelOption.SO_LINGER, value); + } catch (ChannelException e) { + throw new IOException(e); + } + } } @Override diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java index 9d2aa0d9e2add..ab2f9eadb714e 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4ScheduledPingTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.transport.netty4; +import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -60,15 +61,15 @@ public void testScheduledPing() throws Exception { CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService(); NamedWriteableRegistry registry = new NamedWriteableRegistry(Collections.emptyList()); - final Netty4Transport nettyA = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), - BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); + final Netty4Transport nettyA = new Netty4Transport(settings, Version.CURRENT, threadPool, + new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); MockTransportService serviceA = new MockTransportService(settings, nettyA, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, null); serviceA.start(); serviceA.acceptIncomingRequests(); - final Netty4Transport nettyB = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), - BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); + final Netty4Transport nettyB = new Netty4Transport(settings, Version.CURRENT, threadPool, + new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, registry, circuitBreakerService); MockTransportService serviceB = new MockTransportService(settings, nettyB, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, null); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java index dcd730fbc4ebe..db57730058fee 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4SizeHeaderFrameDecoderTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport.netty4; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; @@ -65,7 +66,7 @@ public void startThreadPool() { threadPool = new ThreadPool(settings); NetworkService networkService = new NetworkService(Collections.emptyList()); BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - nettyTransport = new Netty4Transport(settings, threadPool, networkService, bigArrays, + nettyTransport = new Netty4Transport(settings, Version.CURRENT, threadPool, networkService, bigArrays, new NamedWriteableRegistry(Collections.emptyList()), new NoneCircuitBreakerService()); nettyTransport.start(); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java index b81c8efcb47ee..b93e09b53649e 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java @@ -108,7 +108,7 @@ public ExceptionThrowingNetty4Transport( BigArrays bigArrays, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { - super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); + super(settings, Version.CURRENT, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); } @Override diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java index a49df3caaba4e..785c4cfb114bc 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/NettyTransportMultiPortTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.transport.netty4; +import org.elasticsearch.Version; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -118,7 +119,7 @@ public void testThatDefaultProfilePortOverridesGeneralConfiguration() throws Exc private TcpTransport startTransport(Settings settings, ThreadPool threadPool) { BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - TcpTransport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), + TcpTransport transport = new Netty4Transport(settings, Version.CURRENT, threadPool, new NetworkService(Collections.emptyList()), bigArrays, new NamedWriteableRegistry(Collections.emptyList()), new NoneCircuitBreakerService()); transport.start(); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java index e7faac8ae01db..4c651c31bee7e 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.transport.netty4; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -40,7 +41,6 @@ import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportService; -import java.io.IOException; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Collections; @@ -54,23 +54,17 @@ public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase public static MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, ClusterSettings clusterSettings, boolean doHandshake) { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); - Transport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()), + Transport transport = new Netty4Transport(settings, version, threadPool, new NetworkService(Collections.emptyList()), BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, channel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } - - @Override - protected Version getCurrentVersion() { - return version; - } }; MockTransportService mockTransportService = MockTransportService.createNewService(settings, transport, version, threadPool, clusterSettings, Collections.emptySet()); diff --git a/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java b/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java index aa619409c16eb..98f2febd79516 100644 --- a/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java +++ b/plugins/discovery-ec2/src/test/java/org/elasticsearch/discovery/ec2/Ec2DiscoveryTests.java @@ -20,7 +20,6 @@ package org.elasticsearch.discovery.ec2; import com.amazonaws.services.ec2.model.Tag; -import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; @@ -74,8 +73,7 @@ public static void stopThreadPool() throws InterruptedException { public void createTransportService() { NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); final Transport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, - new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), - Version.CURRENT) { + new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList())) { @Override public TransportAddress[] addressesFromString(String address, int perAddressLimit) throws UnknownHostException { // we just need to ensure we don't resolve DNS here diff --git a/server/src/main/java/org/elasticsearch/transport/TcpChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpChannel.java index 22453ac43b4ea..858e0361fc0ca 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpChannel.java @@ -62,7 +62,6 @@ public interface TcpChannel extends Releasable { */ void addCloseListener(ActionListener listener); - /** * This sets the low level socket option {@link java.net.StandardSocketOptions} SO_LINGER on a channel. * @@ -71,7 +70,6 @@ public interface TcpChannel extends Releasable { */ void setSoLinger(int value) throws IOException; - /** * Indicates whether a channel is currently open * @@ -95,6 +93,8 @@ public interface TcpChannel extends Releasable { */ void sendMessage(BytesReference reference, ActionListener listener); + void addConnectListener(ActionListener listener); + /** * Closes the channel. * @@ -128,45 +128,6 @@ static void closeChannels(List channels, boolean block } } - /** - * Awaits for all of the pending connections to complete. Will throw an exception if at least one of the - * connections fails. - * - * @param discoveryNode the node for the pending connections - * @param connectionFutures representing the pending connections - * @param connectTimeout to wait for a connection - * @throws ConnectTransportException if one of the connections fails - */ - static void awaitConnected(DiscoveryNode discoveryNode, List> connectionFutures, TimeValue connectTimeout) - throws ConnectTransportException { - Exception connectionException = null; - boolean allConnected = true; - - for (ActionFuture connectionFuture : connectionFutures) { - try { - connectionFuture.get(connectTimeout.getMillis(), TimeUnit.MILLISECONDS); - } catch (TimeoutException e) { - allConnected = false; - break; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException(e); - } catch (ExecutionException e) { - allConnected = false; - connectionException = (Exception) e.getCause(); - break; - } - } - - if (allConnected == false) { - if (connectionException == null) { - throw new ConnectTransportException(discoveryNode, "connect_timeout[" + connectTimeout + "]"); - } else { - throw new ConnectTransportException(discoveryNode, "connect_exception", connectionException); - } - } - } - static void blockOnFutures(List> futures) { for (ActionFuture future : futures) { try { diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index f8ef3dd8d218d..cd86989241304 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -23,7 +23,6 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; -import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NotifyOnceListener; import org.elasticsearch.action.support.PlainActionFuture; @@ -45,7 +44,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.metrics.CounterMetric; import org.elasticsearch.common.metrics.MeanMetric; import org.elasticsearch.common.network.NetworkAddress; @@ -60,6 +58,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.indices.breaker.CircuitBreakerService; @@ -87,7 +86,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; @@ -99,7 +97,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.function.Consumer; import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -165,6 +162,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements public static final Setting.AffixSetting PUBLISH_PORT_PROFILE = affixKeySetting("transport.profiles.", "publish_port", key -> intSetting(key, -1, -1, Setting.Property.NodeScope)); + // This is the number of bytes necessary to read the message size + public static final int BYTES_NEEDED_FOR_MESSAGE_SIZE = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; public static final int PING_DATA_SIZE = -1; protected final CounterMetric successfulPings = new CounterMetric(); protected final CounterMetric failedPings = new CounterMetric(); @@ -176,6 +175,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements protected final Settings settings; private final CircuitBreakerService circuitBreakerService; + private final Version version; protected final ThreadPool threadPool; private final BigArrays bigArrays; protected final NetworkService networkService; @@ -198,23 +198,22 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private volatile BoundTransportAddress boundAddress; private final String transportName; - private final ConcurrentMap pendingHandshakes = new ConcurrentHashMap<>(); - private final CounterMetric numHandshakes = new CounterMetric(); - private static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake"; - private final MeanMetric readBytesMetric = new MeanMetric(); private final MeanMetric transmittedBytesMetric = new MeanMetric(); private volatile Map> requestHandlers = Collections.emptyMap(); private final ResponseHandlers responseHandlers = new ResponseHandlers(); + private final TcpTransportHandshaker handshaker; + private final TransportLogger transportLogger; private final BytesReference pingMessage; private final String nodeName; - public TcpTransport(String transportName, Settings settings, ThreadPool threadPool, BigArrays bigArrays, + public TcpTransport(String transportName, Settings settings, Version version, ThreadPool threadPool, BigArrays bigArrays, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { super(settings); this.settings = settings; this.profileSettings = getProfileSettings(settings); + this.version = version; this.threadPool = threadPool; this.bigArrays = bigArrays; this.circuitBreakerService = circuitBreakerService; @@ -222,6 +221,13 @@ public TcpTransport(String transportName, Settings settings, ThreadPool threadPo this.compress = Transport.TRANSPORT_TCP_COMPRESS.get(settings); this.networkService = networkService; this.transportName = transportName; + this.transportLogger = new TransportLogger(); + this.handshaker = new TcpTransportHandshaker(version, threadPool, + (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId, + TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE, TransportRequestOptions.EMPTY, v, + TransportStatus.setHandshake((byte) 0)), + (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId, + TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0))); this.nodeName = Node.NODE_NAME_SETTING.get(settings); final Settings defaultFeatures = DEFAULT_FEATURES_SETTING.get(settings); if (defaultFeatures == null) { @@ -272,41 +278,6 @@ public synchronized void registerRequestHandl requestHandlers = MapBuilder.newMapBuilder(requestHandlers).put(reg.getAction(), reg).immutableMap(); } - private static class HandshakeResponseHandler implements TransportResponseHandler { - final AtomicReference versionRef = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(1); - final AtomicReference exceptionRef = new AtomicReference<>(); - final TcpChannel channel; - - HandshakeResponseHandler(TcpChannel channel) { - this.channel = channel; - } - - @Override - public VersionHandshakeResponse read(StreamInput in) throws IOException { - return new VersionHandshakeResponse(in); - } - - @Override - public void handleResponse(VersionHandshakeResponse response) { - final boolean success = versionRef.compareAndSet(null, response.version); - latch.countDown(); - assert success; - } - - @Override - public void handleException(TransportException exp) { - final boolean success = exceptionRef.compareAndSet(null, exp); - latch.countDown(); - assert success; - } - - @Override - public String executor() { - return ThreadPool.Names.SAME; - } - } - public final class NodeChannels extends CloseableConnection { private final Map typeMapping; private final List channels; @@ -428,83 +399,59 @@ public NodeChannels openConnection(DiscoveryNode node, ConnectionProfile connect if (node == null) { throw new ConnectTransportException(null, "can't open connection to a null node"); } - boolean success = false; - NodeChannels nodeChannels = null; connectionProfile = maybeOverrideConnectionProfile(connectionProfile); closeLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); - try { - int numConnections = connectionProfile.getNumConnections(); - assert numConnections > 0 : "A connection profile must be configured with at least one connection"; - List channels = new ArrayList<>(numConnections); - List> connectionFutures = new ArrayList<>(numConnections); - for (int i = 0; i < numConnections; ++i) { - try { - PlainActionFuture connectFuture = PlainActionFuture.newFuture(); - connectionFutures.add(connectFuture); - TcpChannel channel = initiateChannel(node, connectFuture); - logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel)); - channels.add(channel); - } catch (Exception e) { - // If there was an exception when attempting to instantiate the raw channels, we close all of the channels - TcpChannel.closeChannels(channels, false); - throw e; - } - } - - // If we make it past the block above, we successfully instantiated all of the channels - try { - TcpChannel.awaitConnected(node, connectionFutures, connectionProfile.getConnectTimeout()); - } catch (Exception ex) { - TcpChannel.closeChannels(channels, false); - throw ex; - } + PlainActionFuture connectionFuture = PlainActionFuture.newFuture(); + List pendingChannels = initiateConnection(node, connectionProfile, connectionFuture); - // If we make it past the block above, we have successfully established connections for all of the channels - final TcpChannel handshakeChannel = channels.get(0); // one channel is guaranteed by the connection profile - handshakeChannel.addCloseListener(ActionListener.wrap(() -> cancelHandshakeForChannel(handshakeChannel))); - Version version; - try { - version = executeHandshake(node, handshakeChannel, connectionProfile.getHandshakeTimeout()); - } catch (Exception ex) { - TcpChannel.closeChannels(channels, false); - throw ex; + try { + return connectionFuture.actionGet(); + } catch (IllegalStateException e) { + // If the future was interrupted we can close the channels to improve the shutdown of the MockTcpTransport + if (e.getCause() instanceof InterruptedException) { + TcpChannel.closeChannels(pendingChannels, false); } + throw e; + } + } finally { + closeLock.readLock().unlock(); + } + } - // If we make it past the block above, we have successfully completed the handshake and the connection is now open. - // At this point we should construct the connection, notify the transport service, and attach close listeners to the - // underlying channels. - nodeChannels = new NodeChannels(node, channels, connectionProfile, version); - final NodeChannels finalNodeChannels = nodeChannels; + private List initiateConnection(DiscoveryNode node, ConnectionProfile connectionProfile, + ActionListener listener) { + int numConnections = connectionProfile.getNumConnections(); + assert numConnections > 0 : "A connection profile must be configured with at least one connection"; - Consumer onClose = c -> { - assert c.isOpen() == false : "channel is still open when onClose is called"; - finalNodeChannels.close(); - }; + final List channels = new ArrayList<>(numConnections); - nodeChannels.channels.forEach(ch -> ch.addCloseListener(ActionListener.wrap(() -> onClose.accept(ch)))); - success = true; - return nodeChannels; + for (int i = 0; i < numConnections; ++i) { + try { + TcpChannel channel = initiateChannel(node); + logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel)); + channels.add(channel); } catch (ConnectTransportException e) { - throw e; + TcpChannel.closeChannels(channels, false); + listener.onFailure(e); + return channels; } catch (Exception e) { - // ConnectTransportExceptions are handled specifically on the caller end - we wrap the actual exception to ensure - // only relevant exceptions are logged on the caller end.. this is the same as in connectToNode - throw new ConnectTransportException(node, "general node connection failure", e); - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(nodeChannels); - } + TcpChannel.closeChannels(channels, false); + listener.onFailure(new ConnectTransportException(node, "general node connection failure", e)); + return channels; } - } finally { - closeLock.readLock().unlock(); } - } - protected Version getCurrentVersion() { - // this is just for tests to mock stuff like the nodes version - tests can override this internally - return Version.CURRENT; + ChannelsConnectedListener channelsConnectedListener = new ChannelsConnectedListener(node, connectionProfile, channels, listener); + + for (TcpChannel channel : channels) { + channel.addConnectListener(channelsConnectedListener); + } + + TimeValue connectTimeout = connectionProfile.getConnectTimeout(); + threadPool.schedule(connectTimeout, ThreadPool.Names.GENERIC, channelsConnectedListener::onTimeout); + return channels; } @Override @@ -672,7 +619,9 @@ public TransportAddress[] addressesFromString(String address, int perAddressLimi // not perfect, but PortsRange should take care of any port range validation, not a regex private static final Pattern BRACKET_PATTERN = Pattern.compile("^\\[(.*:.*)\\](?::([\\d\\-]*))?$"); - /** parse a hostname+port range spec into its equivalent addresses */ + /** + * parse a hostname+port range spec into its equivalent addresses + */ static TransportAddress[] parse(String hostPortString, String defaultPortRange, int perAddressLimit) throws UnknownHostException { Objects.requireNonNull(hostPortString); String host; @@ -770,7 +719,7 @@ protected void onException(TcpChannel channel, Exception e) { if (isCloseConnectionException(e)) { logger.trace(() -> new ParameterizedMessage( - "close connection exception caught on transport layer [{}], disconnecting from relevant node", channel), e); + "close connection exception caught on transport layer [{}], disconnecting from relevant node", channel), e); // close the channel, which will cause a node to be disconnected if relevant TcpChannel.closeChannel(channel, false); } else if (isConnectException(e)) { @@ -783,7 +732,7 @@ protected void onException(TcpChannel channel, Exception e) { TcpChannel.closeChannel(channel, false); } else if (e instanceof CancelledKeyException) { logger.trace(() -> new ParameterizedMessage( - "cancelled key exception caught on transport layer [{}], disconnecting from relevant node", channel), e); + "cancelled key exception caught on transport layer [{}], disconnecting from relevant node", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant TcpChannel.closeChannel(channel, false); } else if (e instanceof TcpTransport.HttpOnTransportException) { @@ -811,9 +760,23 @@ protected void innerOnFailure(Exception e) { } } + protected void onServerException(TcpServerChannel channel, Exception e) { + logger.error(new ParameterizedMessage("exception from server channel caught on transport layer [channel={}]", channel), e); + } + + /** + * Exception handler for exceptions that are not associated with a specific channel. + * + * @param exception the exception + */ + protected void onNonChannelException(Exception exception) { + logger.warn(new ParameterizedMessage("exception caught on transport layer [thread={}]", Thread.currentThread().getName()), + exception); + } + protected void serverAcceptedChannel(TcpChannel channel) { boolean addedOnThisCall = acceptedChannels.add(channel); - assert addedOnThisCall : "Channel should only be added to accept channel set once"; + assert addedOnThisCall : "Channel should only be added to accepted channel set once"; channel.addCloseListener(ActionListener.wrap(() -> acceptedChannels.remove(channel))); logger.trace(() -> new ParameterizedMessage("Tcp transport channel accepted: {}", channel)); } @@ -830,11 +793,10 @@ protected void serverAcceptedChannel(TcpChannel channel) { * Initiate a single tcp socket channel. * * @param node for the initiated connection - * @param connectListener listener to be called when connection complete * @return the pending connection * @throws IOException if an I/O exception occurs while opening the channel */ - protected abstract TcpChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException; + protected abstract TcpChannel initiateChannel(DiscoveryNode node) throws IOException; /** * Called to tear down internal resources @@ -869,7 +831,7 @@ private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel cha // we pick the smallest of the 2, to support both backward and forward compatibility // note, this is the only place we need to do this, since from here on, we use the serialized version // as the version to use also when the node receiving this request will send the response with - Version version = Version.min(getCurrentVersion(), channelVersion); + Version version = Version.min(this.version, channelVersion); stream.setVersion(version); threadPool.getThreadContext().writeTo(stream); @@ -915,12 +877,12 @@ private void internalSendMessage(TcpChannel channel, BytesReference message, Sen * @param action the action this response replies to */ public void sendErrorResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final Exception error, - final long requestId, - final String action) throws IOException { + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final Exception error, + final long requestId, + final String action) throws IOException { try (BytesStreamOutput stream = new BytesStreamOutput()) { stream.setVersion(nodeVersion); stream.setFeatures(features); @@ -946,25 +908,25 @@ public void sendErrorResponse( * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending back errors to the caller */ public void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - final TransportResponseOptions options) throws IOException { + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final TransportResponse response, + final long requestId, + final String action, + final TransportResponseOptions options) throws IOException { sendResponse(nodeVersion, features, channel, response, requestId, action, options, (byte) 0); } private void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - TransportResponseOptions options, - byte status) throws IOException { + final Version nodeVersion, + final Set features, + final TcpChannel channel, + final TransportResponse response, + final long requestId, + final String action, + TransportResponseOptions options, + byte status) throws IOException { if (compress) { options = TransportResponseOptions.builder(options).withCompress(true).build(); } @@ -1040,72 +1002,143 @@ private BytesReference buildMessage(long requestId, byte status, Version nodeVer } /** - * Validates the first N bytes of the message header and returns false if the message is - * a ping message and has no payload ie. isn't a real user level message. + * Handles inbound message that has been decoded. * - * @throws IllegalStateException if the message is too short, less than the header or less that the header plus the message size - * @throws HttpOnTransportException if the message has no valid header and appears to be an HTTP message - * @throws IllegalArgumentException if the message is greater that the maximum allowed frame size. This is dependent on the available - * memory. + * @param channel the channel the message if fomr + * @param message the message */ - public static boolean validateMessageHeader(BytesReference buffer) throws IOException { - final int sizeHeaderLength = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; - if (buffer.length() < sizeHeaderLength) { - throw new IllegalStateException("message size must be >= to the header size"); - } - int offset = 0; - if (buffer.get(offset) != 'E' || buffer.get(offset + 1) != 'S') { - // special handling for what is probably HTTP - if (bufferStartsWith(buffer, offset, "GET ") || - bufferStartsWith(buffer, offset, "POST ") || - bufferStartsWith(buffer, offset, "PUT ") || - bufferStartsWith(buffer, offset, "HEAD ") || - bufferStartsWith(buffer, offset, "DELETE ") || - bufferStartsWith(buffer, offset, "OPTIONS ") || - bufferStartsWith(buffer, offset, "PATCH ") || - bufferStartsWith(buffer, offset, "TRACE ")) { - - throw new HttpOnTransportException("This is not an HTTP port"); + public void inboundMessage(TcpChannel channel, BytesReference message) { + try { + transportLogger.logInboundMessage(channel, message); + // Message length of 0 is a ping + if (message.length() != 0) { + messageReceived(message, channel); } + } catch (Exception e) { + onException(channel, e); + } + } - // we have 6 readable bytes, show 4 (should be enough) - throw new StreamCorruptedException("invalid internal transport message format, got (" - + Integer.toHexString(buffer.get(offset) & 0xFF) + "," - + Integer.toHexString(buffer.get(offset + 1) & 0xFF) + "," - + Integer.toHexString(buffer.get(offset + 2) & 0xFF) + "," - + Integer.toHexString(buffer.get(offset + 3) & 0xFF) + ")"); + /** + * Consumes bytes that are available from network reads. This method returns the number of bytes consumed + * in this call. + * + * @param channel the channel read from + * @param bytesReference the bytes available to consume + * @return the number of bytes consumed + * @throws StreamCorruptedException if the message header format is not recognized + * @throws TcpTransport.HttpOnTransportException if the message header appears to be an HTTP message + * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. + * This is dependent on the available memory. + */ + public int consumeNetworkReads(TcpChannel channel, BytesReference bytesReference) throws IOException { + BytesReference message = decodeFrame(bytesReference); + + if (message == null) { + return 0; + } else { + inboundMessage(channel, message); + return message.length() + BYTES_NEEDED_FOR_MESSAGE_SIZE; } + } - final int dataLen; - try (StreamInput input = buffer.streamInput()) { - input.skip(TcpHeader.MARKER_BYTES_SIZE); - dataLen = input.readInt(); - if (dataLen == PING_DATA_SIZE) { - // discard the messages we read and continue, this is achieved by skipping the bytes - // and returning null - return false; + /** + * Attempts to a decode a message from the provided bytes. If a full message is not available, null is + * returned. If the message is a ping, an empty {@link BytesReference} will be returned. + * + * @param networkBytes the will be read + * @return the message decoded + * @throws StreamCorruptedException if the message header format is not recognized + * @throws TcpTransport.HttpOnTransportException if the message header appears to be an HTTP message + * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. + * This is dependent on the available memory. + */ + static BytesReference decodeFrame(BytesReference networkBytes) throws IOException { + int messageLength = readMessageLength(networkBytes); + if (messageLength == -1) { + return null; + } else { + int totalLength = messageLength + BYTES_NEEDED_FOR_MESSAGE_SIZE; + if (totalLength > networkBytes.length()) { + return null; + } else if (totalLength == 6) { + return EMPTY_BYTES_REFERENCE; + } else { + return networkBytes.slice(BYTES_NEEDED_FOR_MESSAGE_SIZE, messageLength); } } + } - if (dataLen <= 0) { - throw new StreamCorruptedException("invalid data length: " + dataLen); + /** + * Validates the first 6 bytes of the message header and returns the length of the message. If 6 bytes + * are not available, it returns -1. + * + * @param networkBytes the will be read + * @return the length of the message + * @throws StreamCorruptedException if the message header format is not recognized + * @throws TcpTransport.HttpOnTransportException if the message header appears to be an HTTP message + * @throws IllegalArgumentException if the message length is greater that the maximum allowed frame size. + * This is dependent on the available memory. + */ + public static int readMessageLength(BytesReference networkBytes) throws IOException { + if (networkBytes.length() < BYTES_NEEDED_FOR_MESSAGE_SIZE) { + return -1; + } else { + return readHeaderBuffer(networkBytes); } - // safety against too large frames being sent - if (dataLen > NINETY_PER_HEAP_SIZE) { - throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(dataLen) + "] exceeded [" - + new ByteSizeValue(NINETY_PER_HEAP_SIZE) + "]"); + } + + private static int readHeaderBuffer(BytesReference headerBuffer) throws IOException { + if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') { + if (appearsToBeHTTP(headerBuffer)) { + throw new TcpTransport.HttpOnTransportException("This is not an HTTP port"); + } + + throw new StreamCorruptedException("invalid internal transport message format, got (" + + Integer.toHexString(headerBuffer.get(0) & 0xFF) + "," + + Integer.toHexString(headerBuffer.get(1) & 0xFF) + "," + + Integer.toHexString(headerBuffer.get(2) & 0xFF) + "," + + Integer.toHexString(headerBuffer.get(3) & 0xFF) + ")"); + } + final int messageLength; + try (StreamInput input = headerBuffer.streamInput()) { + input.skip(TcpHeader.MARKER_BYTES_SIZE); + messageLength = input.readInt(); } - if (buffer.length() < dataLen + sizeHeaderLength) { - throw new IllegalStateException("buffer must be >= to the message size but wasn't"); + if (messageLength == TcpTransport.PING_DATA_SIZE) { + // This is a ping + return 0; } - return true; + + if (messageLength <= 0) { + throw new StreamCorruptedException("invalid data length: " + messageLength); + } + + if (messageLength > NINETY_PER_HEAP_SIZE) { + throw new IllegalArgumentException("transport content length received [" + new ByteSizeValue(messageLength) + "] exceeded [" + + new ByteSizeValue(NINETY_PER_HEAP_SIZE) + "]"); + } + + return messageLength; + } + + private static boolean appearsToBeHTTP(BytesReference headerBuffer) { + return bufferStartsWith(headerBuffer, "GET") || + bufferStartsWith(headerBuffer, "POST") || + bufferStartsWith(headerBuffer, "PUT") || + bufferStartsWith(headerBuffer, "HEAD") || + bufferStartsWith(headerBuffer, "DELETE") || + // Actually 'OPTIONS'. But we are only guaranteed to have read six bytes at this point. + bufferStartsWith(headerBuffer, "OPTION") || + bufferStartsWith(headerBuffer, "PATCH") || + bufferStartsWith(headerBuffer, "TRACE"); } - private static boolean bufferStartsWith(BytesReference buffer, int offset, String method) { + private static boolean bufferStartsWith(BytesReference buffer, String method) { char[] chars = method.toCharArray(); for (int i = 0; i < chars.length; i++) { - if (buffer.get(offset + i) != chars[i]) { + if (buffer.get(i) != chars[i]) { return false; } } @@ -1167,7 +1200,7 @@ public final void messageReceived(BytesReference reference, TcpChannel channel, streamIn = compressor.streamInput(streamIn); } final boolean isHandshake = TransportStatus.isHandshake(status); - ensureVersionCompatibility(version, getCurrentVersion(), isHandshake); + ensureVersionCompatibility(version, this.version, isHandshake); streamIn = new NamedWriteableAwareStreamInput(streamIn, namedWriteableRegistry); streamIn.setVersion(version); threadPool.getThreadContext().readHeaders(streamIn); @@ -1177,12 +1210,12 @@ public final void messageReceived(BytesReference reference, TcpChannel channel, } else { final TransportResponseHandler handler; if (isHandshake) { - handler = pendingHandshakes.remove(requestId); + handler = handshaker.removeHandlerForHandshake(requestId); } else { TransportResponseHandler theHandler = responseHandlers.onResponseReceived(requestId, messageListener); if (theHandler == null && TransportStatus.isError(status)) { - handler = pendingHandshakes.remove(requestId); + handler = handshaker.removeHandlerForHandshake(requestId); } else { handler = theHandler; } @@ -1227,7 +1260,7 @@ static void ensureVersionCompatibility(Version version, Version currentVersion, } private void handleResponse(InetSocketAddress remoteAddress, final StreamInput stream, - final TransportResponseHandler handler) { + final TransportResponseHandler handler) { final T response; try { response = handler.read(stream); @@ -1292,9 +1325,7 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str TransportChannel transportChannel = null; try { if (TransportStatus.isHandshake(status)) { - final VersionHandshakeResponse response = new VersionHandshakeResponse(getCurrentVersion()); - sendResponse(version, features, channel, response, requestId, HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, - TransportStatus.setHandshake((byte) 0)); + handshaker.handleHandshake(version, features, channel, requestId); } else { final RequestHandlerRegistry reg = getRequestHandler(action); if (reg == null) { @@ -1317,7 +1348,7 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str // the circuit breaker tripped if (transportChannel == null) { transportChannel = - new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, profileName, 0); + new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, profileName, 0); } try { transportChannel.sendResponse(e); @@ -1370,100 +1401,22 @@ public void onFailure(Exception e) { } catch (Exception inner) { inner.addSuppressed(e); logger.warn(() -> new ParameterizedMessage( - "Failed to send error message back to client for action [{}]", reg.getAction()), inner); + "Failed to send error message back to client for action [{}]", reg.getAction()), inner); } } } } - private static final class VersionHandshakeResponse extends TransportResponse { - private final Version version; - - private VersionHandshakeResponse(Version version) { - this.version = version; - } - - private VersionHandshakeResponse(StreamInput in) throws IOException { - super.readFrom(in); - version = Version.readVersion(in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - assert version != null; - Version.writeVersion(version, out); - } - } - - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) - throws IOException, InterruptedException { - numHandshakes.inc(); - final long requestId = responseHandlers.newRequestId(); - final HandshakeResponseHandler handler = new HandshakeResponseHandler(channel); - AtomicReference versionRef = handler.versionRef; - AtomicReference exceptionRef = handler.exceptionRef; - pendingHandshakes.put(requestId, handler); - boolean success = false; - try { - if (channel.isOpen() == false) { - // we have to protect us here since sendRequestToChannel won't barf if the channel is closed. - // it's weird but to change it will cause a lot of impact on the exception handling code all over the codebase. - // yet, if we don't check the state here we might have registered a pending handshake handler but the close - // listener calling #onChannelClosed might have already run and we are waiting on the latch below unitl we time out. - throw new IllegalStateException("handshake failed, channel already closed"); - } - // for the request we use the minCompatVersion since we don't know what's the version of the node we talk to - // we also have no payload on the request but the response will contain the actual version of the node we talk - // to as the payload. - final Version minCompatVersion = getCurrentVersion().minimumCompatibilityVersion(); - sendRequestToChannel(node, channel, requestId, HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE, - TransportRequestOptions.EMPTY, minCompatVersion, TransportStatus.setHandshake((byte) 0)); - if (handler.latch.await(timeout.millis(), TimeUnit.MILLISECONDS) == false) { - throw new ConnectTransportException(node, "handshake_timeout[" + timeout + "]"); - } - success = true; - if (exceptionRef.get() != null) { - throw new IllegalStateException("handshake failed", exceptionRef.get()); - } else { - Version version = versionRef.get(); - if (getCurrentVersion().isCompatible(version) == false) { - throw new IllegalStateException("Received message from unsupported version: [" + version - + "] minimal compatible version is: [" + getCurrentVersion().minimumCompatibilityVersion() + "]"); - } - return version; - } - } finally { - final TransportResponseHandler removedHandler = pendingHandshakes.remove(requestId); - // in the case of a timeout or an exception on the send part the handshake has not been removed yet. - // but the timeout is tricky since it's basically a race condition so we only assert on the success case. - assert success && removedHandler == null || success == false : "handler for requestId [" + requestId + "] is not been removed"; - } + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { + handshaker.sendHandshake(responseHandlers.newRequestId(), node, channel, timeout, listener); } - final int getNumPendingHandshakes() { // for testing - return pendingHandshakes.size(); + final int getNumPendingHandshakes() { + return handshaker.getNumPendingHandshakes(); } final long getNumHandshakes() { - return numHandshakes.count(); // for testing - } - - /** - * Called once the channel is closed for instance due to a disconnect or a closed socket etc. - */ - private void cancelHandshakeForChannel(TcpChannel channel) { - final Optional first = pendingHandshakes.entrySet().stream() - .filter((entry) -> entry.getValue().channel == channel).map(Map.Entry::getKey).findFirst(); - if (first.isPresent()) { - final Long requestId = first.get(); - final HandshakeResponseHandler handler = pendingHandshakes.remove(requestId); - if (handler != null) { - // there might be a race removing this or this method might be called twice concurrently depending on how - // the channel is closed ie. due to connection reset or broken pipes - handler.handleException(new TransportException("connection reset")); - } - } + return handshaker.getNumHandshakes(); } /** @@ -1643,4 +1596,69 @@ public final ResponseHandlers getResponseHandlers() { public final RequestHandlerRegistry getRequestHandler(String action) { return requestHandlers.get(action); } + + private final class ChannelsConnectedListener implements ActionListener { + + private final DiscoveryNode node; + private final ConnectionProfile connectionProfile; + private final List channels; + private final ActionListener listener; + private final CountDown countDown; + + private ChannelsConnectedListener(DiscoveryNode node, ConnectionProfile connectionProfile, List channels, + ActionListener listener) { + this.node = node; + this.connectionProfile = connectionProfile; + this.channels = channels; + this.listener = listener; + this.countDown = new CountDown(channels.size()); + } + + @Override + public void onResponse(Void v) { + // Returns true if all connections have completed successfully + if (countDown.countDown()) { + final TcpChannel handshakeChannel = channels.get(0); + try { + executeHandshake(node, handshakeChannel, connectionProfile.getHandshakeTimeout(), new ActionListener() { + @Override + public void onResponse(Version version) { + NodeChannels nodeChannels = new NodeChannels(node, channels, connectionProfile, version); + nodeChannels.channels.forEach(ch -> ch.addCloseListener(ActionListener.wrap(nodeChannels::close))); + listener.onResponse(nodeChannels); + } + + @Override + public void onFailure(Exception e) { + TcpChannel.closeChannels(channels, false); + + if (e instanceof ConnectTransportException) { + listener.onFailure(e); + } else { + listener.onFailure(new ConnectTransportException(node, "general node connection failure", e)); + } + } + }); + } catch (Exception ex) { + TcpChannel.closeChannels(channels, false); + listener.onFailure(ex); + } + } + } + + @Override + public void onFailure(Exception ex) { + if (countDown.fastForward()) { + CloseableChannel.closeChannels(channels, false); + listener.onFailure(new ConnectTransportException(node, "connect_exception", ex)); + } + } + + public void onTimeout() { + if (countDown.fastForward()) { + CloseableChannel.closeChannels(channels, false); + listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout() + "]")); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java b/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java new file mode 100644 index 0000000000000..d1037d2bcb5bd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransportHandshaker.java @@ -0,0 +1,185 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Sends and receives transport-level connection handshakes. This class will send the initial handshake, + * manage state/timeouts while the handshake is in transit, and handle the eventual response. + */ +final class TcpTransportHandshaker { + + static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake"; + private final ConcurrentMap pendingHandshakes = new ConcurrentHashMap<>(); + private final CounterMetric numHandshakes = new CounterMetric(); + + private final Version version; + private final ThreadPool threadPool; + private final HandshakeRequestSender handshakeRequestSender; + private final HandshakeResponseSender handshakeResponseSender; + + TcpTransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender, + HandshakeResponseSender handshakeResponseSender) { + this.version = version; + this.threadPool = threadPool; + this.handshakeRequestSender = handshakeRequestSender; + this.handshakeResponseSender = handshakeResponseSender; + } + + void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { + numHandshakes.inc(); + final HandshakeResponseHandler handler = new HandshakeResponseHandler(requestId, version, listener); + pendingHandshakes.put(requestId, handler); + channel.addCloseListener(ActionListener.wrap( + () -> handler.handleLocalException(new TransportException("handshake failed because connection reset")))); + boolean success = false; + try { + // for the request we use the minCompatVersion since we don't know what's the version of the node we talk to + // we also have no payload on the request but the response will contain the actual version of the node we talk + // to as the payload. + final Version minCompatVersion = version.minimumCompatibilityVersion(); + handshakeRequestSender.sendRequest(node, channel, requestId, minCompatVersion); + + threadPool.schedule(timeout, ThreadPool.Names.GENERIC, + () -> handler.handleLocalException(new ConnectTransportException(node, "handshake_timeout[" + timeout + "]"))); + success = true; + } catch (Exception e) { + handler.handleLocalException(new ConnectTransportException(node, "failure to send " + HANDSHAKE_ACTION_NAME, e)); + } finally { + if (success == false) { + TransportResponseHandler removed = pendingHandshakes.remove(requestId); + assert removed == null : "Handshake should not be pending if exception was thrown"; + } + } + } + + void handleHandshake(Version version, Set features, TcpChannel channel, long requestId) throws IOException { + handshakeResponseSender.sendResponse(version, features, channel, new VersionHandshakeResponse(this.version), requestId); + } + + TransportResponseHandler removeHandlerForHandshake(long requestId) { + return pendingHandshakes.remove(requestId); + } + + int getNumPendingHandshakes() { + return pendingHandshakes.size(); + } + + long getNumHandshakes() { + return numHandshakes.count(); + } + + private class HandshakeResponseHandler implements TransportResponseHandler { + + private final long requestId; + private final Version currentVersion; + private final ActionListener listener; + private final AtomicBoolean isDone = new AtomicBoolean(false); + + private HandshakeResponseHandler(long requestId, Version currentVersion, ActionListener listener) { + this.requestId = requestId; + this.currentVersion = currentVersion; + this.listener = listener; + } + + @Override + public VersionHandshakeResponse read(StreamInput in) throws IOException { + return new VersionHandshakeResponse(in); + } + + @Override + public void handleResponse(VersionHandshakeResponse response) { + if (isDone.compareAndSet(false, true)) { + Version version = response.version; + if (currentVersion.isCompatible(version) == false) { + listener.onFailure(new IllegalStateException("Received message from unsupported version: [" + version + + "] minimal compatible version is: [" + currentVersion.minimumCompatibilityVersion() + "]")); + } else { + listener.onResponse(version); + } + } + } + + @Override + public void handleException(TransportException e) { + if (isDone.compareAndSet(false, true)) { + listener.onFailure(new IllegalStateException("handshake failed", e)); + } + } + + void handleLocalException(TransportException e) { + if (removeHandlerForHandshake(requestId) != null && isDone.compareAndSet(false, true)) { + listener.onFailure(e); + } + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + } + + static final class VersionHandshakeResponse extends TransportResponse { + + private final Version version; + + VersionHandshakeResponse(Version version) { + this.version = version; + } + + private VersionHandshakeResponse(StreamInput in) throws IOException { + super.readFrom(in); + version = Version.readVersion(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + assert version != null; + Version.writeVersion(version, out); + } + } + + @FunctionalInterface + interface HandshakeRequestSender { + + void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException; + } + + @FunctionalInterface + interface HandshakeResponseSender { + + void sendResponse(Version version, Set features, TcpChannel channel, TransportResponse response, long requestId) + throws IOException; + } +} diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java index 4e8f7cb3e4d1a..aadffe3d363c7 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java @@ -21,7 +21,6 @@ import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.UnavailableShardsException; import org.elasticsearch.action.admin.indices.close.CloseIndexRequest; @@ -985,8 +984,7 @@ public void testRetryOnReplicaWithRealTransport() throws Exception { final ReplicationTask task = maybeTask(); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); final Transport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, - new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), - Version.CURRENT); + new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList())); transportService = new MockTransportService(Settings.EMPTY, transport, threadPool, TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> clusterService.localNode(), null, Collections.emptySet()); transportService.start(); diff --git a/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java b/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java index 4ab738f5c7bc3..ed310ee305acf 100644 --- a/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/zen/UnicastZenPingTests.java @@ -377,8 +377,7 @@ public void testPortLimit() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -419,8 +418,7 @@ public void testRemovingLocalAddresses() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -465,8 +463,7 @@ public void testUnknownHost() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -512,8 +509,7 @@ public void testResolveTimeout() throws InterruptedException { BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - Version.CURRENT) { + networkService) { @Override public BoundTransportAddress boundAddress() { @@ -578,8 +574,7 @@ public void testResolveReuseExistingNodeConnections() throws ExecutionException, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), new NamedWriteableRegistry(Collections.emptyList()), - networkService, - v); + networkService); NetworkHandle handleA = startServices(settings, threadPool, "UZP_A", Version.CURRENT, supplier, EnumSet.allOf(Role.class)); closeables.push(handleA.transportService); diff --git a/server/src/test/java/org/elasticsearch/transport/TcpTransportHandshakerTests.java b/server/src/test/java/org/elasticsearch/transport/TcpTransportHandshakerTests.java new file mode 100644 index 0000000000000..23e3870842e20 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/transport/TcpTransportHandshakerTests.java @@ -0,0 +1,135 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class TcpTransportHandshakerTests extends ESTestCase { + + private TcpTransportHandshaker handshaker; + private DiscoveryNode node; + private TcpChannel channel; + private TestThreadPool threadPool; + private TcpTransportHandshaker.HandshakeRequestSender requestSender; + private TcpTransportHandshaker.HandshakeResponseSender responseSender; + + @Override + public void setUp() throws Exception { + super.setUp(); + String nodeId = "node-id"; + channel = mock(TcpChannel.class); + requestSender = mock(TcpTransportHandshaker.HandshakeRequestSender.class); + responseSender = mock(TcpTransportHandshaker.HandshakeResponseSender.class); + node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(), + Collections.emptySet(), Version.CURRENT); + threadPool = new TestThreadPool("thread-poll"); + handshaker = new TcpTransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender); + } + + @Override + public void tearDown() throws Exception { + threadPool.shutdown(); + super.tearDown(); + } + + public void testHandshakeRequestAndResponse() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture); + + verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); + + assertFalse(versionFuture.isDone()); + + TcpChannel mockChannel = mock(TcpChannel.class); + handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId); + + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TransportResponse.class); + verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(), + eq(reqId)); + + TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); + handler.handleResponse((TcpTransportHandshaker.VersionHandshakeResponse) responseCaptor.getValue()); + + assertTrue(versionFuture.isDone()); + assertEquals(Version.CURRENT, versionFuture.actionGet()); + } + + public void testHandshakeError() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture); + + verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); + + assertFalse(versionFuture.isDone()); + + TransportResponseHandler handler = handshaker.removeHandlerForHandshake(reqId); + handler.handleException(new TransportException("failed")); + + assertTrue(versionFuture.isDone()); + IllegalStateException ise = expectThrows(IllegalStateException.class, versionFuture::actionGet); + assertThat(ise.getMessage(), containsString("handshake failed")); + } + + public void testSendRequestThrowsException() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + Version compatibilityVersion = Version.CURRENT.minimumCompatibilityVersion(); + doThrow(new IOException("boom")).when(requestSender).sendRequest(node, channel, reqId, compatibilityVersion); + + handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture); + + + assertTrue(versionFuture.isDone()); + ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet); + assertThat(cte.getMessage(), containsString("failure to send internal:tcp/handshake")); + assertNull(handshaker.removeHandlerForHandshake(reqId)); + } + + public void testHandshakeTimeout() throws IOException { + PlainActionFuture versionFuture = PlainActionFuture.newFuture(); + long reqId = randomLongBetween(1, 10); + handshaker.sendHandshake(reqId, node, channel, new TimeValue(100, TimeUnit.MILLISECONDS), versionFuture); + + verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion()); + + ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet); + assertThat(cte.getMessage(), containsString("handshake_timeout")); + + assertNull(handshaker.removeHandlerForHandshake(reqId)); + } +} diff --git a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java index 533a8feb7808c..1d94f8e81d3a8 100644 --- a/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java @@ -176,7 +176,7 @@ public void testCompressRequest() throws IOException { AtomicReference messageCaptor = new AtomicReference<>(); try { TcpTransport transport = new TcpTransport( - "test", Settings.builder().put("transport.tcp.compress", compressed).build(), threadPool, + "test", Settings.builder().put("transport.tcp.compress", compressed).build(), Version.CURRENT, threadPool, new BigArrays(new PageCacheRecycler(Settings.EMPTY), null), null, null, null) { @Override @@ -185,7 +185,7 @@ protected FakeChannel bind(String name, InetSocketAddress address) throws IOExce } @Override - protected FakeChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException { + protected FakeChannel initiateChannel(DiscoveryNode node) throws IOException { return new FakeChannel(messageCaptor); } @@ -253,6 +253,10 @@ public void close() { public void addCloseListener(ActionListener listener) { } + @Override + public void addConnectListener(ActionListener listener) { + } + @Override public void setSoLinger(int value) throws IOException { } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 230b055908243..de9bf4a1b0b4e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -196,13 +196,7 @@ public void tearDown() throws Exception { assertNoPendingHandshakes(serviceA.getOriginalTransport()); assertNoPendingHandshakes(serviceB.getOriginalTransport()); } finally { - IOUtils.close(serviceA, serviceB, () -> { - try { - terminate(threadPool); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); + IOUtils.close(serviceA, serviceB, () -> terminate(threadPool)); } } @@ -2035,9 +2029,10 @@ protected String handleRequest(TcpChannel mockChannel, String profileName, Strea TcpTransport.NodeChannels connection = originalTransport.openConnection( new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0), connectionProfile)) { - Version version = originalTransport.executeHandshake(connection.getNode(), - connection.channel(TransportRequestOptions.Type.PING), TimeValue.timeValueSeconds(10)); - assertEquals(version, Version.CURRENT); + PlainActionFuture listener = PlainActionFuture.newFuture(); + originalTransport.executeHandshake(connection.getNode(), connection.channel(TransportRequestOptions.Type.PING), + TimeValue.timeValueSeconds(10), listener); + assertEquals(listener.actionGet(), Version.CURRENT); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 105961aa3230f..586109f7aba97 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -20,6 +20,7 @@ import org.elasticsearch.cli.SuppressForbidden; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.concurrent.CompletableContext; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; @@ -88,7 +89,6 @@ public class MockTcpTransport extends TcpTransport { } private final ExecutorService executor; - private final Version mockVersion; public MockTcpTransport(Settings settings, ThreadPool threadPool, BigArrays bigArrays, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, @@ -100,11 +100,11 @@ public MockTcpTransport(Settings settings, ThreadPool threadPool, BigArrays bigA public MockTcpTransport(Settings settings, ThreadPool threadPool, BigArrays bigArrays, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService, Version mockVersion) { - super("mock-tcp-transport", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); + super("mock-tcp-transport", settings, mockVersion, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, + networkService); // we have our own crazy cached threadpool this one is not bounded at all... // using the ES thread factory here is crucial for tests otherwise disruption tests won't block that thread executor = Executors.newCachedThreadPool(EsExecutors.daemonThreadFactory(settings, Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX)); - this.mockVersion = mockVersion; } @Override @@ -157,20 +157,13 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx output.write(minimalHeader); output.writeInt(msgSize); output.write(buffer); - final BytesReference bytes = output.bytes(); - if (TcpTransport.validateMessageHeader(bytes)) { - InetSocketAddress remoteAddress = (InetSocketAddress) socket.getRemoteSocketAddress(); - messageReceived(bytes.slice(TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE, msgSize), - mockChannel, mockChannel.profile, remoteAddress, msgSize); - } else { - // ping message - we just drop all stuff - } + consumeNetworkReads(mockChannel, output.bytes()); } } @Override @SuppressForbidden(reason = "real socket for mocking remote connections") - protected MockChannel initiateChannel(DiscoveryNode node, ActionListener connectListener) throws IOException { + protected MockChannel initiateChannel(DiscoveryNode node) throws IOException { InetSocketAddress address = node.getAddress().address(); final MockSocket socket = new MockSocket(); final MockChannel channel = new MockChannel(socket, address, "none"); @@ -183,16 +176,16 @@ protected MockChannel initiateChannel(DiscoveryNode node, ActionListener c if (success == false) { IOUtils.close(socket); } - } executor.submit(() -> { try { socket.connect(address); + socket.setSoLinger(false, 0); + channel.connectFuture.complete(null); channel.loopRead(executor); - connectListener.onResponse(null); } catch (Exception ex) { - connectListener.onFailure(ex); + channel.connectFuture.completeExceptionally(ex); } }); @@ -243,7 +236,8 @@ public final class MockChannel implements Closeable, TcpChannel { private final Socket activeChannel; private final String profile; private final CancellableThreads cancellableThreads = new CancellableThreads(); - private final CompletableFuture closeFuture = new CompletableFuture<>(); + private final CompletableContext closeFuture = new CompletableContext<>(); + private final CompletableContext connectFuture = new CompletableContext<>(); /** * Constructs a new MockChannel instance intended for handling the actual incoming / outgoing traffic. @@ -384,7 +378,12 @@ public void close() { @Override public void addCloseListener(ActionListener listener) { - closeFuture.whenComplete(ActionListener.toBiConsumer(listener)); + closeFuture.addListener(ActionListener.toBiConsumer(listener)); + } + + @Override + public void addConnectListener(ActionListener listener) { + connectFuture.addListener(ActionListener.toBiConsumer(listener)); } @Override @@ -392,7 +391,6 @@ public void setSoLinger(int value) throws IOException { if (activeChannel != null && activeChannel.isClosed() == false) { activeChannel.setSoLinger(true, value); } - } @Override @@ -448,10 +446,5 @@ protected void stopInternal() { assert openChannels.isEmpty() : "there are still open channels: " + openChannels; } } - - @Override - protected Version getCurrentVersion() { - return mockVersion; - } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java index e8b5f38b88df1..1e5c6092687a6 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -29,7 +30,6 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.transport.MockTransportService; -import java.io.IOException; import java.util.Collections; public class MockTcpTransportTests extends AbstractSimpleTransportTestCase { @@ -39,13 +39,13 @@ protected MockTransportService build(Settings settings, Version version, Cluster NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Transport transport = new MockTcpTransport(settings, threadPool, BigArrays.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version) { + @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel mockChannel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, mockChannel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } }; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index df1f985d46b97..1e813065e80b7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core; +import org.elasticsearch.Version; import org.elasticsearch.action.GenericAction; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.NamedDiff; @@ -470,7 +471,7 @@ public Map> getTransports( } catch (Exception e) { throw new RuntimeException(e); } - return Collections.singletonMap(SecurityField.NAME4, () -> new SecurityNetty4Transport(settings, threadPool, + return Collections.singletonMap(SecurityField.NAME4, () -> new SecurityNetty4Transport(settings, Version.CURRENT, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, sslService)); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java index 9253f29741c8e..46baa2925ea55 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java @@ -12,6 +12,7 @@ import io.netty.channel.ChannelPromise; import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -53,13 +54,14 @@ public class SecurityNetty4Transport extends Netty4Transport { public SecurityNetty4Transport( final Settings settings, + final Version version, final ThreadPool threadPool, final NetworkService networkService, final BigArrays bigArrays, final NamedWriteableRegistry namedWriteableRegistry, final CircuitBreakerService circuitBreakerService, final SSLService sslService) { - super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); + super(settings, version, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); this.sslService = sslService; this.sslEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings); if (sslEnabled) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index e9df82fdd5ea4..a0137721c8b44 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -1002,8 +1002,8 @@ public Map> getTransports(Settings settings, ThreadP if (transportClientMode || enabled == false) { // don't register anything if we are not enabled, or in transport client mode return Collections.emptyMap(); } - return Collections.singletonMap(Security.NAME4, () -> new SecurityNetty4ServerTransport(settings, threadPool, networkService, - bigArrays, namedWriteableRegistry, circuitBreakerService, ipFilter.get(), getSslService())); + return Collections.singletonMap(Security.NAME4, () -> new SecurityNetty4ServerTransport(settings, Version.CURRENT, threadPool, + networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, ipFilter.get(), getSslService())); } @Override diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java index e0794d037e33d..d74aa65e94bee 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransport.java @@ -7,6 +7,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -25,6 +26,7 @@ public class SecurityNetty4ServerTransport extends SecurityNetty4Transport { public SecurityNetty4ServerTransport( final Settings settings, + final Version version, final ThreadPool threadPool, final NetworkService networkService, final BigArrays bigArrays, @@ -32,7 +34,7 @@ public SecurityNetty4ServerTransport( final CircuitBreakerService circuitBreakerService, @Nullable final IPFilter authenticator, final SSLService sslService) { - super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, sslService); + super(settings, version, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, sslService); this.authenticator = authenticator; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java index 54b7b6ba7330b..3bd9d7f2e9ec7 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/AbstractSimpleSecurityTransportTestCase.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.security.transport; import org.elasticsearch.Version; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; @@ -106,7 +107,7 @@ public void testBindUnavailableAddress() { } @Override - public void testTcpHandshake() throws IOException, InterruptedException { + public void testTcpHandshake() throws InterruptedException { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); TcpTransport originalTransport = (TcpTransport) serviceA.getOriginalTransport(); @@ -115,9 +116,10 @@ public void testTcpHandshake() throws IOException, InterruptedException { TcpTransport.NodeChannels connection = originalTransport.openConnection( new DiscoveryNode("TS_TPC", "TS_TPC", service.boundAddress().publishAddress(), emptyMap(), emptySet(), version0), connectionProfile)) { - Version version = originalTransport.executeHandshake(connection.getNode(), - connection.channel(TransportRequestOptions.Type.PING), TimeValue.timeValueSeconds(10)); - assertEquals(version, Version.CURRENT); + PlainActionFuture listener = PlainActionFuture.newFuture(); + originalTransport.executeHandshake(connection.getNode(), connection.channel(TransportRequestOptions.Type.PING), + TimeValue.timeValueSeconds(10), listener); + assertEquals(listener.actionGet(), Version.CURRENT); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java index e9d91f5bd2d6a..dc6bffe5c7271 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4ServerTransportTests.java @@ -8,6 +8,7 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.ssl.SslHandler; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.MockSecureSettings; @@ -68,6 +69,7 @@ private SecurityNetty4Transport createTransport(Settings additionalSettings) { .build(); return new SecurityNetty4ServerTransport( settings, + Version.CURRENT, mock(ThreadPool.class), new NetworkService(Collections.emptyList()), mock(BigArrays.class), diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java index 291b39f4b05ba..8c4dcf9e2fac5 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4ServerTransportTests.java @@ -13,6 +13,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.ssl.SslHandler; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -39,7 +40,6 @@ import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; -import java.io.IOException; import java.net.InetSocketAddress; import java.util.Collections; import java.util.EnumSet; @@ -72,25 +72,18 @@ public MockTransportService nettyFromThreadPool(Settings settings, ThreadPool th Settings settings1 = Settings.builder() .put(settings) .put("xpack.security.transport.ssl.enabled", true).build(); - Transport transport = new SecurityNetty4ServerTransport(settings1, threadPool, + Transport transport = new SecurityNetty4ServerTransport(settings1, version, threadPool, networkService, BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService(), null, createSSLService(settings1)) { @Override - public Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, - InterruptedException { + public void executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener listener) { if (doHandshake) { - return super.executeHandshake(node, channel, timeout); + super.executeHandshake(node, channel, timeout, listener); } else { - return version.minimumCompatibilityVersion(); + listener.onResponse(version.minimumCompatibilityVersion()); } } - - @Override - protected Version getCurrentVersion() { - return version; - } - }; MockTransportService mockTransportService = MockTransportService.createNewService(settings, transport, version, threadPool, clusterSettings,