From 9162149683739362ce60383ffbc01f651ad6016e Mon Sep 17 00:00:00 2001 From: Aman Khare <85096200+amkhar@users.noreply.github.com> Date: Wed, 7 Jun 2023 01:24:37 +0530 Subject: [PATCH] Compress and cache cluster state during validate join request (#7321) * Compress and cache cluster state during validate join request Signed-off-by: Aman Khare * Add changelog and license Signed-off-by: Aman Khare * Add javadoc and correct styling Signed-off-by: Aman Khare * Add new handler for sending compressed cluster state in validate join flow and refactor code Signed-off-by: Aman Khare * Refactor util method Signed-off-by: Aman Khare * optimize imports Signed-off-by: Aman Khare * Use cluster state version based cache instead of time based cache Signed-off-by: Aman Khare * style fix Signed-off-by: Aman Khare * fix styling 2 Signed-off-by: Aman Khare * Use concurrent hashmap instead of cache, add UT class for ClusterStateUtils Signed-off-by: Aman Khare * style fix Signed-off-by: Aman Khare * Use AtomicReference instead of ConcurrentHashMap Signed-off-by: Aman Khare * Use method overloading to simplify the caller code Signed-off-by: Aman Khare * Resolve conflicts Signed-off-by: Aman Khare * Change code structure to separate the flow for JoinHelper and PublicationTransportHelper Signed-off-by: Aman Khare * Remove unnecessary input.setVersion line Co-authored-by: Andrew Ross Signed-off-by: Aman Khare <85096200+amkhar@users.noreply.github.com> --------- Signed-off-by: Aman Khare Signed-off-by: Aman Khare <85096200+amkhar@users.noreply.github.com> Co-authored-by: Aman Khare Co-authored-by: Andrew Ross (cherry picked from commit b17c88c8c33b9f627cbfbbd0cffba233bb17d6bb) --- CHANGELOG.md | 1 + .../coordination/CompressedStreamUtils.java | 61 ++++ .../cluster/coordination/Coordinator.java | 3 +- .../cluster/coordination/JoinHelper.java | 110 +++++-- .../PublicationTransportHandler.java | 37 +-- .../bootstrap/test-framework.policy | 4 + .../CompressedStreamUtilsTests.java | 65 ++++ .../cluster/coordination/JoinHelperTests.java | 282 +++++++++++++++--- .../cluster/coordination/NodeJoinTests.java | 23 +- 9 files changed, 485 insertions(+), 101 deletions(-) create mode 100644 server/src/main/java/org/opensearch/cluster/coordination/CompressedStreamUtils.java create mode 100644 server/src/test/java/org/opensearch/cluster/coordination/CompressedStreamUtilsTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a22b02bdf060..6f518bfeda2fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Snapshot Interop] Add Changes in Create Snapshot Flow for remote store interoperability. ([#8071](https://github.com/opensearch-project/OpenSearch/pull/8071)) - Allow insecure string settings to warn-log usage and advise to migration of a newer secure variant ([#5496](https://github.com/opensearch-project/OpenSearch/pull/5496)) - Pass localNode info to all plugins on node start ([#7919](https://github.com/opensearch-project/OpenSearch/pull/7919) +- Compress and cache cluster state during validate join request ([#7321](https://github.com/opensearch-project/OpenSearch/pull/7321)) ### Deprecated diff --git a/server/src/main/java/org/opensearch/cluster/coordination/CompressedStreamUtils.java b/server/src/main/java/org/opensearch/cluster/coordination/CompressedStreamUtils.java new file mode 100644 index 0000000000000..57359f553b5a5 --- /dev/null +++ b/server/src/main/java/org/opensearch/cluster/coordination/CompressedStreamUtils.java @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.coordination; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.Version; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.compress.Compressor; +import org.opensearch.common.compress.CompressorFactory; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.InputStreamStreamInput; +import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.transport.BytesTransportRequest; + +import java.io.IOException; + +/** + * A helper class to utilize the compressed stream. + * + * @opensearch.internal + */ +public final class CompressedStreamUtils { + private static final Logger logger = LogManager.getLogger(CompressedStreamUtils.class); + + public static BytesReference createCompressedStream(Version version, CheckedConsumer outputConsumer) + throws IOException { + final BytesStreamOutput bStream = new BytesStreamOutput(); + try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) { + stream.setVersion(version); + outputConsumer.accept(stream); + } + final BytesReference serializedByteRef = bStream.bytes(); + logger.trace("serialized writable object for node version [{}] with size [{}]", version, serializedByteRef.length()); + return serializedByteRef; + } + + public static StreamInput decompressBytes(BytesTransportRequest request, NamedWriteableRegistry namedWriteableRegistry) + throws IOException { + final Compressor compressor = CompressorFactory.compressor(request.bytes()); + final StreamInput in; + if (compressor != null) { + in = new InputStreamStreamInput(compressor.threadLocalInputStream(request.bytes().streamInput())); + } else { + in = request.bytes().streamInput(); + } + in.setVersion(request.version()); + return new NamedWriteableAwareStreamInput(in, namedWriteableRegistry); + } +} diff --git a/server/src/main/java/org/opensearch/cluster/coordination/Coordinator.java b/server/src/main/java/org/opensearch/cluster/coordination/Coordinator.java index 1ce0f9d4901fe..d2e4e97e46b73 100644 --- a/server/src/main/java/org/opensearch/cluster/coordination/Coordinator.java +++ b/server/src/main/java/org/opensearch/cluster/coordination/Coordinator.java @@ -223,7 +223,8 @@ public Coordinator( this.onJoinValidators, rerouteService, nodeHealthService, - this::onNodeCommissionStatusChange + this::onNodeCommissionStatusChange, + namedWriteableRegistry ); this.persistedStateSupplier = persistedStateSupplier; this.noClusterManagerBlockService = new NoClusterManagerBlockService(settings, clusterSettings); diff --git a/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java b/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java index a66152b8016ee..392ff7045cb8d 100644 --- a/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java +++ b/server/src/main/java/org/opensearch/cluster/coordination/JoinHelper.java @@ -35,6 +35,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.cluster.ClusterState; @@ -49,7 +50,9 @@ import org.opensearch.cluster.routing.allocation.AllocationService; import org.opensearch.cluster.service.ClusterManagerService; import org.opensearch.common.Priority; +import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.collect.Tuple; +import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -58,6 +61,7 @@ import org.opensearch.monitor.StatusInfo; import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPool.Names; +import org.opensearch.transport.BytesTransportRequest; import org.opensearch.transport.RemoteTransportException; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.TransportException; @@ -98,6 +102,7 @@ public class JoinHelper { public static final String JOIN_ACTION_NAME = "internal:cluster/coordination/join"; public static final String VALIDATE_JOIN_ACTION_NAME = "internal:cluster/coordination/join/validate"; + public static final String VALIDATE_COMPRESSED_JOIN_ACTION_NAME = JOIN_ACTION_NAME + "/validate_compressed"; public static final String START_JOIN_ACTION_NAME = "internal:cluster/coordination/start_join"; // the timeout for Zen1 join attempts @@ -122,6 +127,8 @@ public class JoinHelper { private final Supplier joinTaskExecutorGenerator; private final Consumer nodeCommissioned; + private final NamedWriteableRegistry namedWriteableRegistry; + private final AtomicReference> serializedState = new AtomicReference<>(); JoinHelper( Settings settings, @@ -135,13 +142,16 @@ public class JoinHelper { Collection> joinValidators, RerouteService rerouteService, NodeHealthService nodeHealthService, - Consumer nodeCommissioned + Consumer nodeCommissioned, + NamedWriteableRegistry namedWriteableRegistry ) { this.clusterManagerService = clusterManagerService; this.transportService = transportService; this.nodeHealthService = nodeHealthService; this.joinTimeout = JOIN_TIMEOUT_SETTING.get(settings); this.nodeCommissioned = nodeCommissioned; + this.namedWriteableRegistry = namedWriteableRegistry; + this.joinTaskExecutorGenerator = () -> new JoinTaskExecutor(settings, allocationService, logger, rerouteService, transportService) { private final long term = currentTermSupplier.getAsLong(); @@ -208,22 +218,52 @@ public ClusterTasksResult execute(ClusterState currentSta ThreadPool.Names.GENERIC, ValidateJoinRequest::new, (request, channel, task) -> { - final ClusterState localState = currentStateSupplier.get(); - if (localState.metadata().clusterUUIDCommitted() - && localState.metadata().clusterUUID().equals(request.getState().metadata().clusterUUID()) == false) { - throw new CoordinationStateRejectedException( - "join validation on cluster state" - + " with a different cluster uuid " - + request.getState().metadata().clusterUUID() - + " than local cluster uuid " - + localState.metadata().clusterUUID() - + ", rejecting" - ); - } - joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), request.getState())); + runJoinValidators(currentStateSupplier, request.getState(), joinValidators); channel.sendResponse(Empty.INSTANCE); } ); + + transportService.registerRequestHandler( + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, + ThreadPool.Names.GENERIC, + BytesTransportRequest::new, + (request, channel, task) -> { + handleCompressedValidateJoinRequest(currentStateSupplier, joinValidators, request); + channel.sendResponse(Empty.INSTANCE); + } + ); + + } + + private void runJoinValidators( + Supplier currentStateSupplier, + ClusterState incomingState, + Collection> joinValidators + ) { + final ClusterState localState = currentStateSupplier.get(); + if (localState.metadata().clusterUUIDCommitted() + && localState.metadata().clusterUUID().equals(incomingState.metadata().clusterUUID()) == false) { + throw new CoordinationStateRejectedException( + "join validation on cluster state" + + " with a different cluster uuid " + + incomingState.metadata().clusterUUID() + + " than local cluster uuid " + + localState.metadata().clusterUUID() + + ", rejecting" + ); + } + joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), incomingState)); + } + + private void handleCompressedValidateJoinRequest( + Supplier currentStateSupplier, + Collection> joinValidators, + BytesTransportRequest request + ) throws IOException { + try (StreamInput input = CompressedStreamUtils.decompressBytes(request, namedWriteableRegistry)) { + ClusterState incomingState = ClusterState.readFrom(input, transportService.getLocalNode()); + runJoinValidators(currentStateSupplier, incomingState, joinValidators); + } } private JoinCallback transportJoinCallback(TransportRequest request, TransportChannel channel) { @@ -407,12 +447,42 @@ public String executor() { } public void sendValidateJoinRequest(DiscoveryNode node, ClusterState state, ActionListener listener) { - transportService.sendRequest( - node, - VALIDATE_JOIN_ACTION_NAME, - new ValidateJoinRequest(state), - new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC) - ); + if (node.getVersion().before(Version.V_3_0_0)) { + transportService.sendRequest( + node, + VALIDATE_JOIN_ACTION_NAME, + new ValidateJoinRequest(state), + new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC) + ); + } else { + try { + final BytesReference bytes = serializedState.updateAndGet(cachedState -> { + if (cachedState == null || cachedState.v1() != state.version()) { + try { + return new Tuple<>( + state.version(), + CompressedStreamUtils.createCompressedStream(node.getVersion(), state::writeTo) + ); + } catch (IOException e) { + // mandatory as AtomicReference doesn't rethrow IOException. + throw new RuntimeException(e); + } + } else { + return cachedState; + } + }).v2(); + final BytesTransportRequest request = new BytesTransportRequest(bytes, node.getVersion()); + transportService.sendRequest( + node, + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, + request, + new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC) + ); + } catch (Exception e) { + logger.warn("error sending cluster state to {}", node); + listener.onFailure(e); + } + } } /** diff --git a/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java b/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java index e836f30d21ff8..21ef89e9d5790 100644 --- a/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java +++ b/server/src/main/java/org/opensearch/cluster/coordination/PublicationTransportHandler.java @@ -44,16 +44,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.compress.Compressor; -import org.opensearch.common.compress.CompressorFactory; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.InputStreamStreamInput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.io.stream.OutputStreamStreamOutput; import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.util.io.IOUtils; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.BytesTransportRequest; import org.opensearch.transport.TransportChannel; @@ -168,17 +160,9 @@ public PublishClusterStateStats stats() { } private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportRequest request) throws IOException { - final Compressor compressor = CompressorFactory.compressor(request.bytes()); - StreamInput in = request.bytes().streamInput(); - try { - if (compressor != null) { - in = new InputStreamStreamInput(compressor.threadLocalInputStream(in)); - } - in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry); - in.setVersion(request.version()); - // If true we received full cluster state - otherwise diffs + try (StreamInput in = CompressedStreamUtils.decompressBytes(request, namedWriteableRegistry)) { + ClusterState incomingState; if (in.readBoolean()) { - final ClusterState incomingState; // Close early to release resources used by the de-compression as early as possible try (StreamInput input = in) { incomingState = ClusterState.readFrom(input, transportService.getLocalNode()); @@ -198,7 +182,6 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque incompatibleClusterStateDiffReceivedCount.incrementAndGet(); throw new IncompatibleClusterStateVersionException("have no local cluster state"); } else { - ClusterState incomingState; try { final Diff diff; // Close stream early to release resources used by the de-compression as early as possible @@ -225,8 +208,6 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque return response; } } - } finally { - IOUtils.close(in); } } @@ -254,13 +235,10 @@ public PublicationContext newPublicationContext(ClusterChangedEvent clusterChang } private static BytesReference serializeFullClusterState(ClusterState clusterState, Version nodeVersion) throws IOException { - final BytesStreamOutput bStream = new BytesStreamOutput(); - try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.defaultCompressor().threadLocalOutputStream(bStream))) { - stream.setVersion(nodeVersion); + final BytesReference serializedState = CompressedStreamUtils.createCompressedStream(nodeVersion, stream -> { stream.writeBoolean(true); clusterState.writeTo(stream); - } - final BytesReference serializedState = bStream.bytes(); + }); logger.trace( "serialized full cluster state version [{}] for node version [{}] with size [{}]", clusterState.version(), @@ -271,13 +249,10 @@ private static BytesReference serializeFullClusterState(ClusterState clusterStat } private static BytesReference serializeDiffClusterState(Diff diff, Version nodeVersion) throws IOException { - final BytesStreamOutput bStream = new BytesStreamOutput(); - try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.defaultCompressor().threadLocalOutputStream(bStream))) { - stream.setVersion(nodeVersion); + return CompressedStreamUtils.createCompressedStream(nodeVersion, stream -> { stream.writeBoolean(false); diff.writeTo(stream); - } - return bStream.bytes(); + }); } /** diff --git a/server/src/main/resources/org/opensearch/bootstrap/test-framework.policy b/server/src/main/resources/org/opensearch/bootstrap/test-framework.policy index 60d0e9d15215a..52205ee498eb3 100644 --- a/server/src/main/resources/org/opensearch/bootstrap/test-framework.policy +++ b/server/src/main/resources/org/opensearch/bootstrap/test-framework.policy @@ -143,4 +143,8 @@ grant codeBase "file:${gradle.worker.jar}" { grant { // since the gradle test worker jar is on the test classpath, our tests should be able to read it permission java.io.FilePermission "${gradle.worker.jar}", "read"; + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "reflectionFactoryAccess"; + permission java.lang.RuntimePermission "accessClassInPackage.sun.reflect"; + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; }; diff --git a/server/src/test/java/org/opensearch/cluster/coordination/CompressedStreamUtilsTests.java b/server/src/test/java/org/opensearch/cluster/coordination/CompressedStreamUtilsTests.java new file mode 100644 index 0000000000000..e8faa73315e85 --- /dev/null +++ b/server/src/test/java/org/opensearch/cluster/coordination/CompressedStreamUtilsTests.java @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.cluster.coordination; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.BytesTransportRequest; + +import java.io.IOException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Contains tests for {@link CompressedStreamUtils} + */ +public class CompressedStreamUtilsTests extends OpenSearchTestCase { + + public void testCreateCompressedStream() throws IOException { + // serialization success with normal state + final ClusterState localClusterState = ClusterState.builder(ClusterName.DEFAULT) + .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true)) + .build(); + DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); + BytesReference bytes = CompressedStreamUtils.createCompressedStream(localNode.getVersion(), localClusterState::writeTo); + assertNotNull(bytes); + + // Fail on write failure on mocked cluster state's writeTo exception + ClusterState mockedState = mock(ClusterState.class); + doThrow(IOException.class).when(mockedState).writeTo(any()); + assertThrows(IOException.class, () -> CompressedStreamUtils.createCompressedStream(localNode.getVersion(), mockedState::writeTo)); + } + + public void testDecompressBytes() throws IOException { + // Decompression works fine + final ClusterState localClusterState = ClusterState.builder(ClusterName.DEFAULT) + .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true)) + .build(); + DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); + BytesReference bytes = CompressedStreamUtils.createCompressedStream(localNode.getVersion(), localClusterState::writeTo); + BytesTransportRequest request = new BytesTransportRequest(bytes, localNode.getVersion()); + StreamInput input = CompressedStreamUtils.decompressBytes(request, DEFAULT_NAMED_WRITABLE_REGISTRY); + assertEquals(request.version(), input.getVersion()); + + // Decompression fails with AssertionError on non-compressed request + BytesTransportRequest mockedRequest = mock(BytesTransportRequest.class, RETURNS_DEEP_STUBS); + when(mockedRequest.bytes().streamInput()).thenThrow(IOException.class); + assertThrows(AssertionError.class, () -> CompressedStreamUtils.decompressBytes(mockedRequest, DEFAULT_NAMED_WRITABLE_REGISTRY)); + } +} diff --git a/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java b/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java index b56abe101bd4c..ad9dfd564d648 100644 --- a/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java +++ b/server/src/test/java/org/opensearch/cluster/coordination/JoinHelperTests.java @@ -33,6 +33,7 @@ import org.apache.logging.log4j.Level; import org.opensearch.Version; +import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.cluster.ClusterName; @@ -40,21 +41,32 @@ import org.opensearch.cluster.NotClusterManagerException; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.Settings; import org.opensearch.monitor.StatusInfo; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.CapturingTransport; import org.opensearch.test.transport.CapturingTransport.CapturedRequest; import org.opensearch.test.transport.MockTransport; +import org.opensearch.transport.BytesTransportRequest; import org.opensearch.transport.RemoteTransportException; import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportService; +import java.io.IOException; import java.util.Collections; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.cluster.coordination.JoinHelper.VALIDATE_COMPRESSED_JOIN_ACTION_NAME; +import static org.opensearch.cluster.coordination.JoinHelper.VALIDATE_JOIN_ACTION_NAME; import static org.opensearch.monitor.StatusInfo.Status.HEALTHY; import static org.opensearch.monitor.StatusInfo.Status.UNHEALTHY; import static org.opensearch.node.Node.NODE_NAME_SETTING; @@ -63,6 +75,7 @@ import static org.hamcrest.core.Is.is; public class JoinHelperTests extends OpenSearchTestCase { + private final NamedWriteableRegistry namedWriteableRegistry = DEFAULT_NAMED_WRITABLE_REGISTRY; public void testJoinDeduplication() { DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue( @@ -93,7 +106,8 @@ public void testJoinDeduplication() { Collections.emptyList(), (s, p, r) -> {}, () -> new StatusInfo(HEALTHY, "info"), - nodeCommissioned -> {} + nodeCommissioned -> {}, + namedWriteableRegistry ); transportService.start(); @@ -195,60 +209,55 @@ public void testFailedJoinAttemptLogLevel() { ); } - public void testJoinValidationRejectsMismatchedClusterUUID() { + public void testJoinValidationRejectsMismatchedClusterUUID() throws IOException { assertJoinValidationRejectsMismatchedClusterUUID( - JoinHelper.VALIDATE_JOIN_ACTION_NAME, + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, "join validation on cluster state with a different cluster uuid" ); - } - - private void assertJoinValidationRejectsMismatchedClusterUUID(String actionName, String expectedMessage) { - DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue( - Settings.builder().put(NODE_NAME_SETTING.getKey(), "node0").build(), - random() + assertJoinValidationRejectsMismatchedClusterUUID( + VALIDATE_JOIN_ACTION_NAME, + "join validation on cluster state with a different cluster uuid" ); - MockTransport mockTransport = new MockTransport(); - DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); - - final ClusterState localClusterState = ClusterState.builder(ClusterName.DEFAULT) - .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true)) - .build(); + } - TransportService transportService = mockTransport.createTransportService( - Settings.EMPTY, - deterministicTaskQueue.getThreadPool(), - TransportService.NOOP_TRANSPORT_INTERCEPTOR, - x -> localNode, - null, - Collections.emptySet() - ); - new JoinHelper(Settings.EMPTY, null, null, transportService, () -> 0L, () -> localClusterState, (joinRequest, joinCallback) -> { - throw new AssertionError(); - }, startJoinRequest -> { throw new AssertionError(); }, Collections.emptyList(), (s, p, r) -> {}, null, nodeCommissioned -> {}); // registers - // request - // handler - transportService.start(); - transportService.acceptIncomingRequests(); + private void assertJoinValidationRejectsMismatchedClusterUUID(String actionName, String expectedMessage) throws IOException { + TestClusterSetup testCluster = getTestClusterSetup(null, false); final ClusterState otherClusterState = ClusterState.builder(ClusterName.DEFAULT) .metadata(Metadata.builder().generateClusterUuidIfNeeded()) .build(); - + TransportRequest request; final PlainActionFuture future = new PlainActionFuture<>(); - transportService.sendRequest( - localNode, - actionName, - new ValidateJoinRequest(otherClusterState), - new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) - ); - deterministicTaskQueue.runAllTasks(); + if (actionName.equals(VALIDATE_COMPRESSED_JOIN_ACTION_NAME)) { + BytesReference bytes = CompressedStreamUtils.createCompressedStream( + testCluster.localNode.getVersion(), + otherClusterState::writeTo + ); + request = new BytesTransportRequest(bytes, testCluster.localNode.getVersion()); + testCluster.transportService.sendRequest( + testCluster.localNode, + actionName, + request, + new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) + ); + } else if (actionName.equals(VALIDATE_JOIN_ACTION_NAME)) { + request = new ValidateJoinRequest(otherClusterState); + testCluster.transportService.sendRequest( + testCluster.localNode, + actionName, + request, + new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) + ); + } + + testCluster.deterministicTaskQueue.runAllTasks(); final CoordinationStateRejectedException coordinationStateRejectedException = expectThrows( CoordinationStateRejectedException.class, future::actionGet ); assertThat(coordinationStateRejectedException.getMessage(), containsString(expectedMessage)); - assertThat(coordinationStateRejectedException.getMessage(), containsString(localClusterState.metadata().clusterUUID())); + assertThat(coordinationStateRejectedException.getMessage(), containsString(testCluster.localClusterState.metadata().clusterUUID())); assertThat(coordinationStateRejectedException.getMessage(), containsString(otherClusterState.metadata().clusterUUID())); } @@ -282,7 +291,8 @@ public void testJoinFailureOnUnhealthyNodes() { Collections.emptyList(), (s, p, r) -> {}, () -> nodeHealthServiceStatus.get(), - nodeCommissioned -> {} + nodeCommissioned -> {}, + namedWriteableRegistry ); transportService.start(); @@ -322,4 +332,196 @@ public void testJoinFailureOnUnhealthyNodes() { CapturedRequest capturedRequest1a = capturedRequests1a[0]; assertEquals(node1, capturedRequest1a.node); } + + public void testSendCompressedValidateJoinFailOnSerializeFailure() throws ExecutionException, InterruptedException, TimeoutException { + TestClusterSetup testCluster = getTestClusterSetup(Version.CURRENT, false); + final CompletableFuture future = new CompletableFuture<>(); + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, null, new ActionListener<>() { + @Override + public void onResponse(TransportResponse.Empty empty) { + future.completeExceptionally(new AssertionError("validate join should have failed")); + } + + @Override + public void onFailure(Exception e) { + future.complete(e); + } + }); + Throwable t = future.get(10, TimeUnit.SECONDS); + assertTrue(t instanceof NullPointerException); + } + + public void testValidateJoinSentWithCorrectActionForVersions() { + verifyValidateJoinActionSent(VALIDATE_JOIN_ACTION_NAME, Version.V_2_1_0); + verifyValidateJoinActionSent(VALIDATE_JOIN_ACTION_NAME, Version.V_2_7_0); + verifyValidateJoinActionSent(VALIDATE_JOIN_ACTION_NAME, Version.V_2_8_0); + verifyValidateJoinActionSent(VALIDATE_COMPRESSED_JOIN_ACTION_NAME, Version.CURRENT); + } + + private void verifyValidateJoinActionSent(String expectedActionName, Version version) { + TestClusterSetup testCluster = getTestClusterSetup(version, true); + final CompletableFuture future = new CompletableFuture<>(); + DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), version); + testCluster.joinHelper.sendValidateJoinRequest(node1, testCluster.localClusterState, new ActionListener<>() { + @Override + public void onResponse(TransportResponse.Empty empty) { + throw new AssertionError("capturing transport shouldn't run"); + } + + @Override + public void onFailure(Exception e) { + future.complete(e); + } + }); + + CapturedRequest[] validateRequests = testCluster.capturingTransport.getCapturedRequestsAndClear(); + assertEquals(1, validateRequests.length); + assertEquals(expectedActionName, validateRequests[0].action); + } + + public void testJoinValidationFailsOnDecompressionFailure() { + TestClusterSetup testCluster = getTestClusterSetup(Version.CURRENT, false); + TransportRequest request; + final PlainActionFuture future = new PlainActionFuture<>(); + request = new BytesTransportRequest(null, testCluster.localNode.getVersion()); + testCluster.transportService.sendRequest( + testCluster.localNode, + VALIDATE_COMPRESSED_JOIN_ACTION_NAME, + request, + new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE) + ); + testCluster.deterministicTaskQueue.runAllTasks(); + expectThrows(NullPointerException.class, future::actionGet); + } + + public void testJoinHelperCachingOnClusterState() throws ExecutionException, InterruptedException, TimeoutException { + TestClusterSetup testCluster = getTestClusterSetup(Version.CURRENT, false); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(TransportResponse.Empty empty) { + logger.info("validation successful"); + } + + @Override + public void onFailure(Exception e) { + future.completeExceptionally(new AssertionError("validate join should not fail here")); + } + }; + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, testCluster.localClusterState, listener); + // validation will pass due to cached cluster state + ClusterState randomState = ClusterState.builder(new ClusterName("random")) + .stateUUID("random2") + .version(testCluster.localClusterState.version()) + .build(); + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, randomState, listener); + + final CompletableFuture future2 = new CompletableFuture<>(); + ActionListener listener2 = new ActionListener<>() { + @Override + public void onResponse(TransportResponse.Empty empty) { + future2.completeExceptionally(new AssertionError("validation should fail now")); + } + + @Override + public void onFailure(Exception e) { + future2.complete(e); + } + }; + ClusterState randomState2 = ClusterState.builder(new ClusterName("random")) + .stateUUID("random2") + .version(testCluster.localClusterState.version() + 1) + .build(); + // now sending the validate join request will fail due to random cluster uuid because version is changed + // and cache will be invalidated + testCluster.joinHelper.sendValidateJoinRequest(testCluster.localNode, randomState2, listener2); + testCluster.deterministicTaskQueue.runAllTasks(); + + Throwable t = future2.get(10, TimeUnit.SECONDS); + assertTrue(t instanceof RemoteTransportException); + assertTrue(t.getCause() instanceof CoordinationStateRejectedException); + assertTrue(t.getCause().getMessage().contains("different cluster uuid")); + } + + private TestClusterSetup getTestClusterSetup(Version version, boolean isCapturingTransport) { + version = version == null ? Version.CURRENT : version; + DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue( + Settings.builder().put(NODE_NAME_SETTING.getKey(), "node0").build(), + random() + ); + MockTransport mockTransport = new MockTransport(); + CapturingTransport capturingTransport = new CapturingTransport(); + DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), version); + + final ClusterState localClusterState = ClusterState.builder(ClusterName.DEFAULT) + .metadata(Metadata.builder().generateClusterUuidIfNeeded().clusterUUIDCommitted(true)) + .build(); + TransportService transportService; + if (isCapturingTransport) { + transportService = capturingTransport.createTransportService( + Settings.EMPTY, + deterministicTaskQueue.getThreadPool(), + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> localNode, + null, + Collections.emptySet() + ); + } else { + transportService = mockTransport.createTransportService( + Settings.EMPTY, + deterministicTaskQueue.getThreadPool(), + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> localNode, + null, + Collections.emptySet() + ); + } + JoinHelper joinHelper = new JoinHelper( + Settings.EMPTY, + null, + null, + transportService, + () -> 0L, + () -> localClusterState, + (joinRequest, joinCallback) -> { + throw new AssertionError(); + }, + startJoinRequest -> { throw new AssertionError(); }, + Collections.emptyList(), + (s, p, r) -> {}, + null, + nodeCommissioned -> {}, + namedWriteableRegistry + ); // registers + // request + // handler + transportService.start(); + transportService.acceptIncomingRequests(); + return new TestClusterSetup(deterministicTaskQueue, localNode, transportService, localClusterState, joinHelper, capturingTransport); + } + + private static class TestClusterSetup { + public final DeterministicTaskQueue deterministicTaskQueue; + public final DiscoveryNode localNode; + public TransportService transportService; + public final ClusterState localClusterState; + public final JoinHelper joinHelper; + public final CapturingTransport capturingTransport; + + public TestClusterSetup( + DeterministicTaskQueue deterministicTaskQueue, + DiscoveryNode localNode, + TransportService transportService, + ClusterState localClusterState, + JoinHelper joinHelper, + CapturingTransport capturingTransport + ) { + this.deterministicTaskQueue = deterministicTaskQueue; + this.localNode = localNode; + this.transportService = transportService; + this.localClusterState = localClusterState; + this.joinHelper = joinHelper; + this.capturingTransport = capturingTransport; + } + } } diff --git a/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java b/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java index 2752f57b499b3..fb2e7cd73d3bf 100644 --- a/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java +++ b/server/src/test/java/org/opensearch/cluster/coordination/NodeJoinTests.java @@ -208,15 +208,20 @@ private void setupClusterManagerServiceAndCoordinator( CapturingTransport capturingTransport = new CapturingTransport() { @Override protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode destination) { - if (action.equals(HANDSHAKE_ACTION_NAME)) { - handleResponse( - requestId, - new TransportService.HandshakeResponse(destination, initialState.getClusterName(), destination.getVersion()) - ); - } else if (action.equals(JoinHelper.VALIDATE_JOIN_ACTION_NAME)) { - handleResponse(requestId, new TransportResponse.Empty()); - } else { - super.onSendRequest(requestId, action, request, destination); + switch (action) { + case HANDSHAKE_ACTION_NAME: + handleResponse( + requestId, + new TransportService.HandshakeResponse(destination, initialState.getClusterName(), destination.getVersion()) + ); + break; + case JoinHelper.VALIDATE_JOIN_ACTION_NAME: + case JoinHelper.VALIDATE_COMPRESSED_JOIN_ACTION_NAME: + handleResponse(requestId, new TransportResponse.Empty()); + break; + default: + super.onSendRequest(requestId, action, request, destination); + break; } }