Skip to content

Commit

Permalink
Compress and cache cluster state during validate join request
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <amkhar@amazon.com>
  • Loading branch information
Aman Khare committed Apr 28, 2023
1 parent dffd822 commit fd7eebe
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 24 deletions.
36 changes: 36 additions & 0 deletions server/src/main/java/org/opensearch/cluster/CompressionHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.opensearch.cluster;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.Version;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.compress.CompressorFactory;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;

import java.io.IOException;

/**
* A helper class to utilize the compressed stream.
*/
public class CompressionHelper {
private static final Logger logger = LogManager.getLogger(CompressionHelper.class);

public static BytesReference serializedWrite(Writeable writer, Version nodeVersion, boolean streamBooleanFlag) throws IOException {
final BytesStreamOutput bStream = new BytesStreamOutput();
try (StreamOutput stream = new OutputStreamStreamOutput(CompressorFactory.COMPRESSOR.threadLocalOutputStream(bStream))) {
stream.setVersion(nodeVersion);
stream.writeBoolean(streamBooleanFlag);
writer.writeTo(stream);
}
final BytesReference serializedByteRef = bStream.bytes();
logger.trace(
"serialized writable object for node version [{}] with size [{}]",
nodeVersion,
serializedByteRef.length()
);
return serializedByteRef;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ public Coordinator(
this.onJoinValidators,
rerouteService,
nodeHealthService,
this::onNodeCommissionStatusChange
this::onNodeCommissionStatusChange,
namedWriteableRegistry
);
this.persistedStateSupplier = persistedStateSupplier;
this.noClusterManagerBlockService = new NoClusterManagerBlockService(settings, clusterSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.Version;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.cluster.ClusterStateTaskListener;
import org.opensearch.cluster.CompressionHelper;
import org.opensearch.cluster.NotClusterManagerException;
import org.opensearch.cluster.coordination.Coordinator.Mode;
import org.opensearch.cluster.decommission.NodeDecommissionedException;
Expand All @@ -49,15 +51,23 @@
import org.opensearch.cluster.routing.allocation.AllocationService;
import org.opensearch.cluster.service.ClusterManagerService;
import org.opensearch.common.Priority;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.compress.Compressor;
import org.opensearch.common.compress.CompressorFactory;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.monitor.NodeHealthService;
import org.opensearch.monitor.StatusInfo;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.threadpool.ThreadPool.Names;
import org.opensearch.transport.BytesTransportRequest;
import org.opensearch.transport.RemoteTransportException;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportException;
Expand All @@ -78,6 +88,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
Expand Down Expand Up @@ -122,6 +133,17 @@ public class JoinHelper {

private final Supplier<JoinTaskExecutor> joinTaskExecutorGenerator;
private final Consumer<Boolean> nodeCommissioned;
private final NamedWriteableRegistry namedWriteableRegistry;
private final Map<Version, BytesReference> serializedStates = new HashMap<>();
private long lastRefreshTime = 0L;
public static final Setting<TimeValue> CLUSTER_MANAGER_JOIN_STATE_REFRESH_INTERVAL = Setting.timeSetting(
"cluster_manager.join.state.refresh_interval",
TimeValue.timeValueMillis(30000),
TimeValue.timeValueMillis(0),
TimeValue.timeValueMillis(60000),
Setting.Property.NodeScope
);
private final long clusterStateRefreshInterval;

JoinHelper(
Settings settings,
Expand All @@ -135,13 +157,16 @@ public class JoinHelper {
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators,
RerouteService rerouteService,
NodeHealthService nodeHealthService,
Consumer<Boolean> nodeCommissioned
Consumer<Boolean> nodeCommissioned,
NamedWriteableRegistry namedWriteableRegistry
) {
this.clusterManagerService = clusterManagerService;
this.transportService = transportService;
this.nodeHealthService = nodeHealthService;
this.joinTimeout = JOIN_TIMEOUT_SETTING.get(settings);
this.nodeCommissioned = nodeCommissioned;
this.namedWriteableRegistry = namedWriteableRegistry;
this.clusterStateRefreshInterval = CLUSTER_MANAGER_JOIN_STATE_REFRESH_INTERVAL.get(settings).getMillis();
this.joinTaskExecutorGenerator = () -> new JoinTaskExecutor(settings, allocationService, logger, rerouteService) {

private final long term = currentTermSupplier.getAsLong();
Expand Down Expand Up @@ -206,24 +231,53 @@ public ClusterTasksResult<JoinTaskExecutor.Task> execute(ClusterState currentSta
transportService.registerRequestHandler(
VALIDATE_JOIN_ACTION_NAME,
ThreadPool.Names.GENERIC,
ValidateJoinRequest::new,
BytesTransportRequest::new,
(request, channel, task) -> {
handleValidateJoinRequest(currentStateSupplier, joinValidators, request);
channel.sendResponse(Empty.INSTANCE);
}
);
}

private void handleValidateJoinRequest(Supplier<ClusterState> currentStateSupplier,
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators,
BytesTransportRequest request) throws IOException {
final Compressor compressor = CompressorFactory.compressor(request.bytes());
StreamInput in = request.bytes().streamInput();
final ClusterState incomingState;
try {
if (compressor != null) {
in = new InputStreamStreamInput(compressor.threadLocalInputStream(in));
}
in = new NamedWriteableAwareStreamInput(in, namedWriteableRegistry);
in.setVersion(request.version());
// If true we received full cluster state - otherwise diffs
if (in.readBoolean()) {
// Close early to release resources used by the de-compression as early as possible
try (StreamInput input = in) {
incomingState = ClusterState.readFrom(input, transportService.getLocalNode());
} catch (Exception e) {
logger.warn("unexpected error while deserializing an incoming cluster state", e);
throw e;
}

final ClusterState localState = currentStateSupplier.get();
if (localState.metadata().clusterUUIDCommitted()
&& localState.metadata().clusterUUID().equals(request.getState().metadata().clusterUUID()) == false) {
&& localState.metadata().clusterUUID().equals(incomingState.metadata().clusterUUID()) == false) {
throw new CoordinationStateRejectedException(
"join validation on cluster state"
+ " with a different cluster uuid "
+ request.getState().metadata().clusterUUID()
+ incomingState.metadata().clusterUUID()
+ " than local cluster uuid "
+ localState.metadata().clusterUUID()
+ ", rejecting"
);
}
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), request.getState()));
channel.sendResponse(Empty.INSTANCE);
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), incomingState));
}
);
} finally {
IOUtils.close(in);
}
}

