Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Optimise unassigned shards iteration after allocator timeout #15176

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ public void run() {
"Time taken to execute timed runnables in this cycle:[{}ms]",
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)
);
onComplete();
}

/**
* Callback method that is invoked after all {@link TimeoutAwareRunnable} instances in the batch have been processed.
* By default, this method does nothing, but it can be overridden by subclasses or modified in the implementation if
* there is a need to perform additional actions once the batch execution is completed.
*/
public void onComplete() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.opensearch.core.index.shard.ShardId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -82,17 +81,15 @@ public void allocateUnassigned(
executeDecision(shardRouting, allocateUnassignedDecision, allocation, unassignedAllocationHandler);
}

protected void allocateUnassignedBatchOnTimeout(List<ShardRouting> shardRoutings, RoutingAllocation allocation, boolean primary) {
Set<ShardId> shardIdsFromBatch = new HashSet<>();
for (ShardRouting shardRouting : shardRoutings) {
ShardId shardId = shardRouting.shardId();
shardIdsFromBatch.add(shardId);
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
if (shardIds.isEmpty()) {
return;
}
RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
while (iterator.hasNext()) {
ShardRouting unassignedShard = iterator.next();
AllocateUnassignedDecision allocationDecision;
if (unassignedShard.primary() == primary && shardIdsFromBatch.contains(unassignedShard.shardId())) {
if (unassignedShard.primary() == primary && shardIds.contains(unassignedShard.shardId())) {
allocationDecision = AllocateUnassignedDecision.throttle(null);
executeDecision(unassignedShard, allocationDecision, allocation, iterator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,41 +277,51 @@ protected BatchRunnableExecutor innerAllocateUnassignedBatch(
}
List<TimeoutAwareRunnable> runnables = new ArrayList<>();
if (primary) {
Set<ShardId> timedOutPrimaryShardIds = new HashSet<>();
batchIdToStartedShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(shardsBatch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(
shardsBatch.getBatchedShardRoutings(),
allocation,
true
);
timedOutPrimaryShardIds.addAll(shardsBatch.getBatchedShards());
}

@Override
public void run() {
primaryBatchShardAllocator.allocateUnassignedBatch(shardsBatch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] primary shards", timedOutPrimaryShardIds.size());
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutPrimaryShardIds, allocation, true);
}
};
} else {
Set<ShardId> timedOutReplicaShardIds = new HashSet<>();
batchIdToStoreShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(batch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(batch.getBatchedShardRoutings(), allocation, false);
timedOutReplicaShardIds.addAll(batch.getBatchedShards());
}

@Override
public void run() {
replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] replica shards", timedOutReplicaShardIds.size());
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutReplicaShardIds, allocation, false);
}
};
}
}

Expand Down Expand Up @@ -846,11 +856,11 @@ public int getNumberOfStoreShardBatches() {
return batchIdToStoreShardBatch.size();
}

private void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
protected void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
this.primaryShardsBatchGatewayAllocatorTimeout = primaryShardsBatchGatewayAllocatorTimeout;
}

private void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
protected void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
this.replicaShardsBatchGatewayAllocatorTimeout = replicaShardsBatchGatewayAllocatorTimeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.function.Supplier;

