diff --git a/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java b/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java index 883d21154bde5..0fa027744ac62 100644 --- a/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java @@ -38,10 +38,12 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.routing.AllocationId; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.settings.Settings; @@ -53,14 +55,17 @@ import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardState; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardNotFoundException; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportChannelResponseHandler; import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequestHandler; import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportResponse; @@ -69,6 +74,7 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; @@ -115,9 +121,12 @@ protected TransportReplicationAction(Settings settings, String actionName, Trans this.transportPrimaryAction = actionName + "[p]"; this.transportReplicaAction = actionName + "[r]"; transportService.registerRequestHandler(actionName, request, ThreadPool.Names.SAME, new OperationTransportHandler()); - transportService.registerRequestHandler(transportPrimaryAction, request, executor, new PrimaryOperationTransportHandler()); + transportService.registerRequestHandler(transportPrimaryAction, () -> new ConcreteShardRequest<>(request), executor, + new PrimaryOperationTransportHandler()); // we must never reject on because of thread pool capacity on replicas - transportService.registerRequestHandler(transportReplicaAction, replicaRequest, executor, true, true, + transportService.registerRequestHandler(transportReplicaAction, + () -> new ConcreteShardRequest<>(replicaRequest), + executor, true, true, new ReplicaOperationTransportHandler()); this.transportOptions = transportOptions(); @@ -163,7 +172,7 @@ protected void resolveRequest(MetaData metaData, IndexMetaData indexMetaData, Re /** * Synchronous replica operation on nodes with replica copies. This is done under the lock form - * {@link #acquireReplicaOperationLock(ShardId, long, ActionListener)}. + * {@link #acquireReplicaOperationLock(ShardId, long, String, ActionListener)}. */ protected abstract ReplicaResult shardOperationOnReplica(ReplicaRequest shardRequest); @@ -230,33 +239,36 @@ public void messageReceived(Request request, TransportChannel channel) throws Ex } } - class PrimaryOperationTransportHandler implements TransportRequestHandler { + class PrimaryOperationTransportHandler implements TransportRequestHandler> { @Override - public void messageReceived(final Request request, final TransportChannel channel) throws Exception { + public void messageReceived(final ConcreteShardRequest request, final TransportChannel channel) throws Exception { throw new UnsupportedOperationException("the task parameter is required for this operation"); } @Override - public void messageReceived(Request request, TransportChannel channel, Task task) { - new AsyncPrimaryAction(request, channel, (ReplicationTask) task).run(); + public void messageReceived(ConcreteShardRequest request, TransportChannel channel, Task task) { + new AsyncPrimaryAction(request.request, request.targetAllocationID, channel, (ReplicationTask) task).run(); } } class AsyncPrimaryAction extends AbstractRunnable implements ActionListener { private final Request request; + /** targetAllocationID of the shard this request is meant for */ + private final String targetAllocationID; private final TransportChannel channel; private final ReplicationTask replicationTask; - AsyncPrimaryAction(Request request, TransportChannel channel, ReplicationTask replicationTask) { + AsyncPrimaryAction(Request request, String targetAllocationID, TransportChannel channel, ReplicationTask replicationTask) { this.request = request; + this.targetAllocationID = targetAllocationID; this.channel = channel; this.replicationTask = replicationTask; } @Override protected void doRun() throws Exception { - acquirePrimaryShardReference(request.shardId(), this); + acquirePrimaryShardReference(request.shardId(), targetAllocationID, this); } @Override @@ -271,7 +283,9 @@ public void onResponse(PrimaryShardReference primaryShardReference) { final ShardRouting primary = primaryShardReference.routingEntry(); assert primary.relocating() : "indexShard is marked as relocated but routing isn't" + primary; DiscoveryNode relocatingNode = clusterService.state().nodes().get(primary.relocatingNodeId()); - transportService.sendRequest(relocatingNode, transportPrimaryAction, request, transportOptions, + transportService.sendRequest(relocatingNode, transportPrimaryAction, + new ConcreteShardRequest<>(request, primary.allocationId().getRelocationId()), + transportOptions, new TransportChannelResponseHandler(logger, channel, "rerouting indexing to target primary " + primary, TransportReplicationAction.this::newResponseInstance) { @@ -391,15 +405,17 @@ public void respond(ActionListener listener) { } } - class ReplicaOperationTransportHandler implements TransportRequestHandler { + class ReplicaOperationTransportHandler implements TransportRequestHandler> { @Override - public void messageReceived(final ReplicaRequest request, final TransportChannel channel) throws Exception { + public void messageReceived(final ConcreteShardRequest request, final TransportChannel channel) + throws Exception { throw new UnsupportedOperationException("the task parameter is required for this operation"); } @Override - public void messageReceived(ReplicaRequest request, TransportChannel channel, Task task) throws Exception { - new AsyncReplicaAction(request, channel, (ReplicationTask) task).run(); + public void messageReceived(ConcreteShardRequest requestWithAID, TransportChannel channel, Task task) + throws Exception { + new AsyncReplicaAction(requestWithAID.request, requestWithAID.targetAllocationID, channel, (ReplicationTask) task).run(); } } @@ -417,6 +433,8 @@ public RetryOnReplicaException(StreamInput in) throws IOException { private final class AsyncReplicaAction extends AbstractRunnable implements ActionListener { private final ReplicaRequest request; + // allocation id of the replica this request is meant for + private final String targetAllocationID; private final TransportChannel channel; /** * The task on the node with the replica shard. @@ -426,10 +444,11 @@ private final class AsyncReplicaAction extends AbstractRunnable implements Actio // something we want to avoid at all costs private final ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger, threadPool.getThreadContext()); - AsyncReplicaAction(ReplicaRequest request, TransportChannel channel, ReplicationTask task) { + AsyncReplicaAction(ReplicaRequest request, String targetAllocationID, TransportChannel channel, ReplicationTask task) { this.request = request; this.channel = channel; this.task = task; + this.targetAllocationID = targetAllocationID; } @Override @@ -464,7 +483,9 @@ public void onNewClusterState(ClusterState state) { String extraMessage = "action [" + transportReplicaAction + "], request[" + request + "]"; TransportChannelResponseHandler handler = new TransportChannelResponseHandler<>(logger, channel, extraMessage, () -> TransportResponse.Empty.INSTANCE); - transportService.sendRequest(clusterService.localNode(), transportReplicaAction, request, handler); + transportService.sendRequest(clusterService.localNode(), transportReplicaAction, + new ConcreteShardRequest<>(request, targetAllocationID), + handler); } @Override @@ -501,7 +522,7 @@ protected void responseWithFailure(Exception e) { protected void doRun() throws Exception { setPhase(task, "replica"); assert request.shardId() != null : "request shardId must be set"; - acquireReplicaOperationLock(request.shardId(), request.primaryTerm(), this); + acquireReplicaOperationLock(request.shardId(), request.primaryTerm(), targetAllocationID, this); } /** @@ -598,7 +619,7 @@ private void performLocalAction(ClusterState state, ShardRouting primary, Discov logger.trace("send action [{}] on primary [{}] for request [{}] with cluster state version [{}] to [{}] ", transportPrimaryAction, request.shardId(), request, state.version(), primary.currentNodeId()); } - performAction(node, transportPrimaryAction, true); + performAction(node, transportPrimaryAction, true, new ConcreteShardRequest<>(request, primary.allocationId().getId())); } private void performRemoteAction(ClusterState state, ShardRouting primary, DiscoveryNode node) { @@ -620,7 +641,7 @@ private void performRemoteAction(ClusterState state, ShardRouting primary, Disco request.shardId(), request, state.version(), primary.currentNodeId()); } setPhase(task, "rerouted"); - performAction(node, actionName, false); + performAction(node, actionName, false, request); } private boolean retryIfUnavailable(ClusterState state, ShardRouting primary) { @@ -671,8 +692,9 @@ private void handleBlockException(ClusterBlockException blockException) { } } - private void performAction(final DiscoveryNode node, final String action, final boolean isPrimaryAction) { - transportService.sendRequest(node, action, request, transportOptions, new TransportResponseHandler() { + private void performAction(final DiscoveryNode node, final String action, final boolean isPrimaryAction, + final TransportRequest requestToPerform) { + transportService.sendRequest(node, action, requestToPerform, transportOptions, new TransportResponseHandler() { @Override public Response newInstance() { @@ -700,7 +722,7 @@ public void handleException(TransportException exp) { (org.apache.logging.log4j.util.Supplier) () -> new ParameterizedMessage( "received an error from node [{}] for request [{}], scheduling a retry", node.getId(), - request), + requestToPerform), exp); retry(exp); } else { @@ -794,7 +816,8 @@ void retryBecauseUnavailable(ShardId shardId, String message) { * tries to acquire reference to {@link IndexShard} to perform a primary operation. Released after performing primary operation locally * and replication of the operation to all replica shards is completed / failed (see {@link ReplicationOperation}). */ - protected void acquirePrimaryShardReference(ShardId shardId, ActionListener onReferenceAcquired) { + protected void acquirePrimaryShardReference(ShardId shardId, String allocationId, + ActionListener onReferenceAcquired) { IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); // we may end up here if the cluster state used to route the primary is so stale that the underlying @@ -804,6 +827,10 @@ protected void acquirePrimaryShardReference(ShardId shardId, ActionListener onAcquired = new ActionListener() { @Override @@ -823,9 +850,14 @@ public void onFailure(Exception e) { /** * tries to acquire an operation on replicas. The lock is closed as soon as replication is completed on the node. */ - protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, ActionListener onLockAcquired) { + protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, final String allocationId, + ActionListener onLockAcquired) { IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); + final String actualAllocationId = indexShard.routingEntry().allocationId().getId(); + if (actualAllocationId.equals(allocationId) == false) { + throw new ShardNotFoundException(shardId, "expected aID [{}] but found [{}]", allocationId, actualAllocationId); + } indexShard.acquireReplicaOperationLock(primaryTerm, onLockAcquired, executor); } @@ -888,7 +920,8 @@ public void performOn(ShardRouting replica, ReplicaRequest request, ActionListen listener.onFailure(new NoNodeAvailableException("unknown node [" + nodeId + "]")); return; } - transportService.sendRequest(node, transportReplicaAction, request, transportOptions, + transportService.sendRequest(node, transportReplicaAction, + new ConcreteShardRequest<>(request, replica.allocationId().getId()), transportOptions, new ActionListenerResponseHandler<>(listener, () -> TransportResponse.Empty.INSTANCE)); } @@ -930,6 +963,72 @@ public void onFailure(Exception shardFailedError) { } } + /** a wrapper class to encapsulate a request when being sent to a specific allocation id **/ + final class ConcreteShardRequest extends TransportRequest { + + /** {@link AllocationId#getId()} of the shard this request is sent to **/ + private String targetAllocationID; + + private R request; + + ConcreteShardRequest(Supplier requestSupplier) { + request = requestSupplier.get(); + // null now, but will be populated by reading from the streams + targetAllocationID = null; + } + + ConcreteShardRequest(R request, String targetAllocationID) { + Objects.requireNonNull(request); + Objects.requireNonNull(targetAllocationID); + this.request = request; + this.targetAllocationID = targetAllocationID; + } + + @Override + public void setParentTask(String parentTaskNode, long parentTaskId) { + request.setParentTask(parentTaskNode, parentTaskId); + } + + @Override + public void setParentTask(TaskId taskId) { + request.setParentTask(taskId); + } + + @Override + public TaskId getParentTask() { + return request.getParentTask(); + } + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId) { + return request.createTask(id, type, action, parentTaskId); + } + + @Override + public String getDescription() { + return "[" + request.getDescription() + "] for aID [" + targetAllocationID + "]"; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + targetAllocationID = in.readString(); + request.readFrom(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(targetAllocationID); + request.writeTo(out); + } + + public R getRequest() { + return request; + } + + public String getTargetAllocationID() { + return targetAllocationID; + } + } + /** * Sets the current phase on the task if it isn't null. Pulled into its own * method because its more convenient that way. diff --git a/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java b/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java index 2ca165308b194..2d960ce0450bb 100644 --- a/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java +++ b/core/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java @@ -27,8 +27,8 @@ import org.elasticsearch.cluster.DiffableUtils; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.MetaData; -import org.elasticsearch.common.Nullable; import org.elasticsearch.cluster.routing.RecoverySource.SnapshotRecoverySource; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -88,6 +88,11 @@ public boolean hasIndex(String index) { return indicesRouting.containsKey(index); } + public boolean hasIndex(Index index) { + IndexRoutingTable indexRouting = index(index.getName()); + return indexRouting != null && indexRouting.getIndex().equals(index); + } + public IndexRoutingTable index(String index) { return indicesRouting.get(index); } diff --git a/core/src/main/java/org/elasticsearch/index/shard/ShardNotFoundException.java b/core/src/main/java/org/elasticsearch/index/shard/ShardNotFoundException.java index fa2c8ce710337..aa46240fd490f 100644 --- a/core/src/main/java/org/elasticsearch/index/shard/ShardNotFoundException.java +++ b/core/src/main/java/org/elasticsearch/index/shard/ShardNotFoundException.java @@ -33,10 +33,18 @@ public ShardNotFoundException(ShardId shardId) { } public ShardNotFoundException(ShardId shardId, Throwable ex) { - super("no such shard", ex); - setShard(shardId); + this(shardId, "no such shard", ex); + } + + public ShardNotFoundException(ShardId shardId, String msg, Object... args) { + this(shardId, msg, null, args); + } + public ShardNotFoundException(ShardId shardId, String msg, Throwable ex, Object... args) { + super(msg, ex, args); + setShard(shardId); } + public ShardNotFoundException(StreamInput in) throws IOException{ super(in); } diff --git a/core/src/test/java/org/elasticsearch/action/IndicesRequestIT.java b/core/src/test/java/org/elasticsearch/action/IndicesRequestIT.java index 934fdae254bcd..1a11e3f48034c 100644 --- a/core/src/test/java/org/elasticsearch/action/IndicesRequestIT.java +++ b/core/src/test/java/org/elasticsearch/action/IndicesRequestIT.java @@ -69,6 +69,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchType; +import org.elasticsearch.action.support.replication.TransportReplicationActionTests; import org.elasticsearch.action.termvectors.MultiTermVectorsAction; import org.elasticsearch.action.termvectors.MultiTermVectorsRequest; import org.elasticsearch.action.termvectors.TermVectorsAction; @@ -117,7 +118,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasItem; -import static org.hamcrest.Matchers.instanceOf; @ClusterScope(scope = Scope.SUITE, numClientNodes = 1, minNumDataNodes = 2) public class IndicesRequestIT extends ESIntegTestCase { @@ -638,8 +638,7 @@ private static void assertSameIndices(IndicesRequest originalRequest, boolean op assertThat("no internal requests intercepted for action [" + action + "]", requests.size(), greaterThan(0)); } for (TransportRequest internalRequest : requests) { - assertThat(internalRequest, instanceOf(IndicesRequest.class)); - IndicesRequest indicesRequest = (IndicesRequest) internalRequest; + IndicesRequest indicesRequest = convertRequest(internalRequest); assertThat(internalRequest.getClass().getName(), indicesRequest.indices(), equalTo(originalRequest.indices())); assertThat(indicesRequest.indicesOptions(), equalTo(originalRequest.indicesOptions())); } @@ -651,14 +650,24 @@ private static void assertIndicesSubset(List indices, String... actions) List requests = consumeTransportRequests(action); assertThat("no internal requests intercepted for action [" + action + "]", requests.size(), greaterThan(0)); for (TransportRequest internalRequest : requests) { - assertThat(internalRequest, instanceOf(IndicesRequest.class)); - for (String index : ((IndicesRequest) internalRequest).indices()) { + IndicesRequest indicesRequest = convertRequest(internalRequest); + for (String index : indicesRequest.indices()) { assertThat(indices, hasItem(index)); } } } } + static IndicesRequest convertRequest(TransportRequest request) { + final IndicesRequest indicesRequest; + if (request instanceof IndicesRequest) { + indicesRequest = (IndicesRequest) request; + } else { + indicesRequest = TransportReplicationActionTests.resolveRequest(request); + } + return indicesRequest; + } + private String randomIndexOrAlias() { String index = randomFrom(indices); if (randomBoolean()) { diff --git a/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java b/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java index 6c30f015124ad..c8aec62339491 100644 --- a/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.transport.NoNodeAvailableException; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ESAllocationTestCase; import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.block.ClusterBlock; import org.elasticsearch.cluster.block.ClusterBlockException; @@ -36,6 +37,7 @@ import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; +import org.elasticsearch.cluster.routing.RoutingNode; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.TestShardRouting; @@ -47,21 +49,25 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.EngineClosedException; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardClosedException; +import org.elasticsearch.index.shard.IndexShardState; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardNotFoundException; +import org.elasticsearch.indices.IndicesService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.cluster.ESAllocationTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.CapturingTransport; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponseOptions; import org.elasticsearch.transport.TransportService; @@ -75,12 +81,12 @@ import java.util.Arrays; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.stream.Collectors; import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state; @@ -93,12 +99,32 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TransportReplicationActionTests extends ESTestCase { + /** + * takes a request that was sent by a {@link TransportReplicationAction} and captured + * and returns the underlying request if it's wrapped or the original (cast to the expected type). + * + * This will throw a {@link ClassCastException} if the request is of the wrong type. + */ + public static R resolveRequest(TransportRequest requestOrWrappedRequest) { + if (requestOrWrappedRequest instanceof TransportReplicationAction.ConcreteShardRequest) { + requestOrWrappedRequest = ((TransportReplicationAction.ConcreteShardRequest)requestOrWrappedRequest).getRequest(); + } + return (R) requestOrWrappedRequest; + } + private static ThreadPool threadPool; private ClusterService clusterService; @@ -411,7 +437,7 @@ public void testPrimaryPhaseExecutesOrDelegatesRequestToRelocationTarget() throw isRelocated.set(true); executeOnPrimary = false; } - action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) { + action.new AsyncPrimaryAction(request, primaryShard.allocationId().getId(), createTransportChannel(listener), task) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -452,7 +478,8 @@ public void testPrimaryPhaseExecutesDelegatedRequestOnRelocationTarget() throws final String index = "test"; final ShardId shardId = new ShardId(index, "_na_", 0); ClusterState state = state(index, true, ShardRoutingState.RELOCATING); - String primaryTargetNodeId = state.getRoutingTable().shardRoutingTable(shardId).primaryShard().relocatingNodeId(); + final ShardRouting primaryShard = state.getRoutingTable().shardRoutingTable(shardId).primaryShard(); + String primaryTargetNodeId = primaryShard.relocatingNodeId(); // simulate execution of the primary phase on the relocation target node state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(primaryTargetNodeId)).build(); setState(clusterService, state); @@ -460,7 +487,7 @@ public void testPrimaryPhaseExecutesDelegatedRequestOnRelocationTarget() throws PlainActionFuture listener = new PlainActionFuture<>(); ReplicationTask task = maybeTask(); AtomicBoolean executed = new AtomicBoolean(); - action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) { + action.new AsyncPrimaryAction(request, primaryShard.allocationId().getRelocationId(), createTransportChannel(listener), task) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -473,6 +500,11 @@ public void execute() throws Exception { } }; } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } }.run(); assertThat(executed.get(), equalTo(true)); assertPhase(task, "finished"); @@ -596,7 +628,9 @@ public void testShadowIndexDisablesReplication() throws Exception { state = ClusterState.builder(state).metaData(metaData).build(); setState(clusterService, state); AtomicBoolean executed = new AtomicBoolean(); - action.new AsyncPrimaryAction(new Request(shardId), createTransportChannel(new PlainActionFuture<>()), null) { + ShardRouting primaryShard = state.routingTable().shardRoutingTable(shardId).primaryShard(); + action.new AsyncPrimaryAction(new Request(shardId), primaryShard.allocationId().getId(), + createTransportChannel(new PlainActionFuture<>()), null) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -613,8 +647,10 @@ public void testCounterOnPrimary() throws Exception { final String index = "test"; final ShardId shardId = new ShardId(index, "_na_", 0); // no replica, we only want to test on primary - setState(clusterService, state(index, true, ShardRoutingState.STARTED)); + final ClusterState state = state(index, true, ShardRoutingState.STARTED); + setState(clusterService, state); logger.debug("--> using initial state:\n{}", clusterService.state().prettyPrint()); + final ShardRouting primaryShard = state.routingTable().shardRoutingTable(shardId).primaryShard(); Request request = new Request(shardId); PlainActionFuture listener = new PlainActionFuture<>(); ReplicationTask task = maybeTask(); @@ -622,7 +658,7 @@ public void testCounterOnPrimary() throws Exception { final boolean throwExceptionOnCreation = i == 1; final boolean throwExceptionOnRun = i == 2; final boolean respondWithError = i == 3; - action.new AsyncPrimaryAction(request, createTransportChannel(listener), task) { + action.new AsyncPrimaryAction(request, primaryShard.allocationId().getId(), createTransportChannel(listener), task) { @Override protected ReplicationOperation createReplicatedOperation(Request request, ActionListener actionListener, Action.PrimaryShardReference primaryShardReference, @@ -666,8 +702,9 @@ public void execute() throws Exception { public void testReplicasCounter() throws Exception { final ShardId shardId = new ShardId("test", "_na_", 0); - setState(clusterService, state(shardId.getIndexName(), true, - ShardRoutingState.STARTED, ShardRoutingState.STARTED)); + final ClusterState state = state(shardId.getIndexName(), true, ShardRoutingState.STARTED, ShardRoutingState.STARTED); + setState(clusterService, state); + final ShardRouting replicaRouting = state.getRoutingTable().shardRoutingTable(shardId).replicaShards().get(0); boolean throwException = randomBoolean(); final ReplicationTask task = maybeTask(); Action action = new Action(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool) { @@ -683,7 +720,8 @@ protected ReplicaResult shardOperationOnReplica(Request request) { }; final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler(); try { - replicaOperationTransportHandler.messageReceived(new Request().setShardId(shardId), + replicaOperationTransportHandler.messageReceived( + action.new ConcreteShardRequest(new Request().setShardId(shardId), replicaRouting.allocationId().getId()), createTransportChannel(new PlainActionFuture<>()), task); } catch (ElasticsearchException e) { assertThat(e.getMessage(), containsString("simulated")); @@ -725,6 +763,111 @@ public void testDefaultWaitForActiveShardsUsesIndexSetting() throws Exception { assertEquals(ActiveShardCount.from(requestWaitForActiveShards), request.waitForActiveShards()); } + /** test that a primary request is rejected if it arrives at a shard with a wrong allocation id */ + public void testPrimaryActionRejectsWrongAid() throws Exception { + final String index = "test"; + final ShardId shardId = new ShardId(index, "_na_", 0); + setState(clusterService, state(index, true, ShardRoutingState.STARTED)); + PlainActionFuture listener = new PlainActionFuture<>(); + Request request = new Request(shardId).timeout("1ms"); + action.new PrimaryOperationTransportHandler().messageReceived( + action.new ConcreteShardRequest(request, "_not_a_valid_aid_"), + createTransportChannel(listener), maybeTask() + ); + try { + listener.get(); + fail("using a wrong aid didn't fail the operation"); + } catch (ExecutionException execException) { + Throwable throwable = execException.getCause(); + logger.debug("got exception:" , throwable); + assertTrue(throwable.getClass() + " is not a retry exception", action.retryPrimaryException(throwable)); + } + } + + /** test that a replica request is rejected if it arrives at a shard with a wrong allocation id */ + public void testReplicaActionRejectsWrongAid() throws Exception { + final String index = "test"; + final ShardId shardId = new ShardId(index, "_na_", 0); + ClusterState state = state(index, false, ShardRoutingState.STARTED, ShardRoutingState.STARTED); + final ShardRouting replica = state.routingTable().shardRoutingTable(shardId).replicaShards().get(0); + // simulate execution of the node holding the replica + state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build(); + setState(clusterService, state); + + PlainActionFuture listener = new PlainActionFuture<>(); + Request request = new Request(shardId).timeout("1ms"); + action.new ReplicaOperationTransportHandler().messageReceived( + action.new ConcreteShardRequest(request, "_not_a_valid_aid_"), + createTransportChannel(listener), maybeTask() + ); + try { + listener.get(); + fail("using a wrong aid didn't fail the operation"); + } catch (ExecutionException execException) { + Throwable throwable = execException.getCause(); + if (action.retryPrimaryException(throwable) == false) { + throw new AssertionError("thrown exception is not retriable", throwable); + } + assertThat(throwable.getMessage(), containsString("_not_a_valid_aid_")); + } + } + + /** + * test throwing a {@link org.elasticsearch.action.support.replication.TransportReplicationAction.RetryOnReplicaException} + * causes a retry + */ + public void testRetryOnReplica() throws Exception { + final ShardId shardId = new ShardId("test", "_na_", 0); + ClusterState state = state(shardId.getIndexName(), true, ShardRoutingState.STARTED, ShardRoutingState.STARTED); + final ShardRouting replica = state.getRoutingTable().shardRoutingTable(shardId).replicaShards().get(0); + // simulate execution of the node holding the replica + state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build(); + setState(clusterService, state); + AtomicBoolean throwException = new AtomicBoolean(true); + final ReplicationTask task = maybeTask(); + Action action = new Action(Settings.EMPTY, "testActionWithExceptions", transportService, clusterService, threadPool) { + @Override + protected ReplicaResult shardOperationOnReplica(Request request) { + assertPhase(task, "replica"); + if (throwException.get()) { + throw new RetryOnReplicaException(shardId, "simulation"); + } + return new ReplicaResult(); + } + }; + final Action.ReplicaOperationTransportHandler replicaOperationTransportHandler = action.new ReplicaOperationTransportHandler(); + final PlainActionFuture listener = new PlainActionFuture<>(); + final Request request = new Request().setShardId(shardId); + request.primaryTerm(state.metaData().getIndexSafe(shardId.getIndex()).primaryTerm(shardId.id())); + replicaOperationTransportHandler.messageReceived( + action.new ConcreteShardRequest(request, replica.allocationId().getId()), + createTransportChannel(listener), task); + if (listener.isDone()) { + listener.get(); // fail with the exception if there + fail("listener shouldn't be done"); + } + + // no retry yet + List capturedRequests = + transport.getCapturedRequestsByTargetNodeAndClear().get(replica.currentNodeId()); + assertThat(capturedRequests, nullValue()); + + // release the waiting + throwException.set(false); + setState(clusterService, state); + + capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear().get(replica.currentNodeId()); + assertThat(capturedRequests, notNullValue()); + assertThat(capturedRequests.size(), equalTo(1)); + final CapturingTransport.CapturedRequest capturedRequest = capturedRequests.get(0); + assertThat(capturedRequest.action, equalTo("testActionWithExceptions[r]")); + assertThat(capturedRequest.request, instanceOf(TransportReplicationAction.ConcreteShardRequest.class)); + assertThat(((TransportReplicationAction.ConcreteShardRequest) capturedRequest.request).getRequest(), equalTo(request)); + assertThat(((TransportReplicationAction.ConcreteShardRequest) capturedRequest.request).getTargetAllocationID(), + equalTo(replica.allocationId().getId())); + } + + private void assertIndexShardCounter(int expected) { assertThat(count.get(), equalTo(expected)); } @@ -797,7 +940,7 @@ class Action extends TransportReplicationAction { Action(Settings settings, String actionName, TransportService transportService, ClusterService clusterService, ThreadPool threadPool) { - super(settings, actionName, transportService, clusterService, null, threadPool, + super(settings, actionName, transportService, clusterService, mockIndicesService(clusterService), threadPool, new ShardStateAction(settings, clusterService, transportService, null, null, threadPool), new ActionFilters(new HashSet<>()), new IndexNameExpressionResolver(Settings.EMPTY), Request::new, Request::new, ThreadPool.Names.SAME); @@ -825,43 +968,76 @@ protected ReplicaResult shardOperationOnReplica(Request request) { protected boolean resolveIndex() { return false; } + } - @Override - protected void acquirePrimaryShardReference(ShardId shardId, ActionListener onReferenceAcquired) { - count.incrementAndGet(); - PrimaryShardReference primaryShardReference = new PrimaryShardReference(null, null) { - @Override - public boolean isRelocated() { - return isRelocated.get(); - } - - @Override - public void failShard(String reason, @Nullable Exception e) { - throw new UnsupportedOperationException(); - } - - @Override - public ShardRouting routingEntry() { - ShardRouting shardRouting = clusterService.state().getRoutingTable() - .shardRoutingTable(shardId).primaryShard(); - assert shardRouting != null; - return shardRouting; - } - - @Override - public void close() { - count.decrementAndGet(); - } - }; + final IndicesService mockIndicesService(ClusterService clusterService) { + final IndicesService indicesService = mock(IndicesService.class); + when(indicesService.indexServiceSafe(any(Index.class))).then(invocation -> { + Index index = (Index)invocation.getArguments()[0]; + final ClusterState state = clusterService.state(); + final IndexMetaData indexSafe = state.metaData().getIndexSafe(index); + return mockIndexService(indexSafe, clusterService); + }); + when(indicesService.indexService(any(Index.class))).then(invocation -> { + Index index = (Index) invocation.getArguments()[0]; + final ClusterState state = clusterService.state(); + if (state.metaData().hasIndex(index.getName())) { + final IndexMetaData indexSafe = state.metaData().getIndexSafe(index); + return mockIndexService(clusterService.state().metaData().getIndexSafe(index), clusterService); + } else { + return null; + } + }); + return indicesService; + } - onReferenceAcquired.onResponse(primaryShardReference); - } + final IndexService mockIndexService(final IndexMetaData indexMetaData, ClusterService clusterService) { + final IndexService indexService = mock(IndexService.class); + when(indexService.getShard(anyInt())).then(invocation -> { + int shard = (Integer) invocation.getArguments()[0]; + final ShardId shardId = new ShardId(indexMetaData.getIndex(), shard); + if (shard > indexMetaData.getNumberOfShards()) { + throw new ShardNotFoundException(shardId); + } + return mockIndexShard(shardId, clusterService); + }); + return indexService; + } - @Override - protected void acquireReplicaOperationLock(ShardId shardId, long primaryTerm, ActionListener onLockAcquired) { + private IndexShard mockIndexShard(ShardId shardId, ClusterService clusterService) { + final IndexShard indexShard = mock(IndexShard.class); + doAnswer(invocation -> { + ActionListener callback = (ActionListener) invocation.getArguments()[0]; count.incrementAndGet(); - onLockAcquired.onResponse(count::decrementAndGet); - } + callback.onResponse(count::decrementAndGet); + return null; + }).when(indexShard).acquirePrimaryOperationLock(any(ActionListener.class), anyString()); + doAnswer(invocation -> { + long term = (Long)invocation.getArguments()[0]; + ActionListener callback = (ActionListener) invocation.getArguments()[1]; + final long primaryTerm = indexShard.getPrimaryTerm(); + if (term < primaryTerm) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s operation term [%d] is too old (current [%d])", + shardId, term, primaryTerm)); + } + count.incrementAndGet(); + callback.onResponse(count::decrementAndGet); + return null; + }).when(indexShard).acquireReplicaOperationLock(anyLong(), any(ActionListener.class), anyString()); + when(indexShard.routingEntry()).thenAnswer(invocationOnMock -> { + final ClusterState state = clusterService.state(); + final RoutingNode node = state.getRoutingNodes().node(state.nodes().getLocalNodeId()); + final ShardRouting routing = node.getByShardId(shardId); + if (routing == null) { + throw new ShardNotFoundException(shardId, "shard is no longer assigned to current node"); + } + return routing; + }); + when(indexShard.state()).thenAnswer(invocationOnMock -> isRelocated.get() ? IndexShardState.RELOCATED : IndexShardState.STARTED); + doThrow(new AssertionError("failed shard is not supported")).when(indexShard).failShard(anyString(), any(Exception.class)); + when(indexShard.getPrimaryTerm()).thenAnswer(i -> + clusterService.state().metaData().getIndexSafe(shardId.getIndex()).primaryTerm(shardId.id())); + return indexShard; } class NoopReplicationOperation extends ReplicationOperation { @@ -879,11 +1055,6 @@ public void execute() throws Exception { * Transport channel that is needed for replica operation testing. */ public TransportChannel createTransportChannel(final PlainActionFuture listener) { - return createTransportChannel(listener, error -> { - }); - } - - public TransportChannel createTransportChannel(final PlainActionFuture listener, Consumer consumer) { return new TransportChannel() { @Override @@ -908,7 +1079,6 @@ public void sendResponse(TransportResponse response, TransportResponseOptions op @Override public void sendResponse(Exception exception) throws IOException { - consumer.accept(exception); listener.onFailure(exception); }