private JoinCallback transportJoinCallback(TransportRequest request, TransportChannel channel) {
Expand Down Expand Up @@ -407,12 +461,37 @@ public String executor() {
}

public void sendValidateJoinRequest(DiscoveryNode node, ClusterState state, ActionListener<TransportResponse.Empty> listener) {
transportService.sendRequest(
node,
VALIDATE_JOIN_ACTION_NAME,
new ValidateJoinRequest(state),
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
try {
BytesReference bytes = serializedStates.get(node.getVersion());
// Refresh serializedStates map every time if clusterStateRefreshInterval is 0
if (bytes == null || (System.currentTimeMillis() >= lastRefreshTime + clusterStateRefreshInterval)) {
try {
// Re-getting current cluster state for validate join request
bytes = CompressionHelper.serializedWrite(state,
node.getVersion(), true);
serializedStates.put(node.getVersion(), bytes);
lastRefreshTime = System.currentTimeMillis();
} catch (Exception e) {
logger.warn(
() -> new ParameterizedMessage("failed to serialize cluster state during validateJoin" +
" {}", node),
e
);
listener.onFailure(e);
return;
}
}
final BytesTransportRequest request = new BytesTransportRequest(bytes, node.getVersion());
transportService.sendRequest(
node,
VALIDATE_JOIN_ACTION_NAME,
request,
new ActionListenerResponseHandler<>(listener, i -> Empty.INSTANCE, ThreadPool.Names.GENERIC)
);
} catch (Exception e) {
logger.warn(() -> new ParameterizedMessage("error sending cluster state to {}", node), e);
listener.onFailure(e);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.CompressionHelper;
import org.opensearch.cluster.Diff;
import org.opensearch.cluster.IncompatibleClusterStateVersionException;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -309,15 +310,17 @@ void buildDiffAndSerializeStates() {
try {
if (sendFullVersion || previousState.nodes().nodeExists(node) == false) {
if (serializedStates.containsKey(node.getVersion()) == false) {
serializedStates.put(node.getVersion(), serializeFullClusterState(newState, node.getVersion()));
serializedStates.put(node.getVersion(), CompressionHelper.serializedWrite(newState,
node.getVersion(), true));
}
} else {
// will send a diff
if (diff == null) {
diff = newState.diff(previousState);
}
if (serializedDiffs.containsKey(node.getVersion()) == false) {
final BytesReference serializedDiff = serializeDiffClusterState(diff, node.getVersion());
final BytesReference serializedDiff = CompressionHelper.serializedWrite(newState,
node.getVersion(), false);
serializedDiffs.put(node.getVersion(), serializedDiff);
logger.trace(
"serialized cluster state diff for version [{}] in for node version [{}] with size [{}]",
Expand Down Expand Up @@ -413,7 +416,7 @@ private void sendFullClusterState(DiscoveryNode destination, ActionListener<Publ
BytesReference bytes = serializedStates.get(destination.getVersion());
if (bytes == null) {
try {
bytes = serializeFullClusterState(newState, destination.getVersion());
bytes = CompressionHelper.serializedWrite(newState, destination.getVersion(), true);
serializedStates.put(destination.getVersion(), bytes);
} catch (Exception e) {
logger.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ public void apply(Settings value, Settings current, Settings previous) {
SegmentReplicationPressureService.MAX_ALLOWED_STALE_SHARDS,

// Settings related to Searchable Snapshots
Node.NODE_SEARCH_CACHE_SIZE_SETTING
Node.NODE_SEARCH_CACHE_SIZE_SETTING,
JoinHelper.CLUSTER_MANAGER_JOIN_STATE_REFRESH_INTERVAL
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,25 @@
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.CompressionHelper;
import org.opensearch.cluster.NotClusterManagerException;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.settings.Settings;
import org.opensearch.monitor.StatusInfo;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.transport.CapturingTransport;
import org.opensearch.test.transport.CapturingTransport.CapturedRequest;
import org.opensearch.test.transport.MockTransport;
import org.opensearch.transport.BytesTransportRequest;
import org.opensearch.transport.RemoteTransportException;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportResponse;
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.Collections;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -63,6 +68,7 @@
import static org.hamcrest.core.Is.is;

public class JoinHelperTests extends OpenSearchTestCase {
private final NamedWriteableRegistry namedWriteableRegistry = DEFAULT_NAMED_WRITABLE_REGISTRY;

public void testJoinDeduplication() {
DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue(
Expand Down Expand Up @@ -93,7 +99,8 @@ public void testJoinDeduplication() {
Collections.emptyList(),
(s, p, r) -> {},
() -> new StatusInfo(HEALTHY, "info"),
nodeCommissioned -> {}
nodeCommissioned -> {},
namedWriteableRegistry
);
transportService.start();

Expand Down Expand Up @@ -195,14 +202,14 @@ public void testFailedJoinAttemptLogLevel() {
);
}

public void testJoinValidationRejectsMismatchedClusterUUID() {
public void testJoinValidationRejectsMismatchedClusterUUID() throws IOException {
assertJoinValidationRejectsMismatchedClusterUUID(
JoinHelper.VALIDATE_JOIN_ACTION_NAME,
"join validation on cluster state with a different cluster uuid"
);
}

private void assertJoinValidationRejectsMismatchedClusterUUID(String actionName, String expectedMessage) {
private void assertJoinValidationRejectsMismatchedClusterUUID(String actionName, String expectedMessage) throws IOException {
DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue(
Settings.builder().put(NODE_NAME_SETTING.getKey(), "node0").build(),
random()
Expand All @@ -224,7 +231,8 @@ private void assertJoinValidationRejectsMismatchedClusterUUID(String actionName,
);
new JoinHelper(Settings.EMPTY, null, null, transportService, () -> 0L, () -> localClusterState, (joinRequest, joinCallback) -> {
throw new AssertionError();
}, startJoinRequest -> { throw new AssertionError(); }, Collections.emptyList(), (s, p, r) -> {}, null, nodeCommissioned -> {}); // registers
}, startJoinRequest -> { throw new AssertionError(); }, Collections.emptyList(), (s, p, r) -> {}, null,
nodeCommissioned -> {}, namedWriteableRegistry); // registers
// request
// handler
transportService.start();
Expand All @@ -235,10 +243,12 @@ private void assertJoinValidationRejectsMismatchedClusterUUID(String actionName,
.build();

final PlainActionFuture<TransportResponse.Empty> future = new PlainActionFuture<>();
BytesReference bytes = CompressionHelper.serializedWrite(otherClusterState, localNode.getVersion(), true);
final BytesTransportRequest request = new BytesTransportRequest(bytes, localNode.getVersion());
transportService.sendRequest(
localNode,
actionName,
new ValidateJoinRequest(otherClusterState),
request,
new ActionListenerResponseHandler<>(future, in -> TransportResponse.Empty.INSTANCE)
);
deterministicTaskQueue.runAllTasks();
Expand Down Expand Up @@ -282,7 +292,8 @@ public void testJoinFailureOnUnhealthyNodes() {
Collections.emptyList(),
(s, p, r) -> {},
() -> nodeHealthServiceStatus.get(),
nodeCommissioned -> {}
nodeCommissioned -> {},
namedWriteableRegistry
);
transportService.start();

Expand Down

0 comments on commit fd7eebe

Please sign in to comment.