From d1cc1201c8849aa350a6f426fb9e6e08e350ce29 Mon Sep 17 00:00:00 2001 From: Marc Handalian Date: Mon, 3 Jul 2023 10:24:25 -0700 Subject: [PATCH] [Segment Replication] Fix bug where ReplicationListeners would not complete on target cancellation. This change updates cancellation with Segment Replication to ensure all listeners are resolved. It does this by requesting cancellation before shard closure instead of using ReplicationCollection's cancelForShard which immediately removes it from the replicationCollection. This would cause the underlying ReplicationListener to never get invoked on close. This change includes new tests using suite scope to catch for any open tasks. This caught other locations where this was possible: 1. On a replica during force sync if the shard was closed while resolving its listeners, it would never call back to the primary if an exception was caught in the onDone method. - Fixed by refactoring those paths to use a ChannelActionListener and always reply to primary. 2. On the primary during forceSync, the primary would not successfully cancel before shard close during a forceSync, Fixed by wrapping the synchronous recoveryTarget::forceSync call in cancellableThreads. Signed-off-by: Marc Handalian --- .../replication/SegmentReplicationBaseIT.java | 17 +- .../SegmentReplicationSuiteIT.java | 88 ++++++++ .../index/engine/NRTReplicationEngine.java | 7 +- .../org/opensearch/index/store/Store.java | 23 +- .../recovery/RecoverySourceHandler.java | 2 +- .../replication/SegmentReplicationState.java | 11 +- .../replication/SegmentReplicationTarget.java | 186 +++++++--------- .../SegmentReplicationTargetService.java | 202 +++++++++++------- .../common/ReplicationCollection.java | 36 ++++ .../replication/common/ReplicationTarget.java | 1 + .../SegmentReplicationIndexShardTests.java | 124 ++++++++++- .../SegmentReplicationTargetServiceTests.java | 137 +++++++----- .../SegmentReplicationTargetTests.java | 6 +- .../recovery/ReplicationCollectionTests.java | 29 ++- 14 files changed, 592 insertions(+), 277 deletions(-) create mode 100644 server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java index 52fe85b51cebd..49ee7f1f3f594 100644 --- a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationBaseIT.java @@ -25,6 +25,7 @@ import org.opensearch.index.SegmentReplicationPerGroupStats; import org.opensearch.index.SegmentReplicationShardStats; import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.ShardId; import org.opensearch.index.store.Store; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.IndicesService; @@ -186,9 +187,23 @@ protected void verifyStoreContent() throws Exception { } private IndexShard getIndexShard(ClusterState state, ShardRouting routing, String indexName) { - return getIndexShard(state.nodes().get(routing.currentNodeId()).getName(), indexName); + return getIndexShard(state.nodes().get(routing.currentNodeId()).getName(), routing.shardId(), indexName); } + /** + * Fetch IndexShard by shardId, multiple shards per node allowed. + */ + protected IndexShard getIndexShard(String node, ShardId shardId, String indexName) { + final Index index = resolveIndex(indexName); + IndicesService indicesService = internalCluster().getInstance(IndicesService.class, node); + IndexService indexService = indicesService.indexServiceSafe(index); + final Optional id = indexService.shardIds().stream().filter(sid -> sid == shardId.id()).findFirst(); + return indexService.getShard(id.get()); + } + + /** + * Fetch IndexShard, assumes only a single shard per node. + */ protected IndexShard getIndexShard(String node, String indexName) { final Index index = resolveIndex(indexName); IndicesService indicesService = internalCluster().getInstance(IndicesService.class, node); diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java new file mode 100644 index 0000000000000..9025c1cc79884 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.replication; + +import org.junit.Before; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.common.settings.Settings; +import org.opensearch.indices.replication.common.ReplicationType; +import org.opensearch.test.OpenSearchIntegTestCase; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, minNumDataNodes = 2) +public class SegmentReplicationSuiteIT extends SegmentReplicationBaseIT { + + @Before + public void setup() { + internalCluster().startClusterManagerOnlyNode(); + createIndex(INDEX_NAME); + } + + @Override + public Settings indexSettings() { + final Settings.Builder builder = Settings.builder() + .put(super.indexSettings()) + // reset shard & replica count to random values set by OpenSearchIntegTestCase. + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards()) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numberOfReplicas()) + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT); + + // TODO: Randomly enable remote store on these tests. + return builder.build(); + } + + public void testBasicReplication() throws Exception { + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(); + ensureGreen(INDEX_NAME); + verifyStoreContent(); + } + + public void testDropRandomNodeDuringReplication() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(2); + internalCluster().startClusterManagerOnlyNodes(1); + + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(); + + internalCluster().restartRandomDataNode(); + + ensureYellow(INDEX_NAME); + client().prepareIndex(INDEX_NAME).setId(Integer.toString(docCount)).setSource("field", "value" + docCount).execute().get(); + internalCluster().startDataOnlyNode(); + client().admin().indices().delete(new DeleteIndexRequest(INDEX_NAME)).actionGet(); + } + + public void testDeleteIndexWhileReplicating() throws Exception { + internalCluster().startClusterManagerOnlyNode(); + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(INDEX_NAME); + client().admin().indices().delete(new DeleteIndexRequest(INDEX_NAME)).actionGet(); + } + + public void testFullRestartDuringReplication() throws Exception { + internalCluster().startNode(); + final int docCount = scaledRandomIntBetween(10, 200); + for (int i = 0; i < docCount; i++) { + client().prepareIndex(INDEX_NAME).setId(Integer.toString(i)).setSource("field", "value" + i).execute().get(); + } + refresh(INDEX_NAME); + internalCluster().fullRestart(); + ensureGreen(INDEX_NAME); + } +} diff --git a/server/src/main/java/org/opensearch/index/engine/NRTReplicationEngine.java b/server/src/main/java/org/opensearch/index/engine/NRTReplicationEngine.java index 50b5fbb8596a6..e982ceae2dc07 100644 --- a/server/src/main/java/org/opensearch/index/engine/NRTReplicationEngine.java +++ b/server/src/main/java/org/opensearch/index/engine/NRTReplicationEngine.java @@ -126,12 +126,7 @@ private NRTReplicationReaderManager buildReaderManager() throws IOException { (files) -> { store.decRefFileDeleter(files); try { - store.cleanupAndPreserveLatestCommitPoint( - "On reader closed", - getLatestSegmentInfos(), - getLastCommittedSegmentInfos(), - false - ); + store.cleanupAndPreserveLatestCommitPoint("On reader closed", getLatestSegmentInfos(), getLastCommittedSegmentInfos()); } catch (IOException e) { // Log but do not rethrow - we can try cleaning up again after next replication cycle. // If that were to fail, the shard will as well. diff --git a/server/src/main/java/org/opensearch/index/store/Store.java b/server/src/main/java/org/opensearch/index/store/Store.java index 90832b4c77756..8e267144d1277 100644 --- a/server/src/main/java/org/opensearch/index/store/Store.java +++ b/server/src/main/java/org/opensearch/index/store/Store.java @@ -799,7 +799,7 @@ public void cleanupAndVerify(String reason, MetadataSnapshot sourceMetadata) thr * @throws IllegalStateException if the latest snapshot in this store differs from the given one after the cleanup. */ public void cleanupAndPreserveLatestCommitPoint(String reason, SegmentInfos infos) throws IOException { - this.cleanupAndPreserveLatestCommitPoint(reason, infos, readLastCommittedSegmentsInfo(), true); + this.cleanupAndPreserveLatestCommitPoint(reason, infos, readLastCommittedSegmentsInfo()); } /** @@ -816,33 +816,24 @@ public void cleanupAndPreserveLatestCommitPoint(String reason, SegmentInfos info * @param reason the reason for this cleanup operation logged for each deleted file * @param infos {@link SegmentInfos} Files from this infos will be preserved on disk if present. * @param lastCommittedSegmentInfos {@link SegmentInfos} Last committed segment infos - * @param deleteTempFiles Does this clean up delete temporary replication files * * @throws IllegalStateException if the latest snapshot in this store differs from the given one after the cleanup. */ - public void cleanupAndPreserveLatestCommitPoint( - String reason, - SegmentInfos infos, - SegmentInfos lastCommittedSegmentInfos, - boolean deleteTempFiles - ) throws IOException { + public void cleanupAndPreserveLatestCommitPoint(String reason, SegmentInfos infos, SegmentInfos lastCommittedSegmentInfos) + throws IOException { assert indexSettings.isSegRepEnabled(); // fetch a snapshot from the latest on disk Segments_N file. This can be behind // the passed in local in memory snapshot, so we want to ensure files it references are not removed. metadataLock.writeLock().lock(); try (Lock writeLock = directory.obtainLock(IndexWriter.WRITE_LOCK_NAME)) { - cleanupFiles(reason, lastCommittedSegmentInfos.files(true), infos.files(true), deleteTempFiles); + cleanupFiles(reason, lastCommittedSegmentInfos.files(true), infos.files(true)); } finally { metadataLock.writeLock().unlock(); } } - private void cleanupFiles( - String reason, - Collection localSnapshot, - @Nullable Collection additionalFiles, - boolean deleteTempFiles - ) throws IOException { + private void cleanupFiles(String reason, Collection localSnapshot, @Nullable Collection additionalFiles) + throws IOException { assert metadataLock.isWriteLockedByCurrentThread(); for (String existingFile : directory.listAll()) { if (Store.isAutogenerated(existingFile) @@ -851,7 +842,7 @@ private void cleanupFiles( // also ensure we are not deleting a file referenced by an active reader. || replicaFileTracker != null && replicaFileTracker.canDelete(existingFile) == false // prevent temporary file deletion during reader cleanup - || deleteTempFiles == false && existingFile.startsWith(REPLICATION_PREFIX)) { + || existingFile.startsWith(REPLICATION_PREFIX)) { // don't delete snapshot file, or the checksums file (note, this is extra protection since the Store won't delete // checksum) continue; diff --git a/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java b/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java index 5e278f06cfb8f..0b343fb0b0871 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java @@ -835,7 +835,7 @@ void finalizeRecovery(long targetLocalCheckpoint, long trimAboveSeqNo, ActionLis } else { // Force round of segment replication to update its checkpoint to primary's if (shard.indexSettings().isSegRepEnabled()) { - recoveryTarget.forceSegmentFileSync(); + cancellableThreads.execute(recoveryTarget::forceSegmentFileSync); } } stopWatch.stop(); diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java index 7a996ec7aedaa..226ccbaf01afa 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationState.java @@ -45,8 +45,7 @@ public enum Stage { GET_CHECKPOINT_INFO((byte) 3), FILE_DIFF((byte) 4), GET_FILES((byte) 5), - FINALIZE_REPLICATION((byte) 6), - CANCELLED((byte) 7); + FINALIZE_REPLICATION((byte) 6); private static final Stage[] STAGES = new Stage[Stage.values().length]; @@ -245,14 +244,6 @@ public void setStage(Stage stage) { overallTimer.stop(); timingData.put("OVERALL", overallTimer.time()); break; - case CANCELLED: - if (this.stage == Stage.DONE) { - throw new IllegalStateException("can't move replication to Cancelled state from Done."); - } - this.stage = Stage.CANCELLED; - overallTimer.stop(); - timingData.put("OVERALL", overallTimer.time()); - break; default: throw new IllegalArgumentException("unknown SegmentReplicationState.Stage [" + stage + "]"); } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java index 22c68ad46fea6..c2ef735ecab49 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTarget.java @@ -13,14 +13,16 @@ import org.apache.lucene.index.IndexFormatTooNewException; import org.apache.lucene.index.IndexFormatTooOldException; import org.apache.lucene.index.SegmentInfos; +import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.store.BufferedChecksumIndexInput; import org.apache.lucene.store.ByteBuffersDataInput; import org.apache.lucene.store.ByteBuffersIndexInput; import org.apache.lucene.store.ChecksumIndexInput; -import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchCorruptionException; import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.action.StepListener; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.UUIDs; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.lucene.Lucene; @@ -38,6 +40,8 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.List; +import java.util.Locale; /** * Represents the target of a replication event. @@ -101,17 +105,11 @@ public SegmentReplicationTarget retryCopy() { @Override public String description() { - return "Segment replication from " + source.toString(); + return String.format(Locale.ROOT, "Id:[%d] Shard:[%s] Source:[%s]", getId(), shardId(), source.getDescription()); } @Override public void notifyListener(ReplicationFailedException e, boolean sendShardFailure) { - // Cancellations still are passed to our SegmentReplicationListener as failures, if we have failed because of cancellation - // update the stage. - final Throwable cancelledException = ExceptionsHelper.unwrap(e, CancellableThreads.ExecutionCancelledException.class); - if (cancelledException != null) { - state.setStage(SegmentReplicationState.Stage.CANCELLED); - } listener.onFailure(state(), e, sendShardFailure); } @@ -140,144 +138,119 @@ public void writeFileChunk( /** * Start the Replication event. + * * @param listener {@link ActionListener} listener. */ public void startReplication(ActionListener listener) { cancellableThreads.setOnCancel((reason, beforeCancelEx) -> { - // This method only executes when cancellation is triggered by this node and caught by a call to checkForCancel, - // SegmentReplicationSource does not share CancellableThreads. - final CancellableThreads.ExecutionCancelledException executionCancelledException = - new CancellableThreads.ExecutionCancelledException("replication was canceled reason [" + reason + "]"); - notifyListener(new ReplicationFailedException("Segment replication failed", executionCancelledException), false); - throw executionCancelledException; + throw new CancellableThreads.ExecutionCancelledException("replication was canceled reason [" + reason + "]"); }); + // TODO: Remove this useless state. state.setStage(SegmentReplicationState.Stage.REPLICATING); final StepListener checkpointInfoListener = new StepListener<>(); final StepListener getFilesListener = new StepListener<>(); - final StepListener finalizeListener = new StepListener<>(); - cancellableThreads.checkForCancel(); - logger.trace("[shardId {}] Replica starting replication [id {}]", shardId().getId(), getId()); + logger.trace(new ParameterizedMessage("Starting Replication Target: {}", description())); // Get list of files to copy from this checkpoint. state.setStage(SegmentReplicationState.Stage.GET_CHECKPOINT_INFO); + cancellableThreads.checkForCancel(); source.getCheckpointMetadata(getId(), checkpoint, checkpointInfoListener); - checkpointInfoListener.whenComplete(checkpointInfo -> getFiles(checkpointInfo, getFilesListener), listener::onFailure); - getFilesListener.whenComplete( - response -> finalizeReplication(checkpointInfoListener.result(), finalizeListener), - listener::onFailure - ); - finalizeListener.whenComplete(r -> listener.onResponse(null), listener::onFailure); + checkpointInfoListener.whenComplete(checkpointInfo -> { + final List filesToFetch = getFiles(checkpointInfo); + state.setStage(SegmentReplicationState.Stage.GET_FILES); + cancellableThreads.checkForCancel(); + source.getSegmentFiles(getId(), checkpointInfo.getCheckpoint(), filesToFetch, indexShard, getFilesListener); + }, listener::onFailure); + + getFilesListener.whenComplete(response -> { + finalizeReplication(checkpointInfoListener.result()); + listener.onResponse(null); + }, listener::onFailure); } - private void getFiles(CheckpointInfoResponse checkpointInfo, StepListener getFilesListener) - throws IOException { + private List getFiles(CheckpointInfoResponse checkpointInfo) throws IOException { cancellableThreads.checkForCancel(); state.setStage(SegmentReplicationState.Stage.FILE_DIFF); final Store.RecoveryDiff diff = Store.segmentReplicationDiff(checkpointInfo.getMetadataMap(), indexShard.getSegmentMetadataMap()); - logger.trace("Replication diff for checkpoint {} {}", checkpointInfo.getCheckpoint(), diff); + logger.trace(() -> new ParameterizedMessage("Replication diff for checkpoint {} {}", checkpointInfo.getCheckpoint(), diff)); /* * Segments are immutable. So if the replica has any segments with the same name that differ from the one in the incoming * snapshot from source that means the local copy of the segment has been corrupted/changed in some way and we throw an * IllegalStateException to fail the shard */ if (diff.different.isEmpty() == false) { - IllegalStateException illegalStateException = new IllegalStateException( + throw new OpenSearchCorruptionException( new ParameterizedMessage( "Shard {} has local copies of segments that differ from the primary {}", indexShard.shardId(), diff.different ).getFormattedMessage() ); - ReplicationFailedException rfe = new ReplicationFailedException( - indexShard.shardId(), - "different segment files", - illegalStateException - ); - fail(rfe, true); - throw rfe; } for (StoreFileMetadata file : diff.missing) { state.getIndex().addFileDetail(file.name(), file.length(), false); } - // always send a req even if not fetching files so the primary can clear the copyState for this shard. - state.setStage(SegmentReplicationState.Stage.GET_FILES); - cancellableThreads.checkForCancel(); - source.getSegmentFiles(getId(), checkpointInfo.getCheckpoint(), diff.missing, indexShard, getFilesListener); + return diff.missing; } - private void finalizeReplication(CheckpointInfoResponse checkpointInfoResponse, ActionListener listener) { + private void finalizeReplication(CheckpointInfoResponse checkpointInfoResponse) throws OpenSearchCorruptionException { // TODO: Refactor the logic so that finalize doesn't have to be invoked for remote store as source if (source instanceof RemoteStoreReplicationSource) { - ActionListener.completeWith(listener, () -> { - state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); - return null; - }); + state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); return; } - ActionListener.completeWith(listener, () -> { + cancellableThreads.checkForCancel(); + state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); + Store store = null; + try { + multiFileWriter.renameAllTempFiles(); + store = store(); + store.incRef(); + // Deserialize the new SegmentInfos object sent from the primary. + final ReplicationCheckpoint responseCheckpoint = checkpointInfoResponse.getCheckpoint(); + SegmentInfos infos = SegmentInfos.readCommit( + store.directory(), + toIndexInput(checkpointInfoResponse.getInfosBytes()), + responseCheckpoint.getSegmentsGen() + ); cancellableThreads.checkForCancel(); - state.setStage(SegmentReplicationState.Stage.FINALIZE_REPLICATION); - Store store = null; + indexShard.finalizeReplication(infos); + } catch (CorruptIndexException | IndexFormatTooNewException | IndexFormatTooOldException ex) { + // this is a fatal exception at this stage. + // this means we transferred files from the remote that have not be checksummed and they are + // broken. We have to clean up this shard entirely, remove all files and bubble it up to the + // source shard since this index might be broken there as well? The Source can handle this and checks + // its content on disk if possible. try { - multiFileWriter.renameAllTempFiles(); - store = store(); - store.incRef(); - // Deserialize the new SegmentInfos object sent from the primary. - final ReplicationCheckpoint responseCheckpoint = checkpointInfoResponse.getCheckpoint(); - SegmentInfos infos = SegmentInfos.readCommit( - store.directory(), - toIndexInput(checkpointInfoResponse.getInfosBytes()), - responseCheckpoint.getSegmentsGen() - ); - cancellableThreads.checkForCancel(); - indexShard.finalizeReplication(infos); - } catch (CorruptIndexException | IndexFormatTooNewException | IndexFormatTooOldException ex) { - // this is a fatal exception at this stage. - // this means we transferred files from the remote that have not be checksummed and they are - // broken. We have to clean up this shard entirely, remove all files and bubble it up to the - // source shard since this index might be broken there as well? The Source can handle this and checks - // its content on disk if possible. try { - try { - store.removeCorruptionMarker(); - } finally { - Lucene.cleanLuceneIndex(store.directory()); // clean up and delete all files - } - } catch (Exception e) { - logger.debug("Failed to clean lucene index", e); - ex.addSuppressed(e); - } - ReplicationFailedException rfe = new ReplicationFailedException( - indexShard.shardId(), - "failed to clean after replication", - ex - ); - fail(rfe, true); - throw rfe; - } catch (OpenSearchException ex) { - /* - Ignore closed replication target as it can happen due to index shard closed event in a separate thread. - In such scenario, ignore the exception - */ - assert cancellableThreads.isCancelled() : "Replication target closed but segment replication not cancelled"; - logger.info("Replication target closed", ex); - } catch (Exception ex) { - ReplicationFailedException rfe = new ReplicationFailedException( - indexShard.shardId(), - "failed to clean after replication", - ex - ); - fail(rfe, true); - throw rfe; - } finally { - if (store != null) { - store.decRef(); + store.removeCorruptionMarker(); + } finally { + Lucene.cleanLuceneIndex(store.directory()); // clean up and delete all files } + } catch (Exception e) { + logger.debug("Failed to clean lucene index", e); + ex.addSuppressed(e); } - return null; - }); + throw new OpenSearchCorruptionException(ex); + } catch (AlreadyClosedException ex) { + // In this case the shard is closed at some point while updating the reader. + // This can happen when the engine is closed in a separate thread. + logger.warn("Shard is already closed, closing replication"); + } catch (OpenSearchException ex) { + /* + Ignore closed replication target as it can happen due to index shard closed event in a separate thread. + In such scenario, ignore the exception + */ + assert cancellableThreads.isCancelled() : "Replication target closed but segment replication not cancelled"; + } catch (Exception ex) { + throw new OpenSearchCorruptionException(ex); + } finally { + if (store != null) { + store.decRef(); + } + } } /** @@ -290,10 +263,15 @@ private ChecksumIndexInput toIndexInput(byte[] input) { ); } + /** + * Trigger a cancellation, this method will not close the target a subsequent call to #fail is required from target service. + */ @Override - protected void onCancel(String reason) { - cancellableThreads.cancel(reason); - source.cancel(); - multiFileWriter.close(); + public void cancel(String reason) { + if (finished.get() == false) { + logger.trace(new ParameterizedMessage("Cancelling replication for target {}", description())); + cancellableThreads.cancel(reason); + source.cancel(); + } } } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index a7e0c0ec887ab..e0fcbcd7ef79f 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -11,14 +11,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchCorruptionException; import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ChannelActionListener; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.CancellableThreads; +import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.index.shard.IndexEventListener; import org.opensearch.index.shard.IndexShard; @@ -43,10 +44,8 @@ import org.opensearch.transport.TransportResponse; import org.opensearch.transport.TransportService; -import java.io.IOException; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import static org.opensearch.indices.replication.SegmentReplicationSourceService.Actions.UPDATE_VISIBLE_CHECKPOINT; @@ -145,7 +144,7 @@ public SegmentReplicationTargetService( @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { if (indexShard != null && indexShard.indexSettings().isSegRepEnabled()) { - onGoingReplications.cancelForShard(shardId, "shard closed"); + onGoingReplications.requestCancel(indexShard.shardId(), "Shard closing"); latestReceivedCheckpoint.remove(shardId); } } @@ -167,7 +166,7 @@ public void afterIndexShardStarted(IndexShard indexShard) { @Override public void shardRoutingChanged(IndexShard indexShard, @Nullable ShardRouting oldRouting, ShardRouting newRouting) { if (oldRouting != null && indexShard.indexSettings().isSegRepEnabled() && oldRouting.primary() == false && newRouting.primary()) { - onGoingReplications.cancelForShard(indexShard.shardId(), "shard has been promoted to primary"); + onGoingReplications.requestCancel(indexShard.shardId(), "Shard closing"); latestReceivedCheckpoint.remove(indexShard.shardId()); } } @@ -221,14 +220,17 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe if (replicaShard.state().equals(IndexShardState.STARTED) == true) { // Checks if received checkpoint is already present and ahead then it replaces old received checkpoint SegmentReplicationTarget ongoingReplicationTarget = onGoingReplications.getOngoingReplicationTarget(replicaShard.shardId()); + if (ongoingReplicationTarget != null) { if (ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) { logger.trace( - "Cancelling ongoing replication from old primary with primary term {}", - ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() + () -> new ParameterizedMessage( + "Cancelling ongoing replication {} from old primary with primary term {}", + ongoingReplicationTarget.description(), + ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() + ) ); - onGoingReplications.cancel(ongoingReplicationTarget.getId(), "Cancelling stuck target after new primary"); - completedReplications.put(replicaShard.shardId(), ongoingReplicationTarget); + ongoingReplicationTarget.cancel("Cancelling ongoing target after new primary"); } else { logger.trace( () -> new ParameterizedMessage( @@ -268,21 +270,20 @@ public void onReplicationFailure( ReplicationFailedException e, boolean sendShardFailure ) { - logger.trace( + logger.error( () -> new ParameterizedMessage( "[shardId {}] [replication id {}] Replication failed, timing data: {}", replicaShard.shardId().getId(), state.getReplicationId(), state.getTimingData() - ) + ), + e ); if (sendShardFailure == true) { - logger.error("replication failure", e); - replicaShard.failShard("replication failure", e); + failShard(e, replicaShard); } } }); - } } else { logger.trace( @@ -305,7 +306,14 @@ protected void updateVisibleCheckpoint(long replicationId, IndexShard replicaSha final TransportRequestOptions options = TransportRequestOptions.builder() .withTimeout(recoverySettings.internalActionTimeout()) .build(); - logger.debug("Updating replication checkpoint to {}", request.getCheckpoint()); + logger.trace( + () -> new ParameterizedMessage( + "Updating Primary shard that replica {}-{} is synced to checkpoint {}", + replicaShard.shardId(), + replicaShard.routingEntry().allocationId(), + request.getCheckpoint() + ) + ); RetryableTransportClient transportClient = new RetryableTransportClient( transportService, getPrimaryNode(primaryShard), @@ -315,19 +323,23 @@ protected void updateVisibleCheckpoint(long replicationId, IndexShard replicaSha final ActionListener listener = new ActionListener<>() { @Override public void onResponse(Void unused) { - logger.debug( - "Successfully updated replication checkpoint {} for replica {}", - replicaShard.shardId(), - request.getCheckpoint() + logger.trace( + () -> new ParameterizedMessage( + "Successfully updated replication checkpoint {} for replica {}", + replicaShard.shardId(), + request.getCheckpoint() + ) ); } @Override public void onFailure(Exception e) { logger.error( - "Failed to update visible checkpoint for replica {}, {}: {}", - replicaShard.shardId(), - request.getCheckpoint(), + () -> new ParameterizedMessage( + "Failed to update visible checkpoint for replica {}, {}:", + replicaShard.shardId(), + request.getCheckpoint() + ), e ); } @@ -350,6 +362,13 @@ private DiscoveryNode getPrimaryNode(ShardRouting primaryShard) { protected boolean processLatestReceivedCheckpoint(IndexShard replicaShard, Thread thread) { final ReplicationCheckpoint latestPublishedCheckpoint = latestReceivedCheckpoint.get(replicaShard.shardId()); if (latestPublishedCheckpoint != null && latestPublishedCheckpoint.isAheadOf(replicaShard.getLatestReplicationCheckpoint())) { + logger.trace( + () -> new ParameterizedMessage( + "Processing latest received checkpoint for shard {} {}", + replicaShard.shardId(), + latestPublishedCheckpoint + ) + ); Runnable runnable = () -> onNewCheckpoint(latestReceivedCheckpoint.get(replicaShard.shardId()), replicaShard); // Checks if we are using same thread and forks if necessary. if (thread == Thread.currentThread()) { @@ -381,7 +400,15 @@ public SegmentReplicationTarget startReplication(final IndexShard indexShard, fi // pkg-private for integration tests void startReplication(final SegmentReplicationTarget target) { - final long replicationId = onGoingReplications.start(target, recoverySettings.activityTimeout()); + final long replicationId; + try { + replicationId = onGoingReplications.startSafe(target, recoverySettings.activityTimeout()); + } catch (ReplicationFailedException e) { + // replication already running for shard. + target.fail(e, false); + return; + } + logger.trace(() -> new ParameterizedMessage("Added new replication to collection {}", target.description())); threadPool.generic().execute(new ReplicationRunner(replicationId)); } @@ -410,7 +437,7 @@ default void onFailure(ReplicationState state, ReplicationFailedException e, boo /** * Runnable implementation to trigger a replication event. */ - private class ReplicationRunner implements Runnable { + private class ReplicationRunner extends AbstractRunnable { final long replicationId; @@ -419,47 +446,49 @@ public ReplicationRunner(long replicationId) { } @Override - public void run() { + public void onFailure(Exception e) { + try (final ReplicationRef ref = onGoingReplications.get(replicationId)) { + logger.error(() -> new ParameterizedMessage("Error during segment replication, {}", ref.get().description()), e); + } + onGoingReplications.fail(replicationId, new ReplicationFailedException("Unexpected Error during replication"), false); + } + + @Override + public void doRun() { start(replicationId); } } private void start(final long replicationId) { + final SegmentReplicationTarget target; try (ReplicationRef replicationRef = onGoingReplications.get(replicationId)) { // This check is for handling edge cases where the reference is removed before the ReplicationRunner is started by the // threadpool. if (replicationRef == null) { return; } - SegmentReplicationTarget target = onGoingReplications.getTarget(replicationId); - replicationRef.get().startReplication(new ActionListener<>() { - @Override - public void onResponse(Void o) { - onGoingReplications.markAsDone(replicationId); - if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0) { - completedReplications.put(target.shardId(), target); - } - + target = onGoingReplications.getTarget(replicationId); + } + target.startReplication(new ActionListener<>() { + @Override + public void onResponse(Void o) { + logger.trace(() -> new ParameterizedMessage("Finished replicating {} marking as done.", target.description())); + onGoingReplications.markAsDone(replicationId); + if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0) { + completedReplications.put(target.shardId(), target); } + } - @Override - public void onFailure(Exception e) { - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (cause instanceof CancellableThreads.ExecutionCancelledException) { - if (onGoingReplications.getTarget(replicationId) != null) { - IndexShard indexShard = onGoingReplications.getTarget(replicationId).indexShard(); - // if the target still exists in our collection, the primary initiated the cancellation, fail the replication - // but do not fail the shard. Cancellations initiated by this node from Index events will be removed with - // onGoingReplications.cancel and not appear in the collection when this listener resolves. - onGoingReplications.fail(replicationId, new ReplicationFailedException(indexShard, cause), false); - completedReplications.put(target.shardId(), target); - } - } else { - onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", e), false); - } + @Override + public void onFailure(Exception e) { + logger.error(() -> new ParameterizedMessage("Exception replicating {} marking as failed.", target.description()), e); + if (e instanceof OpenSearchCorruptionException) { + onGoingReplications.fail(replicationId, new ReplicationFailedException("Store corruption during replication", e), true); + return; } - }); - } + onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", e), false); + } + }); } private class FileChunkTransportRequestHandler implements TransportRequestHandler { @@ -484,27 +513,31 @@ public void messageReceived(final FileChunkRequest request, TransportChannel cha private class ForceSyncTransportRequestHandler implements TransportRequestHandler { @Override public void messageReceived(final ForceSyncRequest request, TransportChannel channel, Task task) throws Exception { - assert indicesService != null; - final IndexShard indexShard = indicesService.getShardOrNull(request.getShardId()); - // Proceed with round of segment replication only when it is allowed - if (indexShard == null || indexShard.getReplicationEngine().isEmpty()) { - logger.info("Ignore force segment replication sync as it is not allowed"); - channel.sendResponse(TransportResponse.Empty.INSTANCE); - return; - } + forceReplication(request, new ChannelActionListener<>(channel, Actions.FORCE_SYNC, request)); + } + } + + private void forceReplication(ForceSyncRequest request, ActionListener listener) { + final ShardId shardId = request.getShardId(); + assert indicesService != null; + final IndexShard indexShard = indicesService.getShardOrNull(shardId); + // Proceed with round of segment replication only when it is allowed + if (indexShard == null || indexShard.getReplicationEngine().isEmpty()) { + listener.onResponse(TransportResponse.Empty.INSTANCE); + } else { startReplication(indexShard, new SegmentReplicationTargetService.SegmentReplicationListener() { @Override public void onReplicationDone(SegmentReplicationState state) { - logger.trace( - () -> new ParameterizedMessage( - "[shardId {}] [replication id {}] Replication complete to {}, timing data: {}", - indexShard.shardId().getId(), - state.getReplicationId(), - indexShard.getLatestReplicationCheckpoint(), - state.getTimingData() - ) - ); try { + logger.trace( + () -> new ParameterizedMessage( + "[shardId {}] [replication id {}] Force replication Sync complete to {}, timing data: {}", + shardId, + state.getReplicationId(), + indexShard.getLatestReplicationCheckpoint(), + state.getTimingData() + ) + ); // Promote engine type for primary target if (indexShard.recoveryState().getPrimary() == true) { indexShard.resetToWriteableEngine(); @@ -512,33 +545,40 @@ public void onReplicationDone(SegmentReplicationState state) { // Update the replica's checkpoint on primary's replication tracker. updateVisibleCheckpoint(state.getReplicationId(), indexShard); } - channel.sendResponse(TransportResponse.Empty.INSTANCE); - } catch (InterruptedException | TimeoutException | IOException e) { - throw new RuntimeException(e); + listener.onResponse(TransportResponse.Empty.INSTANCE); + } catch (Exception e) { + logger.error("Error while marking replication completed", e); + listener.onFailure(e); } } @Override public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) { - logger.trace( + logger.error( () -> new ParameterizedMessage( "[shardId {}] [replication id {}] Replication failed, timing data: {}", indexShard.shardId().getId(), state.getReplicationId(), state.getTimingData() - ) + ), + e ); - if (sendShardFailure == true) { - indexShard.failShard("replication failure", e); - } - try { - channel.sendResponse(e); - } catch (IOException ex) { - throw new RuntimeException(ex); + if (sendShardFailure) { + failShard(e, indexShard); } + listener.onFailure(e); } }); } } + private void failShard(ReplicationFailedException e, IndexShard indexShard) { + try { + indexShard.failShard("unrecoverable replication failure", e); + } catch (Exception inner) { + logger.error("Error attempting to fail shard", inner); + e.addSuppressed(inner); + } + } + } diff --git a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java index e918ac0a79691..c65ef27969154 100644 --- a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java +++ b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationCollection.java @@ -70,6 +70,26 @@ public ReplicationCollection(Logger logger, ThreadPool threadPool) { this.threadPool = threadPool; } + /** + * Starts a new target event for a given shard, fails the given target if this shard is already replicating. + * @param target ReplicationTarget to start + * @param activityTimeout timeout for entire replication event + * @return The replication id + */ + public long startSafe(T target, TimeValue activityTimeout) { + synchronized (onGoingTargetEvents) { + final boolean isPresent = onGoingTargetEvents.values() + .stream() + .map(ReplicationTarget::shardId) + .anyMatch(t -> t.equals(target.shardId())); + if (isPresent) { + throw new ReplicationFailedException("Shard " + target.shardId() + " is already replicating"); + } else { + return start(target, activityTimeout); + } + } + } + /** * Starts a new target event for the given shard, source node and state * @@ -234,6 +254,22 @@ public boolean cancelForShard(ShardId shardId, String reason) { return cancelled; } + /** + * Trigger cancel on the target but do not remove it from the collection. + * This is intended to be called to ensure replication events are removed from the collection + * only when the target has closed. + * + * @param shardId {@link ShardId} shard events to cancel + * @param reason {@link String} reason for cancellation + */ + public void requestCancel(ShardId shardId, String reason) { + for (T value : onGoingTargetEvents.values()) { + if (value.shardId().equals(shardId)) { + value.cancel(reason); + } + } + } + /** * Get target for shard * diff --git a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java index 4d75ff4896706..344a4040be119 100644 --- a/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java +++ b/server/src/main/java/org/opensearch/indices/replication/common/ReplicationTarget.java @@ -173,6 +173,7 @@ public void cancel(String reason) { public void fail(ReplicationFailedException e, boolean sendShardFailure) { if (finished.compareAndSet(false, true)) { try { + logger.debug("marking target " + description() + " as failed", e); notifyListener(e, sendShardFailure); } finally { try { diff --git a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java index 0c859c5f6a64a..dd6eda62512bc 100644 --- a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.index.SegmentInfos; +import org.apache.lucene.store.AlreadyClosedException; import org.junit.Assert; import org.opensearch.ExceptionsHelper; import org.opensearch.action.ActionListener; @@ -78,11 +79,12 @@ import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.spy; public class SegmentReplicationIndexShardTests extends OpenSearchIndexLevelReplicationTestCase { @@ -974,6 +976,125 @@ public void getSegmentFiles( } } + public void testCloseShardDuringFinalize() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + final IndexShard replicaSpy = spy(replica); + + primary.refresh("Test"); + + doThrow(AlreadyClosedException.class).when(replicaSpy).finalizeReplication(any()); + + replicateSegments(primary, List.of(replicaSpy)); + } + } + + public void testCloseShardWhileGettingCheckpoint() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new TestReplicationSource() { + + ActionListener listener; + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + // set the listener, we will only fail it once we cancel the source. + this.listener = listener; + // shard is closing while we are copying files. + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + IndexShard indexShard, + ActionListener listener + ) { + Assert.fail("Unreachable"); + } + + @Override + public void cancel() { + // simulate listener resolving, but only after we have issued a cancel from beforeIndexShardClosed . + final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("retryable action was cancelled"); + listener.onFailure(exception); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + + public void testBeforeIndexShardClosedWhileCopyingFiles() throws Exception { + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + IndexShard primary = shards.getPrimary(); + final IndexShard replica = shards.getReplicas().get(0); + + primary.refresh("Test"); + + final SegmentReplicationSourceFactory sourceFactory = mock(SegmentReplicationSourceFactory.class); + final SegmentReplicationTargetService targetService = newTargetService(sourceFactory); + SegmentReplicationSource source = new TestReplicationSource() { + + ActionListener listener; + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + resolveCheckpointInfoResponseListener(listener, primary); + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + IndexShard indexShard, + ActionListener listener + ) { + // set the listener, we will only fail it once we cancel the source. + this.listener = listener; + // shard is closing while we are copying files. + targetService.beforeIndexShardClosed(replica.shardId, replica, Settings.EMPTY); + } + + @Override + public void cancel() { + // simulate listener resolving, but only after we have issued a cancel from beforeIndexShardClosed . + final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("retryable action was cancelled"); + listener.onFailure(exception); + } + }; + when(sourceFactory.get(any())).thenReturn(source); + startReplicationAndAssertCancellation(replica, targetService); + + shards.removeReplica(replica); + closeShards(replica); + } + } + public void testPrimaryCancelsExecution() throws Exception { try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { shards.startAll(); @@ -1063,7 +1184,6 @@ public void onReplicationDone(SegmentReplicationState state) { @Override public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) { assertFalse(sendShardFailure); - assertEquals(SegmentReplicationState.Stage.CANCELLED, state.getStage()); latch.countDown(); } } diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index c632f2843cba2..2e234643ddd2f 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -8,8 +8,8 @@ package org.opensearch.indices.replication; +import org.apache.lucene.store.AlreadyClosedException; import org.junit.Assert; -import org.mockito.Mockito; import org.opensearch.OpenSearchException; import org.opensearch.Version; import org.opensearch.action.ActionListener; @@ -49,6 +49,7 @@ import java.util.concurrent.TimeUnit; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; @@ -62,7 +63,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.indices.replication.SegmentReplicationState.Stage.CANCELLED; public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { @@ -240,71 +240,82 @@ public void testAlreadyOnNewCheckpoint() { verify(spy, times(0)).startReplication(any(), any()); } - public void testShardAlreadyReplicating() throws InterruptedException { - // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. - SegmentReplicationTargetService serviceSpy = spy(sut); - final SegmentReplicationTarget target = new SegmentReplicationTarget( - replicaShard, - replicationSource, - mock(SegmentReplicationTargetService.SegmentReplicationListener.class) - ); - // Create a Mockito spy of target to stub response of few method calls. - final SegmentReplicationTarget targetSpy = Mockito.spy(target); - CountDownLatch latch = new CountDownLatch(1); - // Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown - // of latch. - doAnswer(invocation -> { - final ActionListener listener = invocation.getArgument(0); - // a new checkpoint arrives before we've completed. - serviceSpy.onNewCheckpoint(aheadCheckpoint, replicaShard); - listener.onResponse(null); - latch.countDown(); - return null; - }).when(targetSpy).startReplication(any()); - doNothing().when(targetSpy).onDone(); - - // start replication of this shard the first time. - serviceSpy.startReplication(targetSpy); + public void testShardAlreadyReplicating() { + sut.startReplication(replicaShard, mock(SegmentReplicationTargetService.SegmentReplicationListener.class)); + sut.startReplication(replicaShard, new SegmentReplicationTargetService.SegmentReplicationListener() { + @Override + public void onReplicationDone(SegmentReplicationState state) { + Assert.fail("Should not succeed"); + } - // wait for the new checkpoint to arrive, before the listener completes. - latch.await(30, TimeUnit.SECONDS); - verify(targetSpy, times(0)).cancel(any()); - verify(serviceSpy, times(0)).startReplication(eq(replicaShard), any()); + @Override + public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) { + assertEquals("Shard " + replicaShard.shardId() + " is already replicating", e.getMessage()); + assertFalse(sendShardFailure); + } + }); } - public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws IOException, InterruptedException { + public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws InterruptedException { // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. SegmentReplicationTargetService serviceSpy = spy(sut); + doNothing().when(serviceSpy).updateVisibleCheckpoint(anyLong(), any()); + // skip post replication actions so we can assert execution counts. This will continue to process bc replica's pterm is not advanced + // post replication. + doReturn(true).when(serviceSpy).processLatestReceivedCheckpoint(any(), any()); // Create a Mockito spy of target to stub response of few method calls. - final SegmentReplicationTarget targetSpy = spy( - new SegmentReplicationTarget( - replicaShard, - replicationSource, - mock(SegmentReplicationTargetService.SegmentReplicationListener.class) - ) - ); CountDownLatch latch = new CountDownLatch(1); - // Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown - // of latch. - doAnswer(invocation -> { - // short circuit loop on new checkpoint request - doReturn(null).when(serviceSpy).startReplication(eq(replicaShard), any()); - // a new checkpoint arrives before we've completed. - serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard); - try { - invocation.callRealMethod(); - } catch (CancellableThreads.ExecutionCancelledException e) { + SegmentReplicationSource source = new TestReplicationSource() { + + ActionListener listener; + + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + // set the listener, we will only fail it once we cancel the source. + logger.info("In test source"); + this.listener = listener; latch.countDown(); + // do not resolve this listener yet, wait for cancel to hit. } - return null; - }).when(targetSpy).startReplication(any()); + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + IndexShard indexShard, + ActionListener listener + ) { + Assert.fail("Unreachable"); + } + + @Override + public void cancel() { + // simulate listener resolving, but only after we have issued a cancel from beforeIndexShardClosed . + final RuntimeException exception = new CancellableThreads.ExecutionCancelledException("retryable action was cancelled"); + listener.onFailure(exception); + } + }; + + final SegmentReplicationTarget targetSpy = spy( + new SegmentReplicationTarget(replicaShard, source, mock(SegmentReplicationTargetService.SegmentReplicationListener.class)) + ); // start replication. This adds the target to on-ongoing replication collection serviceSpy.startReplication(targetSpy); + + // wait until we get to getCheckpoint step. latch.await(); - // wait for the new checkpoint to arrive, before the listener completes. - assertEquals(CANCELLED, targetSpy.state().getStage()); + + // new checkpoint arrives with higher pterm. + serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard); + + // ensure the old target is cancelled. and new iteration kicks off. verify(targetSpy, times(1)).cancel("Cancelling stuck target after new primary"); verify(serviceSpy, times(1)).startReplication(eq(replicaShard), any()); } @@ -467,6 +478,7 @@ public void testForceSegmentSyncHandler() throws Exception { } public void testForceSegmentSyncHandlerWithFailure() throws Exception { + allowShardFailures(); IndexShard spyReplicaShard = spy(replicaShard); ForceSyncRequest forceSyncRequest = new ForceSyncRequest(1L, 1L, replicaShard.shardId()); when(indicesService.getShardOrNull(forceSyncRequest.getShardId())).thenReturn(spyReplicaShard); @@ -488,4 +500,23 @@ public void testForceSegmentSyncHandlerWithFailure() throws Exception { assertTrue(nestedException instanceof IOException); assertTrue(nestedException.getMessage().contains("dummy failure")); } + + public void testForceSegmentSyncHandlerWithFailure_AlreadyClosedException_swallowed() throws Exception { + IndexShard spyReplicaShard = spy(replicaShard); + ForceSyncRequest forceSyncRequest = new ForceSyncRequest(1L, 1L, replicaShard.shardId()); + when(indicesService.getShardOrNull(forceSyncRequest.getShardId())).thenReturn(spyReplicaShard); + + AlreadyClosedException exception = new AlreadyClosedException("shard closed"); + doThrow(exception).when(spyReplicaShard).finalizeReplication(any()); + + // prevent shard failure to avoid test setup assertion + doNothing().when(spyReplicaShard).failShard(eq("replication failure"), any()); + transportService.submitRequest( + localNode, + SegmentReplicationTargetService.Actions.FORCE_SYNC, + forceSyncRequest, + TransportRequestOptions.builder().withTimeout(TRANSPORT_TIMEOUT).build(), + EmptyTransportResponseHandler.INSTANCE_SAME + ).txGet(); + } } diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java index ac8904527f7fb..4fb1edb4e496e 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetTests.java @@ -27,6 +27,7 @@ import org.junit.Assert; import org.mockito.Mockito; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchCorruptionException; import org.opensearch.action.ActionListener; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; @@ -373,8 +374,9 @@ public void onResponse(Void replicationResponse) { @Override public void onFailure(Exception e) { - assert (e instanceof ReplicationFailedException); - assert (e.getMessage().contains("different segment files")); + assertTrue(e instanceof OpenSearchCorruptionException); + assertTrue(e.getMessage().contains("has local copies of segments that differ from the primary")); + segrepTarget.fail(new ReplicationFailedException(e), false); } }); } diff --git a/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java b/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java index 75ac1075e8ee0..f7f7a4ef6d965 100644 --- a/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java +++ b/server/src/test/java/org/opensearch/recovery/ReplicationCollectionTests.java @@ -38,12 +38,15 @@ import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.ShardId; import org.opensearch.index.store.Store; +import org.opensearch.indices.replication.SegmentReplicationSource; +import org.opensearch.indices.replication.SegmentReplicationTarget; import org.opensearch.indices.replication.common.ReplicationCollection; import org.opensearch.indices.replication.common.ReplicationFailedException; import org.opensearch.indices.replication.common.ReplicationListener; import org.opensearch.indices.replication.common.ReplicationState; import org.opensearch.indices.recovery.RecoveryState; import org.opensearch.indices.recovery.RecoveryTarget; +import org.opensearch.indices.replication.common.ReplicationTarget; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -51,6 +54,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.lessThan; +import static org.mockito.Mockito.mock; public class ReplicationCollectionTests extends OpenSearchIndexLevelReplicationTestCase { static final ReplicationListener listener = new ReplicationListener() { @@ -108,7 +112,30 @@ public void onFailure(ReplicationState state, ReplicationFailedException e, bool } } - public void testMultiReplicationsForSingleShard() throws Exception { + public void testStartMultipleReplicationsForSingleShard() throws Exception { + try (ReplicationGroup shards = createGroup(0)) { + shards.startAll(); + final ReplicationCollection collection = new ReplicationCollection<>(logger, threadPool); + final IndexShard shard = shards.addReplica(); + shards.recoverReplica(shard); + final SegmentReplicationTarget target1 = new SegmentReplicationTarget( + shard, + mock(SegmentReplicationSource.class), + mock(ReplicationListener.class) + ); + final SegmentReplicationTarget target2 = new SegmentReplicationTarget( + shard, + mock(SegmentReplicationSource.class), + mock(ReplicationListener.class) + ); + collection.startSafe(target1, TimeValue.ZERO); + assertThrows(ReplicationFailedException.class, () -> collection.startSafe(target2, TimeValue.ZERO)); + // close second target, otherwise it will hold store ref, first target will complete immediately. + target2.decRef(); + } + } + + public void testGetReplicationTargetMultiReplicationsForSingleShard() throws Exception { try (ReplicationGroup shards = createGroup(0)) { final ReplicationCollection collection = new ReplicationCollection<>(logger, threadPool); final IndexShard shard1 = shards.addReplica();