import static org.mockito.Mockito.atMost;
Expand Down Expand Up @@ -42,33 +43,53 @@ public void setupRunnables() {
public void testRunWithoutTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueSeconds(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).run();
verify(runnable2, times(1)).run();
verify(runnable3, times(1)).run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueNanos(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).onTimeout();
verify(runnable2, times(1)).onTimeout();
verify(runnable3, times(1)).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithPartialTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueMillis(50);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
doAnswer(invocation -> {
Thread.sleep(100);
return null;
Expand All @@ -81,17 +102,25 @@ public void testRunWithPartialTimeout() {
verify(runnable3, atMost(1)).onTimeout();
verify(runnable2, atMost(1)).onTimeout();
verify(runnable3, atMost(1)).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithEmptyRunnableList() {
setupRunnables();
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(1, countDownLatch.getCount());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.BatchRunnableExecutor;
import org.opensearch.common.util.set.Sets;
import org.opensearch.core.index.shard.ShardId;
Expand All @@ -45,6 +46,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.gateway.ShardsBatchGatewayAllocator.PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING;
Expand Down Expand Up @@ -423,6 +426,24 @@ public void testReplicaAllocatorTimeout() {
assertEquals(-1, REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING.get(build).getMillis());
}

public void testCollectTimedOutShards() throws InterruptedException {
createIndexAndUpdateClusterState(2, 5, 2);
CountDownLatch latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, false);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
}

private void createIndexAndUpdateClusterState(int count, int numberOfShards, int numberOfReplicas) {
if (count == 0) return;
Metadata.Builder metadata = Metadata.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.junit.Before;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -264,8 +263,9 @@ public void testAllocateUnassignedBatchOnTimeoutWithMatchingPrimaryShards() {
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
ShardRouting shardRouting = routingAllocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard();

List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
Expand All @@ -277,8 +277,7 @@ public void testAllocateUnassignedBatchOnTimeoutWithNoMatchingPrimaryShards() {
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
List<ShardRouting> shardRoutings = new ArrayList<>();
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
batchAllocator.allocateUnassignedBatchOnTimeout(new HashSet<>(), routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(0, ignoredShards.size());
Expand All @@ -296,13 +295,33 @@ public void testAllocateUnassignedBatchOnTimeoutWithNonPrimaryShards() {
.shard(shardId.id())
.replicaShards()
.get(0);
List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, false);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, false);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
}

public void testAllocateUnassignedBatchOnTimeoutWithNoShards() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");

ShardRouting shardRouting = routingAllocation.routingTable()
.getIndicesRouting()
.get("test")
.shard(shardId.id())
.replicaShards()
.get(0);
Set<ShardId> shardIds = new HashSet<>();
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, false);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(0, ignoredShards.size());
}

private RoutingAllocation routingAllocationWithOnePrimary(
AllocationDeciders deciders,
UnassignedInfo.Reason reason,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,9 @@ public void testAllocateUnassignedBatchThrottlingAllocationDeciderIsHonoured() t
public void testAllocateUnassignedBatchOnTimeoutWithUnassignedReplicaShard() {
RoutingAllocation allocation = onePrimaryOnNode1And1Replica(yesAllocationDeciders());
final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
List<ShardRouting> shards = new ArrayList<>();
Set<ShardId> shards = new HashSet<>();
while (iterator.hasNext()) {
shards.add(iterator.next());
shards.add(iterator.next().shardId());
}
testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(1));
Expand All @@ -736,9 +736,9 @@ public void testAllocateUnassignedBatchOnTimeoutWithUnassignedReplicaShard() {
public void testAllocateUnassignedBatchOnTimeoutWithAlreadyRecoveringReplicaShard() {
RoutingAllocation allocation = onePrimaryOnNode1And1ReplicaRecovering(yesAllocationDeciders());
final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
List<ShardRouting> shards = new ArrayList<>();
Set<ShardId> shards = new HashSet<>();
while (iterator.hasNext()) {
shards.add(iterator.next());
shards.add(iterator.next().shardId());
}
testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;

public class TestShardBatchGatewayAllocator extends ShardsBatchGatewayAllocator {

CountDownLatch latch;

public TestShardBatchGatewayAllocator() {

}

public TestShardBatchGatewayAllocator(CountDownLatch latch) {
this.latch = latch;
}

public TestShardBatchGatewayAllocator(long maxBatchSize) {
super(maxBatchSize);
}
Expand Down Expand Up @@ -83,6 +90,13 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatc
}
return new AsyncShardFetch.FetchResult<>(foundShards, shardsToIgnoreNodes);
}

@Override
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
for (int i = 0; i < shardIds.size(); i++) {
latch.countDown();
}
}
};

ReplicaShardBatchAllocator replicaBatchShardAllocator = new ReplicaShardBatchAllocator() {
Expand All @@ -100,6 +114,13 @@ protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadataBatch.
protected boolean hasInitiatedFetching(ShardRouting shard) {
return true;
}

@Override
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
for (int i = 0; i < shardIds.size(); i++) {
latch.countDown();
}
}
};

@Override
Expand Down
Loading