From 60f834eeac06e599718543a1d967c4d068e40f9c Mon Sep 17 00:00:00 2001 From: Aman Khare Date: Mon, 8 May 2023 20:52:08 +0530 Subject: [PATCH] Add new handler for sending compressed cluster state in validate join flow and refactor code Signed-off-by: Aman Khare --- .../cluster/coordination/ZenDiscoveryIT.java | 7 +- .../opensearch/cluster/CompressionHelper.java | 49 ---- .../cluster/coordination/JoinHelper.java | 116 ++++----- .../PublicationTransportHandler.java | 71 +----- .../common/compress/CompressionHelper.java | 100 ++++++++ .../common/settings/ClusterSettings.java | 2 +- .../cluster/coordination/JoinHelperTests.java | 227 ++++++++++++++---- .../cluster/coordination/NodeJoinTests.java | 2 +- 8 files changed, 359 insertions(+), 215 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/cluster/CompressionHelper.java create mode 100644 server/src/main/java/org/opensearch/common/compress/CompressionHelper.java diff --git a/server/src/internalClusterTest/java/org/opensearch/cluster/coordination/ZenDiscoveryIT.java b/server/src/internalClusterTest/java/org/opensearch/cluster/coordination/ZenDiscoveryIT.java index aaba53dcb2b07..dc9034ab53e73 100644 --- a/server/src/internalClusterTest/java/org/opensearch/cluster/coordination/ZenDiscoveryIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/cluster/coordination/ZenDiscoveryIT.java @@ -42,6 +42,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Priority; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; @@ -59,6 +60,7 @@ import java.util.concurrent.TimeoutException; import static org.opensearch.action.admin.cluster.node.stats.NodesStatsRequest.Metric.DISCOVERY; +import static org.opensearch.cluster.coordination.JoinHelper.CLUSTER_MANAGER_VALIDATE_JOIN_CACHE_INTERVAL; import static org.opensearch.test.NodeRoles.dataNode; import static org.opensearch.test.NodeRoles.clusterManagerOnlyNode; import static org.hamcrest.Matchers.containsString; @@ -106,7 +108,9 @@ public void testNoShardRelocationsOccurWhenElectedClusterManagerNodeFails() thro } public void testHandleNodeJoin_incompatibleClusterState() throws InterruptedException, ExecutionException, TimeoutException { - String clusterManagerNode = internalCluster().startClusterManagerOnlyNode(); + String clusterManagerNode = internalCluster().startClusterManagerOnlyNode( + Settings.builder().put(CLUSTER_MANAGER_VALIDATE_JOIN_CACHE_INTERVAL.getKey(), TimeValue.timeValueMillis(0)).build() + ); String node1 = internalCluster().startNode(); ClusterService clusterService = internalCluster().getInstance(ClusterService.class, node1); Coordinator coordinator = (Coordinator) internalCluster().getInstance(Discovery.class, clusterManagerNode); @@ -117,7 +121,6 @@ public void testHandleNodeJoin_incompatibleClusterState() throws InterruptedExce final CompletableFuture future = new CompletableFuture<>(); DiscoveryNode node = state.nodes().getLocalNode(); - coordinator.sendValidateJoinRequest( stateWithCustomMetadata, new JoinRequest(node, 0L, Optional.empty()), diff --git a/server/src/main/java/org/opensearch/cluster/CompressionHelper.java b/server/src/main/java/org/opensearch/cluster/CompressionHelper.java deleted file mode 100644 index d2206a77063bd..0000000000000 --- a/server/src/main/java/org/opensearch/cluster/CompressionHelper.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.cluster; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.Version; -import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.compress.CompressorFactory; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.OutputStreamStreamOutput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; - -import java.io.IOException; - -/** - * A helper class to utilize the compressed stream. - */ -public class CompressionHelper { - private static final Logger logger = LogManager.getLogger(CompressionHelper.class); - - /** - * It'll always use compression before writing on a newly created output stream. - * @param writer Object which is going to write the content - * @param nodeVersion version of cluster node - * @param streamBooleanFlag flag used at receiver end to make intelligent decisions. For example, ClusterState - * assumes full state of diff of the states based on this flag. - * @return reference to serialized bytes - * @throws IOException if writing on the compressed stream is failed. - */ - public static BytesReference serializedWrite(Writeable writer, Version nodeVersion, boolean streamBooleanFlag) throws IOException { - final BytesStreamOutput bStream = new BytesStreamOutput(); - try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) { - stream.setVersion(nodeVersion); - stream.writeBoolean(streamBooleanFlag); - writer.writeTo(stream); - } - final BytesReference serializedByteRef = bStream.bytes(); - logger.trace("serialized writable object for node version [{}] with size [{}]", nodeVersion, serializedByteRef.length()); - return serializedByteRef; - } -} diff --git a/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java b/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java index aa64e096ceb72..37371be2fb659 100644 --- a/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java +++ b/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java @@ -41,7 +41,9 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.cluster.ClusterStateTaskListener; -import org.opensearch.cluster.CompressionHelper; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; +import org.opensearch.common.compress.CompressionHelper; import org.opensearch.cluster.NotClusterManagerException; import org.opensearch.cluster.coordination.Coordinator.Mode; import org.opensearch.cluster.decommission.NodeDecommissionedException; @@ -53,10 +55,6 @@ import org.opensearch.common.Priority; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.collect.Tuple; -import org.opensearch.common.compress.Compressor; -import org.opensearch.common.compress.CompressorFactory; -import org.opensearch.common.io.stream.InputStreamStreamInput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Setting; @@ -79,6 +77,7 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.io.InvalidObjectException; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -108,6 +107,7 @@ public class JoinHelper { public static final String JOIN_ACTION_NAME = "internal:cluster/coordination/join"; public static final String VALIDATE_JOIN_ACTION_NAME = "internal:cluster/coordination/join/validate"; + public static final String VALIDATE_COMPRESSED_JOIN_ACTION_NAME = "internal:cluster/coordination/join" + "/validate_compressed"; public static final String START_JOIN_ACTION_NAME = "internal:cluster/coordination/start_join"; // the timeout for Zen1 join attempts @@ -133,10 +133,9 @@ public class JoinHelper { private final Supplier joinTaskExecutorGenerator; private final Consumer nodeCommissioned; private final NamedWriteableRegistry namedWriteableRegistry; - private final Map serializedStates = new HashMap<>(); - private long lastRefreshTime = 0L; - public static final Setting CLUSTER_MANAGER_JOIN_STATE_REFRESH_INTERVAL = Setting.timeSetting( - "cluster_manager.join.state.refresh_interval", + private Cache serializedStates; + public static final Setting CLUSTER_MANAGER_VALIDATE_JOIN_CACHE_INTERVAL = Setting.timeSetting( + "cluster_manager.validate_join.cache_interval", TimeValue.timeValueMillis(30000), TimeValue.timeValueMillis(0), TimeValue.timeValueMillis(60000), @@ -165,7 +164,13 @@ public class JoinHelper { this.joinTimeout = JOIN_TIMEOUT_SETTING.get(settings); this.nodeCommissioned = nodeCommissioned; this.namedWriteableRegistry = namedWriteableRegistry; - this.clusterStateRefreshInterval = CLUSTER_MANAGER_JOIN_STATE_REFRESH_INTERVAL.get(settings).getMillis(); + this.clusterStateRefreshInterval = CLUSTER_MANAGER_VALIDATE_JOIN_CACHE_INTERVAL.get(settings).getMillis(); + if (clusterStateRefreshInterval != 0) { + CacheBuilder cacheBuilder = CacheBuilder.builder(); + cacheBuilder.setExpireAfterWrite(CLUSTER_MANAGER_VALIDATE_JOIN_CACHE_INTERVAL.get(settings)); + this.serializedStates = cacheBuilder.build(); + } + this.joinTaskExecutorGenerator = () -> new JoinTaskExecutor(settings, allocationService, logger, rerouteService) { private final long term = currentTermSupplier.getAsLong(); @@ -230,12 +235,43 @@ public ClusterTasksResult execute(ClusterState currentSta transportService.registerRequestHandler( VALIDATE_JOIN_ACTION_NAME, ThreadPool.Names.GENERIC, + ValidateJoinRequest::new, + (request, channel, task) -> { + runJoinValidators(currentStateSupplier, request.getState(), joinValidators); + channel.sendResponse(Empty.INSTANCE); + } + ); + + transportService.registerRequestHandler( + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, + ThreadPool.Names.GENERIC, BytesTransportRequest::new, (request, channel, task) -> { handleValidateJoinRequest(currentStateSupplier, joinValidators, request); channel.sendResponse(Empty.INSTANCE); } ); + + } + + private void runJoinValidators( + Supplier currentStateSupplier, + ClusterState incomingState, + Collection> joinValidators + ) { + final ClusterState localState = currentStateSupplier.get(); + if (localState.metadata().clusterUUIDCommitted() + && localState.metadata().clusterUUID().equals(incomingState.metadata().clusterUUID()) == false) { + throw new CoordinationStateRejectedException( + "join validation on cluster state" + + " with a different cluster uuid " + + incomingState.metadata().clusterUUID() + + " than local cluster uuid " + + localState.metadata().clusterUUID() + + ", rejecting" + ); + } + joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), incomingState)); } private void handleValidateJoinRequest( @@ -243,38 +279,15 @@ private void handleValidateJoinRequest( Collection> joinValidators, BytesTransportRequest request ) throws IOException { - final Compressor compressor = CompressorFactory.compressor(request.bytes()); - StreamInput in = request.bytes().streamInput(); - final ClusterState incomingState; + StreamInput in = CompressionHelper.decompressClusterState(request, namedWriteableRegistry); + ClusterState incomingState; try { - if (compressor != null) { - in = new InputStreamStreamInput(compressor.threadLocalInputStream(in)); - } - in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry); - in.setVersion(request.version()); - // If true we received full cluster state - otherwise diffs if (in.readBoolean()) { - // Close early to release resources used by the de-compression as early as possible - try (StreamInput input = in) { - incomingState = ClusterState.readFrom(input, transportService.getLocalNode()); - } catch (Exception e) { - logger.warn("unexpected error while deserializing an incoming cluster state", e); - throw e; - } - - final ClusterState localState = currentStateSupplier.get(); - if (localState.metadata().clusterUUIDCommitted() - && localState.metadata().clusterUUID().equals(incomingState.metadata().clusterUUID()) == false) { - throw new CoordinationStateRejectedException( - "join validation on cluster state" - + " with a different cluster uuid " - + incomingState.metadata().clusterUUID() - + " than local cluster uuid " - + localState.metadata().clusterUUID() - + ", rejecting" - ); - } - joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), incomingState)); + incomingState = CompressionHelper.deserializeFullClusterState(in, transportService); + runJoinValidators(currentStateSupplier, incomingState, joinValidators); + } else { + logger.error("validate new node join request requires full cluster state"); + throw new InvalidObjectException("validate new node join request requires full cluster state"); } } finally { IOUtils.close(in); @@ -463,24 +476,19 @@ public String executor() { public void sendValidateJoinRequest(DiscoveryNode node, ClusterState state, ActionListener listener) { try { - BytesReference bytes = serializedStates.get(node.getVersion()); - // Refresh serializedStates map every time if clusterStateRefreshInterval is 0 - if (bytes == null || (System.currentTimeMillis() >= lastRefreshTime + clusterStateRefreshInterval)) { - try { - // Re-getting current cluster state for validate join request - bytes = CompressionHelper.serializedWrite(state, node.getVersion(), true); - serializedStates.put(node.getVersion(), bytes); - lastRefreshTime = System.currentTimeMillis(); - } catch (Exception e) { - logger.warn(() -> new ParameterizedMessage("failed to serialize cluster state during validateJoin" + " {}", node), e); - listener.onFailure(e); - return; - } + BytesReference bytes; + if (clusterStateRefreshInterval == 0) { + bytes = CompressionHelper.serializeClusterState(state, node, true); + } else { + bytes = serializedStates.computeIfAbsent( + node.getVersion(), + key -> CompressionHelper.serializeClusterState(state, node, true) + ); } final BytesTransportRequest request = new BytesTransportRequest(bytes, node.getVersion()); transportService.sendRequest( node, - VALIDATE_JOIN_ACTION_NAME, + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, request, new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC) ); diff --git a/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java b/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java index aab91bc72ae83..00c1976619c2d 100644 --- a/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java +++ b/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java @@ -39,21 +39,14 @@ import org.opensearch.action.ActionListener; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.CompressionHelper; +import org.opensearch.common.compress.CompressionHelper; import org.opensearch.cluster.Diff; import org.opensearch.cluster.IncompatibleClusterStateVersionException; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.compress.Compressor; -import org.opensearch.common.compress.CompressorFactory; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.InputStreamStreamInput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.io.stream.OutputStreamStreamOutput; import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.util.io.IOUtils; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.BytesTransportRequest; @@ -169,24 +162,11 @@ public PublishClusterStateStats stats() { } private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportRequest request) throws IOException { - final Compressor compressor = CompressorFactory.compressor(request.bytes()); - StreamInput in = request.bytes().streamInput(); + StreamInput in = CompressionHelper.decompressClusterState(request, namedWriteableRegistry); + ClusterState incomingState; try { - if (compressor != null) { - in = new InputStreamStreamInput(compressor.threadLocalInputStream(in)); - } - in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry); - in.setVersion(request.version()); - // If true we received full cluster state - otherwise diffs if (in.readBoolean()) { - final ClusterState incomingState; - // Close early to release resources used by the de-compression as early as possible - try (StreamInput input = in) { - incomingState = ClusterState.readFrom(input, transportService.getLocalNode()); - } catch (Exception e) { - logger.warn("unexpected error while deserializing an incoming cluster state", e); - throw e; - } + incomingState = CompressionHelper.deserializeFullClusterState(in, transportService); fullClusterStateReceivedCount.incrementAndGet(); logger.debug("received full cluster state version [{}] with size [{}]", incomingState.version(), request.bytes().length()); final PublishWithJoinResponse response = acceptState(incomingState); @@ -199,20 +179,12 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque incompatibleClusterStateDiffReceivedCount.incrementAndGet(); throw new IncompatibleClusterStateVersionException("have no local cluster state"); } else { - ClusterState incomingState; try { - final Diff diff; - // Close stream early to release resources used by the de-compression as early as possible - try (StreamInput input = in) { - diff = ClusterState.readDiffFrom(input, lastSeen.nodes().getLocalNode()); - } + Diff diff = CompressionHelper.deserializeClusterStateDiff(in, lastSeen.getNodes().getLocalNode()); incomingState = diff.apply(lastSeen); // might throw IncompatibleClusterStateVersionException } catch (IncompatibleClusterStateVersionException e) { incompatibleClusterStateDiffReceivedCount.incrementAndGet(); throw e; - } catch (Exception e) { - logger.warn("unexpected error while deserializing an incoming cluster state", e); - throw e; } compatibleClusterStateDiffReceivedCount.incrementAndGet(); logger.debug( @@ -254,33 +226,6 @@ public PublicationContext newPublicationContext(ClusterChangedEvent clusterChang return publicationContext; } - private static BytesReference serializeFullClusterState(ClusterState clusterState, Version nodeVersion) throws IOException { - final BytesStreamOutput bStream = new BytesStreamOutput(); - try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) { - stream.setVersion(nodeVersion); - stream.writeBoolean(true); - clusterState.writeTo(stream); - } - final BytesReference serializedState = bStream.bytes(); - logger.trace( - "serialized full cluster state version [{}] for node version [{}] with size [{}]", - clusterState.version(), - nodeVersion, - serializedState.length() - ); - return serializedState; - } - - private static BytesReference serializeDiffClusterState(Diff diff, Version nodeVersion) throws IOException { - final BytesStreamOutput bStream = new BytesStreamOutput(); - try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) { - stream.setVersion(nodeVersion); - stream.writeBoolean(false); - diff.writeTo(stream); - } - return bStream.bytes(); - } - /** * Publishing a cluster state typically involves sending the same cluster state (or diff) to every node, so the work of diffing, * serializing, and compressing the state can be done once and the results shared across publish requests. The @@ -310,7 +255,7 @@ void buildDiffAndSerializeStates() { try { if (sendFullVersion || previousState.nodes().nodeExists(node) == false) { if (serializedStates.containsKey(node.getVersion()) == false) { - serializedStates.put(node.getVersion(), CompressionHelper.serializedWrite(newState, node.getVersion(), true)); + serializedStates.put(node.getVersion(), CompressionHelper.serializeClusterState(newState, node, true)); } } else { // will send a diff @@ -318,7 +263,7 @@ void buildDiffAndSerializeStates() { diff = newState.diff(previousState); } if (serializedDiffs.containsKey(node.getVersion()) == false) { - final BytesReference serializedDiff = CompressionHelper.serializedWrite(newState, node.getVersion(), false); + final BytesReference serializedDiff = CompressionHelper.serializeClusterState(diff, node, false); serializedDiffs.put(node.getVersion(), serializedDiff); logger.trace( "serialized cluster state diff for version [{}] in for node version [{}] with size [{}]", @@ -414,7 +359,7 @@ private void sendFullClusterState(DiscoveryNode destination, ActionListener new ParameterizedMessage("failed to serialize cluster state during validateJoin" + " {}", node), e); + throw e; + } + } + + public static StreamInput decompressClusterState(BytesTransportRequest request, NamedWriteableRegistry namedWriteableRegistry) + throws IOException { + final Compressor compressor = CompressorFactory.compressor(request.bytes()); + StreamInput in = request.bytes().streamInput(); + if (compressor != null) { + in = new InputStreamStreamInput(compressor.threadLocalInputStream(in)); + } + in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry); + in.setVersion(request.version()); + return in; + } + + public static ClusterState deserializeFullClusterState(StreamInput in, TransportService transportService) throws IOException { + final ClusterState incomingState; + try (StreamInput input = in) { + incomingState = ClusterState.readFrom(input, transportService.getLocalNode()); + } catch (Exception e) { + logger.warn("unexpected error while deserializing an incoming cluster state", e); + throw e; + } + return incomingState; + } + + public static Diff deserializeClusterStateDiff(StreamInput in, DiscoveryNode localNode) throws IOException { + final Diff diff; + // Close stream early to release resources used by the de-compression as early as possible + try (StreamInput input = in) { + diff = ClusterState.readDiffFrom(input, localNode); + } catch (Exception e) { + logger.warn("unexpected error while deserializing an incoming cluster state diff", e); + throw e; + } + return diff; + } +} diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index a4ad52b4dfc6f..5f74cfa60c48a 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -639,7 +639,7 @@ public void apply(Settings value, Settings current, Settings previous) { // Settings related to Searchable Snapshots Node.NODE_SEARCH_CACHE_SIZE_SETTING, - JoinHelper.CLUSTER_MANAGER_JOIN_STATE_REFRESH_INTERVAL + JoinHelper.CLUSTER_MANAGER_VALIDATE_JOIN_CACHE_INTERVAL ) ) ); diff --git a/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java b/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java index 0ee2fbd7e1dd5..55b529bb99f99 100644 --- a/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java +++ b/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java @@ -33,17 +33,19 @@ import org.apache.logging.log4j.Level; import org.opensearch.Version; +import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.CompressionHelper; +import org.opensearch.common.compress.CompressionHelper; import org.opensearch.cluster.NotClusterManagerException; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.UncategorizedExecutionException; import org.opensearch.monitor.StatusInfo; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.CapturingTransport; @@ -52,14 +54,20 @@ import org.opensearch.transport.BytesTransportRequest; import org.opensearch.transport.RemoteTransportException; import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportService; import java.io.IOException; import java.util.Collections; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.cluster.coordination.JoinHelper.VALIDATE_COMPRESSED_JOIN_ACTION_NAME; import static org.opensearch.monitor.StatusInfo.Status.HEALTHY; import static org.opensearch.monitor.StatusInfo.Status.UNHEALTHY; import static org.opensearch.node.Node.NODE_NAME_SETTING; @@ -203,6 +211,10 @@ public void testFailedJoinAttemptLogLevel() { } public void testJoinValidationRejectsMismatchedClusterUUID() throws IOException { + assertJoinValidationRejectsMismatchedClusterUUID( + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, + "join validation on cluster state with a different cluster uuid" + ); assertJoinValidationRejectsMismatchedClusterUUID( JoinHelper.VALIDATE_JOIN_ACTION_NAME, "join validation on cluster state with a different cluster uuid" @@ -210,61 +222,40 @@ public void testJoinValidationRejectsMismatchedClusterUUID() throws IOException } private void assertJoinValidationRejectsMismatchedClusterUUID(String actionName, String expectedMessage) throws IOException { - DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue( - Settings.builder().put(NODE_NAME_SETTING.getKey(), "node0").build(), - random() - ); - MockTransport mockTransport = new MockTransport(); - DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); - - final ClusterState localClusterState = ClusterState.builder(ClusterName.DEFAULT) - .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true)) - .build(); - - TransportService transportService = mockTransport.createTransportService( - Settings.EMPTY, - deterministicTaskQueue.getThreadPool(), - TransportService.NOOP_TRANSPORT_INTERCEPTOR, - x -> localNode, - null, - Collections.emptySet() - ); - new JoinHelper(Settings.EMPTY, null, null, transportService, () -> 0L, () -> localClusterState, (joinRequest, joinCallback) -> { - throw new AssertionError(); - }, - startJoinRequest -> { throw new AssertionError(); }, - Collections.emptyList(), - (s, p, r) -> {}, - null, - nodeCommissioned -> {}, - namedWriteableRegistry - ); // registers - // request - // handler - transportService.start(); - transportService.acceptIncomingRequests(); + TestClusterSetup testCluster = getTestClusterSetup(); final ClusterState otherClusterState = ClusterState.builder(ClusterName.DEFAULT) .metadata(Metadata.builder().generateClusterUuidIfNeeded()) .build(); - + TransportRequest request; final PlainActionFuture future = new PlainActionFuture<>(); - BytesReference bytes = CompressionHelper.serializedWrite(otherClusterState, localNode.getVersion(), true); - final BytesTransportRequest request = new BytesTransportRequest(bytes, localNode.getVersion()); - transportService.sendRequest( - localNode, - actionName, - request, - new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) - ); - deterministicTaskQueue.runAllTasks(); + if (actionName.equals(VALIDATE_COMPRESSED_JOIN_ACTION_NAME)) { + BytesReference bytes = CompressionHelper.serializeClusterState(otherClusterState, testCluster.localNode, true); + request = new BytesTransportRequest(bytes, testCluster.localNode.getVersion()); + testCluster.transportService.sendRequest( + testCluster.localNode, + actionName, + request, + new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) + ); + } else if (actionName.equals(JoinHelper.VALIDATE_JOIN_ACTION_NAME)) { + request = new ValidateJoinRequest(otherClusterState); + testCluster.transportService.sendRequest( + testCluster.localNode, + actionName, + request, + new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) + ); + } + + testCluster.deterministicTaskQueue.runAllTasks(); final CoordinationStateRejectedException coordinationStateRejectedException = expectThrows( CoordinationStateRejectedException.class, future::actionGet ); assertThat(coordinationStateRejectedException.getMessage(), containsString(expectedMessage)); - assertThat(coordinationStateRejectedException.getMessage(), containsString(localClusterState.metadata().clusterUUID())); + assertThat(coordinationStateRejectedException.getMessage(), containsString(testCluster.localClusterState.metadata().clusterUUID())); assertThat(coordinationStateRejectedException.getMessage(), containsString(otherClusterState.metadata().clusterUUID())); } @@ -339,4 +330,150 @@ public void testJoinFailureOnUnhealthyNodes() { CapturedRequest capturedRequest1a = capturedRequests1a[0]; assertEquals(node1, capturedRequest1a.node); } + + public void testSendValidateJoinFailsOnCompressionHelperException() throws IOException, ExecutionException, InterruptedException, + TimeoutException { + TestClusterSetup testCluster = getTestClusterSetup(); + final CompletableFuture future = new CompletableFuture<>(); + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, null, new ActionListener<>() { + @Override + public void onResponse(TransportResponse.Empty empty) { + future.completeExceptionally(new AssertionError("validate join should have failed")); + } + + @Override + public void onFailure(Exception e) { + future.complete(e); + } + }); + Throwable t = future.get(10, TimeUnit.SECONDS); + assertTrue(t instanceof ExecutionException); + assertTrue(t.getCause() instanceof NullPointerException); + } + + public void testJoinValidationFailsOnCompressionHelperException() throws IOException { + TestClusterSetup testCluster = getTestClusterSetup(); + final ClusterState otherClusterState = ClusterState.builder(ClusterName.DEFAULT) + .metadata(Metadata.builder().generateClusterUuidIfNeeded()) + .build(); + TransportRequest request; + final PlainActionFuture future = new PlainActionFuture<>(); + BytesReference bytes = CompressionHelper.serializeClusterState(otherClusterState, testCluster.localNode, false); + request = new BytesTransportRequest(bytes, testCluster.localNode.getVersion()); + testCluster.transportService.sendRequest( + testCluster.localNode, + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, + request, + new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) + ); + testCluster.deterministicTaskQueue.runAllTasks(); + final UncategorizedExecutionException invalidStateException = expectThrows( + UncategorizedExecutionException.class, + future::actionGet + ); + assertTrue(invalidStateException.getCause().getMessage().contains("requires full cluster state")); + } + + public void testJoinHelperCachingOnClusterState() throws IOException, ExecutionException, InterruptedException, TimeoutException { + TestClusterSetup testCluster = getTestClusterSetup(); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(TransportResponse.Empty empty) { + logger.info("validation successful"); + } + + @Override + public void onFailure(Exception e) { + future.completeExceptionally(new AssertionError("validate join should not fail here")); + } + }; + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, testCluster.localClusterState, listener); + // validation will pass due to cached cluster state + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, null, listener); + final CompletableFuture future2 = new CompletableFuture<>(); + ActionListener listener2 = new ActionListener<>() { + @Override + public void onResponse(TransportResponse.Empty empty) { + future2.completeExceptionally(new AssertionError("validation should fail now")); + } + + @Override + public void onFailure(Exception e) { + future2.complete(e); + } + }; + Thread.sleep(30 * 1000); + // now sending the validate join request will fail due to null cluster state + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, null, listener2); + Throwable t = future2.get(10, TimeUnit.SECONDS); + assertTrue(t instanceof ExecutionException); + assertTrue(t.getCause() instanceof NullPointerException); + } + + private TestClusterSetup getTestClusterSetup() { + DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue( + Settings.builder().put(NODE_NAME_SETTING.getKey(), "node0").build(), + random() + ); + MockTransport mockTransport = new MockTransport(); + DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); + + final ClusterState localClusterState = ClusterState.builder(ClusterName.DEFAULT) + .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true)) + .build(); + + TransportService transportService = mockTransport.createTransportService( + Settings.EMPTY, + deterministicTaskQueue.getThreadPool(), + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> localNode, + null, + Collections.emptySet() + ); + JoinHelper joinHelper = new JoinHelper( + Settings.EMPTY, + null, + null, + transportService, + () -> 0L, + () -> localClusterState, + (joinRequest, joinCallback) -> { + throw new AssertionError(); + }, + startJoinRequest -> { throw new AssertionError(); }, + Collections.emptyList(), + (s, p, r) -> {}, + null, + nodeCommissioned -> {}, + namedWriteableRegistry + ); // registers + // request + // handler + transportService.start(); + transportService.acceptIncomingRequests(); + return new TestClusterSetup(deterministicTaskQueue, localNode, transportService, localClusterState, joinHelper); + } + + private static class TestClusterSetup { + public final DeterministicTaskQueue deterministicTaskQueue; + public final DiscoveryNode localNode; + public final TransportService transportService; + public final ClusterState localClusterState; + public final JoinHelper joinHelper; + + public TestClusterSetup( + DeterministicTaskQueue deterministicTaskQueue, + DiscoveryNode localNode, + TransportService transportService, + ClusterState localClusterState, + JoinHelper joinHelper + ) { + this.deterministicTaskQueue = deterministicTaskQueue; + this.localNode = localNode; + this.transportService = transportService; + this.localClusterState = localClusterState; + this.joinHelper = joinHelper; + } + } } diff --git a/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java b/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java index 2752f57b499b3..8f32f6166ce6a 100644 --- a/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java +++ b/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java @@ -213,7 +213,7 @@ protected void onSendRequest(long requestId, String action, TransportRequest req requestId, new TransportService.HandshakeResponse(destination, initialState.getClusterName(), destination.getVersion()) ); - } else if (action.equals(JoinHelper.VALIDATE_JOIN_ACTION_NAME)) { + } else if (action.equals(JoinHelper.VALIDATE_COMPRESSED_JOIN_ACTION_NAME)) { handleResponse(requestId, new TransportResponse.Empty()); } else { super.onSendRequest(requestId, action, request, destination);