Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compress and cache cluster state during validate join request #7321

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add ZSTD compression for snapshotting ([#2996](https://github.com/opensearch-project/OpenSearch/pull/2996))
- Change `com.amazonaws.sdk.ec2MetadataServiceEndpointOverride` to `aws.ec2MetadataServiceEndpoint` ([7372](https://github.com/opensearch-project/OpenSearch/pull/7372/))
- Change `com.amazonaws.sdk.stsEndpointOverride` to `aws.stsEndpointOverride` ([7372](https://github.com/opensearch-project/OpenSearch/pull/7372/))
- Compress and cache cluster state during validate join request ([#7321](https://github.com/opensearch-project/OpenSearch/pull/7321))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.cluster.coordination;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.Version;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.compress.Compressor;
import org.opensearch.common.compress.CompressorFactory;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.transport.BytesTransportRequest;

import java.io.IOException;

/**
* A helper class to utilize the compressed stream.
*
* @opensearch.internal
*/
public final class CompressedStreamUtils {
private static final Logger logger = LogManager.getLogger(CompressedStreamUtils.class);

public static BytesReference createCompressedStream(Version version, CheckedConsumer<StreamOutput, IOException> outputConsumer)
throws IOException {
final BytesStreamOutput bStream = new BytesStreamOutput();
try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
stream.setVersion(version);
outputConsumer.accept(stream);
}
final BytesReference serializedByteRef = bStream.bytes();
logger.trace("serialized writable object for node version [{}] with size [{}]", version, serializedByteRef.length());
return serializedByteRef;
}

public static StreamInput decompressBytes(BytesTransportRequest request, NamedWriteableRegistry namedWriteableRegistry)
throws IOException {
final Compressor compressor = CompressorFactory.compressor(request.bytes());
final StreamInput in;
if (compressor != null) {
in = new InputStreamStreamInput(compressor.threadLocalInputStream(request.bytes().streamInput()));
} else {
in = request.bytes().streamInput();
}
in.setVersion(request.version());
return new NamedWriteableAwareStreamInput(in, namedWriteableRegistry);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ public Coordinator(
this.onJoinValidators,
rerouteService,
nodeHealthService,
this::onNodeCommissionStatusChange
this::onNodeCommissionStatusChange,
namedWriteableRegistry
);
this.persistedStateSupplier = persistedStateSupplier;
this.noClusterManagerBlockService = new NoClusterManagerBlockService(settings, clusterSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.Version;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.cluster.ClusterState;
Expand All @@ -49,7 +50,9 @@
import org.opensearch.cluster.routing.allocation.AllocationService;
import org.opensearch.cluster.service.ClusterManagerService;
import org.opensearch.common.Priority;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
Expand All @@ -58,6 +61,7 @@
import org.opensearch.monitor.StatusInfo;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.threadpool.ThreadPool.Names;
import org.opensearch.transport.BytesTransportRequest;
import org.opensearch.transport.RemoteTransportException;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportException;
Expand Down Expand Up @@ -98,6 +102,7 @@ public class JoinHelper {

public static final String JOIN_ACTION_NAME = "internal:cluster/coordination/join";
public static final String VALIDATE_JOIN_ACTION_NAME = "internal:cluster/coordination/join/validate";
public static final String VALIDATE_COMPRESSED_JOIN_ACTION_NAME = JOIN_ACTION_NAME + "/validate_compressed";
public static final String START_JOIN_ACTION_NAME = "internal:cluster/coordination/start_join";

// the timeout for Zen1 join attempts
Expand All @@ -122,6 +127,8 @@ public class JoinHelper {

private final Supplier<JoinTaskExecutor> joinTaskExecutorGenerator;
private final Consumer<Boolean> nodeCommissioned;
private final NamedWriteableRegistry namedWriteableRegistry;
private final AtomicReference<Tuple<Long, BytesReference>> serializedState = new AtomicReference<>();

JoinHelper(
Settings settings,
Expand All @@ -135,13 +142,16 @@ public class JoinHelper {
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators,
RerouteService rerouteService,
NodeHealthService nodeHealthService,
Consumer<Boolean> nodeCommissioned
Consumer<Boolean> nodeCommissioned,
NamedWriteableRegistry namedWriteableRegistry
) {
this.clusterManagerService = clusterManagerService;
this.transportService = transportService;
this.nodeHealthService = nodeHealthService;
this.joinTimeout = JOIN_TIMEOUT_SETTING.get(settings);
this.nodeCommissioned = nodeCommissioned;
this.namedWriteableRegistry = namedWriteableRegistry;

this.joinTaskExecutorGenerator = () -> new JoinTaskExecutor(settings, allocationService, logger, rerouteService) {

private final long term = currentTermSupplier.getAsLong();
Expand Down Expand Up @@ -208,22 +218,52 @@ public ClusterTasksResult<JoinTaskExecutor.Task> execute(ClusterState currentSta
ThreadPool.Names.GENERIC,
ValidateJoinRequest::new,
(request, channel, task) -> {
final ClusterState localState = currentStateSupplier.get();
if (localState.metadata().clusterUUIDCommitted()
&& localState.metadata().clusterUUID().equals(request.getState().metadata().clusterUUID()) == false) {
throw new CoordinationStateRejectedException(
"join validation on cluster state"
+ " with a different cluster uuid "
+ request.getState().metadata().clusterUUID()
+ " than local cluster uuid "
+ localState.metadata().clusterUUID()
+ ", rejecting"
);
}
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), request.getState()));
runJoinValidators(currentStateSupplier, request.getState(), joinValidators);
channel.sendResponse(Empty.INSTANCE);
}
);

transportService.registerRequestHandler(
VALIDATE_COMPRESSED_JOIN_ACTION_NAME,
ThreadPool.Names.GENERIC,
BytesTransportRequest::new,
(request, channel, task) -> {
handleCompressedValidateJoinRequest(currentStateSupplier, joinValidators, request);
channel.sendResponse(Empty.INSTANCE);
}
);

}

private void runJoinValidators(
Supplier<ClusterState> currentStateSupplier,
ClusterState incomingState,
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators
) {
final ClusterState localState = currentStateSupplier.get();
if (localState.metadata().clusterUUIDCommitted()
&& localState.metadata().clusterUUID().equals(incomingState.metadata().clusterUUID()) == false) {
throw new CoordinationStateRejectedException(
"join validation on cluster state"
+ " with a different cluster uuid "
+ incomingState.metadata().clusterUUID()
+ " than local cluster uuid "
+ localState.metadata().clusterUUID()
+ ", rejecting"
);
}
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), incomingState));
}

private void handleCompressedValidateJoinRequest(
Supplier<ClusterState> currentStateSupplier,
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators,
BytesTransportRequest request
) throws IOException {
try (StreamInput input = CompressedStreamUtils.decompressBytes(request, namedWriteableRegistry)) {
ClusterState incomingState = ClusterState.readFrom(input, transportService.getLocalNode());
runJoinValidators(currentStateSupplier, incomingState, joinValidators);
}
}

private JoinCallback transportJoinCallback(TransportRequest request, TransportChannel channel) {
Expand Down Expand Up @@ -407,12 +447,42 @@ public String executor() {
}

public void sendValidateJoinRequest(DiscoveryNode node, ClusterState state, ActionListener<TransportResponse.Empty> listener) {
transportService.sendRequest(
node,
VALIDATE_JOIN_ACTION_NAME,
new ValidateJoinRequest(state),
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
if (node.getVersion().before(Version.V_3_0_0)) {
transportService.sendRequest(
node,
VALIDATE_JOIN_ACTION_NAME,
new ValidateJoinRequest(state),
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
} else {
try {
final BytesReference bytes = serializedState.updateAndGet(cachedState -> {
if (cachedState == null || cachedState.v1() != state.version()) {
try {
return new Tuple<>(
state.version(),
CompressedStreamUtils.createCompressedStream(node.getVersion(), state::writeTo)
);
} catch (IOException e) {
// mandatory as AtomicReference doesn't rethrow IOException.
throw new RuntimeException(e);
}
} else {
return cachedState;
}
}).v2();
final BytesTransportRequest request = new BytesTransportRequest(bytes, node.getVersion());
transportService.sendRequest(
node,
VALIDATE_COMPRESSED_JOIN_ACTION_NAME,
request,
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
} catch (Exception e) {
logger.warn("error sending cluster state to {}", node);
listener.onFailure(e);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,8 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.compress.Compressor;
import org.opensearch.common.compress.CompressorFactory;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.BytesTransportRequest;
import org.opensearch.transport.TransportChannel;
Expand Down Expand Up @@ -168,17 +160,9 @@ public PublishClusterStateStats stats() {
}

private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportRequest request) throws IOException {
final Compressor compressor = CompressorFactory.compressor(request.bytes());
StreamInput in = request.bytes().streamInput();
try {
if (compressor != null) {
in = new InputStreamStreamInput(compressor.threadLocalInputStream(in));
}
in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry);
in.setVersion(request.version());
// If true we received full cluster state - otherwise diffs
try (StreamInput in = CompressedStreamUtils.decompressBytes(request, namedWriteableRegistry)) {
ClusterState incomingState;
if (in.readBoolean()) {
final ClusterState incomingState;
// Close early to release resources used by the de-compression as early as possible
try (StreamInput input = in) {
incomingState = ClusterState.readFrom(input, transportService.getLocalNode());
Expand All @@ -198,7 +182,6 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque
incompatibleClusterStateDiffReceivedCount.incrementAndGet();
throw new IncompatibleClusterStateVersionException("have no local cluster state");
} else {
ClusterState incomingState;
try {
final Diff<ClusterState> diff;
// Close stream early to release resources used by the de-compression as early as possible
Expand All @@ -225,8 +208,6 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque
return response;
}
}
} finally {
IOUtils.close(in);
}
}

Expand Down Expand Up @@ -254,13 +235,10 @@ public PublicationContext newPublicationContext(ClusterChangedEvent clusterChang
}

private static BytesReference serializeFullClusterState(ClusterState clusterState, Version nodeVersion) throws IOException {
final BytesStreamOutput bStream = new BytesStreamOutput();
try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
stream.setVersion(nodeVersion);
final BytesReference serializedState = CompressedStreamUtils.createCompressedStream(nodeVersion, stream -> {
stream.writeBoolean(true);
clusterState.writeTo(stream);
}
final BytesReference serializedState = bStream.bytes();
});
logger.trace(
"serialized full cluster state version [{}] for node version [{}] with size [{}]",
clusterState.version(),
Expand All @@ -271,13 +249,10 @@ private static BytesReference serializeFullClusterState(ClusterState clusterStat
}

private static BytesReference serializeDiffClusterState(Diff<ClusterState> diff, Version nodeVersion) throws IOException {
final BytesStreamOutput bStream = new BytesStreamOutput();
try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
stream.setVersion(nodeVersion);
return CompressedStreamUtils.createCompressedStream(nodeVersion, stream -> {
stream.writeBoolean(false);
diff.writeTo(stream);
}
return bStream.bytes();
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,8 @@ grant codeBase "file:${gradle.worker.jar}" {
grant {
// since the gradle test worker jar is on the test classpath, our tests should be able to read it
permission java.io.FilePermission "${gradle.worker.jar}", "read";
permission java.lang.RuntimePermission "accessDeclaredMembers";
permission java.lang.RuntimePermission "reflectionFactoryAccess";
permission java.lang.RuntimePermission "accessClassInPackage.sun.reflect";
permission java.lang.reflect.ReflectPermission "suppressAccessChecks";
};
Loading