diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index f6ed113019897..34664c570132e 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -11,9 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.lucene.index.CorruptIndexException; import org.opensearch.ExceptionsHelper; -import org.opensearch.OpenSearchCorruptionException; import org.opensearch.action.support.ChannelActionListener; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateListener; @@ -24,7 +22,6 @@ import org.opensearch.common.lifecycle.AbstractLifecycleComponent; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.CancellableThreads; -import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -33,7 +30,6 @@ import org.opensearch.index.shard.IndexEventListener; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardState; -import org.opensearch.index.store.Store; import org.opensearch.indices.IndicesService; import org.opensearch.indices.recovery.FileChunkRequest; import org.opensearch.indices.recovery.ForceSyncRequest; @@ -61,7 +57,7 @@ import static org.opensearch.indices.replication.SegmentReplicationSourceService.Actions.UPDATE_VISIBLE_CHECKPOINT; /** - * Service class that orchestrates replication events on replicas. + * Service class that handles incoming checkpoints to initiate replication events on replicas. * * @opensearch.internal */ @@ -72,10 +68,6 @@ public class SegmentReplicationTargetService extends AbstractLifecycleComponent private final ThreadPool threadPool; private final RecoverySettings recoverySettings; - private final ReplicationCollection onGoingReplications; - - private final Map completedReplications = ConcurrentCollections.newConcurrentMap(); - private final SegmentReplicationSourceFactory sourceFactory; protected final Map latestReceivedCheckpoint = ConcurrentCollections.newConcurrentMap(); @@ -83,6 +75,7 @@ public class SegmentReplicationTargetService extends AbstractLifecycleComponent private final IndicesService indicesService; private final ClusterService clusterService; private final TransportService transportService; + private final SegmentReplicator replicator; /** * The internal actions @@ -94,6 +87,7 @@ public static class Actions { public static final String FORCE_SYNC = "internal:index/shard/replication/segments_sync"; } + @Deprecated public SegmentReplicationTargetService( final ThreadPool threadPool, final RecoverySettings recoverySettings, @@ -113,6 +107,7 @@ public SegmentReplicationTargetService( ); } + @Deprecated public SegmentReplicationTargetService( final ThreadPool threadPool, final RecoverySettings recoverySettings, @@ -121,14 +116,34 @@ public SegmentReplicationTargetService( final IndicesService indicesService, final ClusterService clusterService, final ReplicationCollection ongoingSegmentReplications + ) { + this( + threadPool, + recoverySettings, + transportService, + sourceFactory, + indicesService, + clusterService, + new SegmentReplicator(threadPool) + ); + } + + public SegmentReplicationTargetService( + final ThreadPool threadPool, + final RecoverySettings recoverySettings, + final TransportService transportService, + final SegmentReplicationSourceFactory sourceFactory, + final IndicesService indicesService, + final ClusterService clusterService, + final SegmentReplicator replicator ) { this.threadPool = threadPool; this.recoverySettings = recoverySettings; - this.onGoingReplications = ongoingSegmentReplications; this.sourceFactory = sourceFactory; this.indicesService = indicesService; this.clusterService = clusterService; this.transportService = transportService; + this.replicator = replicator; transportService.registerRequestHandler( Actions.FILE_CHUNK, @@ -154,7 +169,7 @@ protected void doStart() { @Override protected void doStop() { if (DiscoveryNode.isDataNode(clusterService.getSettings())) { - assert onGoingReplications.size() == 0 : "Replication collection should be empty on shutdown"; + assert replicator.size() == 0 : "Replication collection should be empty on shutdown"; clusterService.removeListener(this); } } @@ -199,7 +214,7 @@ public void clusterChanged(ClusterChangedEvent event) { @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { if (indexShard != null && indexShard.indexSettings().isSegRepEnabledOrRemoteNode()) { - onGoingReplications.cancelForShard(indexShard.shardId(), "Shard closing"); + replicator.cancel(indexShard.shardId(), "Shard closing"); latestReceivedCheckpoint.remove(shardId); } } @@ -224,7 +239,7 @@ public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting ol && indexShard.indexSettings().isSegRepEnabledOrRemoteNode() && oldRouting.primary() == false && newRouting.primary()) { - onGoingReplications.cancelForShard(indexShard.shardId(), "Shard has been promoted to primary"); + replicator.cancel(indexShard.shardId(), "Shard has been promoted to primary"); latestReceivedCheckpoint.remove(indexShard.shardId()); } } @@ -234,9 +249,7 @@ public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting ol */ @Nullable public SegmentReplicationState getOngoingEventSegmentReplicationState(ShardId shardId) { - return Optional.ofNullable(onGoingReplications.getOngoingReplicationTarget(shardId)) - .map(SegmentReplicationTarget::state) - .orElse(null); + return Optional.ofNullable(replicator.get(shardId)).map(SegmentReplicationTarget::state).orElse(null); } /** @@ -244,7 +257,7 @@ public SegmentReplicationState getOngoingEventSegmentReplicationState(ShardId sh */ @Nullable public SegmentReplicationState getlatestCompletedEventSegmentReplicationState(ShardId shardId) { - return completedReplications.get(shardId); + return replicator.getCompleted(shardId); } /** @@ -257,11 +270,11 @@ public SegmentReplicationState getSegmentReplicationState(ShardId shardId) { } public ReplicationRef get(long replicationId) { - return onGoingReplications.get(replicationId); + return replicator.get(replicationId); } public SegmentReplicationTarget get(ShardId shardId) { - return onGoingReplications.getOngoingReplicationTarget(shardId); + return replicator.get(shardId); } /** @@ -285,7 +298,7 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe // checkpoint to be replayed once the shard is Active. if (replicaShard.state().equals(IndexShardState.STARTED) == true) { // Checks if received checkpoint is already present and ahead then it replaces old received checkpoint - SegmentReplicationTarget ongoingReplicationTarget = onGoingReplications.getOngoingReplicationTarget(replicaShard.shardId()); + SegmentReplicationTarget ongoingReplicationTarget = replicator.get(replicaShard.shardId()); if (ongoingReplicationTarget != null) { if (ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) { logger.debug( @@ -504,28 +517,12 @@ public SegmentReplicationTarget startReplication( final ReplicationCheckpoint checkpoint, final SegmentReplicationListener listener ) { - final SegmentReplicationTarget target = new SegmentReplicationTarget( - indexShard, - checkpoint, - sourceFactory.get(indexShard), - listener - ); - startReplication(target); - return target; + return replicator.startReplication(indexShard, checkpoint, sourceFactory.get(indexShard), listener); } // pkg-private for integration tests void startReplication(final SegmentReplicationTarget target) { - final long replicationId; - try { - replicationId = onGoingReplications.startSafe(target, recoverySettings.activityTimeout()); - } catch (ReplicationFailedException e) { - // replication already running for shard. - target.fail(e, false); - return; - } - logger.trace(() -> new ParameterizedMessage("Added new replication to collection {}", target.description())); - threadPool.generic().execute(new ReplicationRunner(replicationId)); + replicator.startReplication(target); } /** @@ -550,81 +547,6 @@ default void onFailure(ReplicationState state, ReplicationFailedException e, boo void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure); } - /** - * Runnable implementation to trigger a replication event. - */ - private class ReplicationRunner extends AbstractRunnable { - - final long replicationId; - - public ReplicationRunner(long replicationId) { - this.replicationId = replicationId; - } - - @Override - public void onFailure(Exception e) { - onGoingReplications.fail(replicationId, new ReplicationFailedException("Unexpected Error during replication", e), false); - } - - @Override - public void doRun() { - start(replicationId); - } - } - - private void start(final long replicationId) { - final SegmentReplicationTarget target; - try (ReplicationRef replicationRef = onGoingReplications.get(replicationId)) { - // This check is for handling edge cases where the reference is removed before the ReplicationRunner is started by the - // threadpool. - if (replicationRef == null) { - return; - } - target = replicationRef.get(); - } - target.startReplication(new ActionListener<>() { - @Override - public void onResponse(Void o) { - logger.debug(() -> new ParameterizedMessage("Finished replicating {} marking as done.", target.description())); - onGoingReplications.markAsDone(replicationId); - if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0) { - completedReplications.put(target.shardId(), target.state()); - } - } - - @Override - public void onFailure(Exception e) { - logger.debug("Replication failed {}", target.description()); - if (isStoreCorrupt(target) || e instanceof CorruptIndexException || e instanceof OpenSearchCorruptionException) { - onGoingReplications.fail(replicationId, new ReplicationFailedException("Store corruption during replication", e), true); - return; - } - onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", e), false); - } - }); - } - - private boolean isStoreCorrupt(SegmentReplicationTarget target) { - // ensure target is not already closed. In that case - // we can assume the store is not corrupt and that the replication - // event completed successfully. - if (target.refCount() > 0) { - final Store store = target.store(); - if (store.tryIncRef()) { - try { - return store.isMarkedCorrupted(); - } catch (IOException ex) { - logger.warn("Unable to determine if store is corrupt", ex); - return false; - } finally { - store.decRef(); - } - } - } - // store already closed. - return false; - } - private class FileChunkTransportRequestHandler implements TransportRequestHandler { // How many bytes we've copied since we last called RateLimiter.pause @@ -632,7 +554,7 @@ private class FileChunkTransportRequestHandler implements TransportRequestHandle @Override public void messageReceived(final FileChunkRequest request, TransportChannel channel, Task task) throws Exception { - try (ReplicationRef ref = onGoingReplications.getSafe(request.recoveryId(), request.shardId())) { + try (ReplicationRef ref = replicator.get(request.recoveryId(), request.shardId())) { final SegmentReplicationTarget target = ref.get(); final ActionListener listener = target.createOrFinishListener(channel, Actions.FILE_CHUNK, request); target.handleFileChunk(request, target, bytesSinceLastPause, recoverySettings.replicationRateLimiter(), listener); diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicator.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicator.java new file mode 100644 index 0000000000000..3d25b9ad53e9c --- /dev/null +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicator.java @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.replication; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.index.CorruptIndexException; +import org.opensearch.OpenSearchCorruptionException; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.ConcurrentCollections; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.store.Store; +import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import org.opensearch.indices.replication.common.ReplicationCollection; +import org.opensearch.indices.replication.common.ReplicationFailedException; +import org.opensearch.indices.replication.common.ReplicationListener; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Map; + +/** + * This class is responsible for managing segment replication events on replicas. + * It uses a {@link ReplicationCollection} to track ongoing replication events and + * manages the state of each replication event. + * + * @opensearch.internal + */ +public class SegmentReplicator { + + private static final Logger logger = LogManager.getLogger(SegmentReplicator.class); + + private final ReplicationCollection onGoingReplications; + private final Map completedReplications = ConcurrentCollections.newConcurrentMap(); + private final ThreadPool threadPool; + + public SegmentReplicator(ThreadPool threadPool) { + this.onGoingReplications = new ReplicationCollection<>(logger, threadPool); + this.threadPool = threadPool; + } + + // TODO: Add public entrypoint for replication on an interval to be invoked via IndexService + + /** + * Start a round of replication and sync to at least the given checkpoint. + * @param indexShard - {@link IndexShard} replica shard + * @param checkpoint - {@link ReplicationCheckpoint} checkpoint to sync to + * @param listener - {@link ReplicationListener} + * @return {@link SegmentReplicationTarget} target event orchestrating the event. + */ + SegmentReplicationTarget startReplication( + final IndexShard indexShard, + final ReplicationCheckpoint checkpoint, + final SegmentReplicationSource source, + final SegmentReplicationTargetService.SegmentReplicationListener listener + ) { + final SegmentReplicationTarget target = new SegmentReplicationTarget(indexShard, checkpoint, source, listener); + startReplication(target); + return target; + } + + /** + * Runnable implementation to trigger a replication event. + */ + private class ReplicationRunner extends AbstractRunnable { + + final long replicationId; + + public ReplicationRunner(long replicationId) { + this.replicationId = replicationId; + } + + @Override + public void onFailure(Exception e) { + onGoingReplications.fail(replicationId, new ReplicationFailedException("Unexpected Error during replication", e), false); + } + + @Override + public void doRun() { + start(replicationId); + } + } + + private void start(final long replicationId) { + final SegmentReplicationTarget target; + try (ReplicationCollection.ReplicationRef replicationRef = onGoingReplications.get(replicationId)) { + // This check is for handling edge cases where the reference is removed before the ReplicationRunner is started by the + // threadpool. + if (replicationRef == null) { + return; + } + target = replicationRef.get(); + } + target.startReplication(new ActionListener<>() { + @Override + public void onResponse(Void o) { + logger.debug(() -> new ParameterizedMessage("Finished replicating {} marking as done.", target.description())); + onGoingReplications.markAsDone(replicationId); + if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0) { + completedReplications.put(target.shardId(), target.state()); + } + } + + @Override + public void onFailure(Exception e) { + logger.debug("Replication failed {}", target.description()); + if (isStoreCorrupt(target) || e instanceof CorruptIndexException || e instanceof OpenSearchCorruptionException) { + onGoingReplications.fail(replicationId, new ReplicationFailedException("Store corruption during replication", e), true); + return; + } + onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", e), false); + } + }); + } + + // pkg-private for integration tests + void startReplication(final SegmentReplicationTarget target) { + final long replicationId; + try { + replicationId = onGoingReplications.startSafe(target, target.indexShard().getRecoverySettings().internalActionTimeout()); + } catch (ReplicationFailedException e) { + // replication already running for shard. + target.fail(e, false); + return; + } + logger.trace(() -> new ParameterizedMessage("Added new replication to collection {}", target.description())); + threadPool.generic().execute(new ReplicationRunner(replicationId)); + } + + private boolean isStoreCorrupt(SegmentReplicationTarget target) { + // ensure target is not already closed. In that case + // we can assume the store is not corrupt and that the replication + // event completed successfully. + if (target.refCount() > 0) { + final Store store = target.store(); + if (store.tryIncRef()) { + try { + return store.isMarkedCorrupted(); + } catch (IOException ex) { + logger.warn("Unable to determine if store is corrupt", ex); + return false; + } finally { + store.decRef(); + } + } + } + // store already closed. + return false; + } + + int size() { + return onGoingReplications.size(); + } + + void cancel(ShardId shardId, String reason) { + onGoingReplications.cancelForShard(shardId, reason); + } + + SegmentReplicationTarget get(ShardId shardId) { + return onGoingReplications.getOngoingReplicationTarget(shardId); + } + + ReplicationCollection.ReplicationRef get(long id) { + return onGoingReplications.get(id); + } + + SegmentReplicationState getCompleted(ShardId shardId) { + return completedReplications.get(shardId); + } + + ReplicationCollection.ReplicationRef get(long id, ShardId shardId) { + return onGoingReplications.getSafe(id, shardId); + } +} diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 7e867d3966ff5..3ba761f797632 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -178,6 +178,7 @@ import org.opensearch.indices.replication.SegmentReplicationSourceFactory; import org.opensearch.indices.replication.SegmentReplicationSourceService; import org.opensearch.indices.replication.SegmentReplicationTargetService; +import org.opensearch.indices.replication.SegmentReplicator; import org.opensearch.indices.store.IndicesStore; import org.opensearch.ingest.IngestService; import org.opensearch.monitor.MonitorService; @@ -1411,7 +1412,8 @@ protected Node( transportService, new SegmentReplicationSourceFactory(transportService, recoverySettings, clusterService), indicesService, - clusterService + clusterService, + new SegmentReplicator(threadPool) ) ); b.bind(SegmentReplicationSourceService.class)