Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.17] Optimise unassigned shards iteration after allocator timeout & Fix responsibility check for existing shards allocator when timed out #15650

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Optimise unassigned shards iteration after allocator timeout (#14977)
* Optimise unassigned shards iteration after allocator timeout

Signed-off-by: Rishab Nahata <rnnahata@amazon.com>
  • Loading branch information
imRishN committed Sep 4, 2024
commit 438cfc4020674dcd6021acaeccdb872dcac6ba0f
Original file line number Diff line number Diff line change
@@ -61,6 +61,13 @@ public void run() {
"Time taken to execute timed runnables in this cycle:[{}ms]",
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)
);
onComplete();
}

/**
* Callback method that is invoked after all {@link TimeoutAwareRunnable} instances in the batch have been processed.
* By default, this method does nothing, but it can be overridden by subclasses or modified in the implementation if
* there is a need to perform additional actions once the batch execution is completed.
*/
public void onComplete() {}
}
Original file line number Diff line number Diff line change
@@ -47,7 +47,6 @@
import org.opensearch.core.index.shard.ShardId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

@@ -82,17 +81,15 @@ public void allocateUnassigned(
executeDecision(shardRouting, allocateUnassignedDecision, allocation, unassignedAllocationHandler);
}

protected void allocateUnassignedBatchOnTimeout(List<ShardRouting> shardRoutings, RoutingAllocation allocation, boolean primary) {
Set<ShardId> shardIdsFromBatch = new HashSet<>();
for (ShardRouting shardRouting : shardRoutings) {
ShardId shardId = shardRouting.shardId();
shardIdsFromBatch.add(shardId);
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
if (shardIds.isEmpty()) {
return;
}
RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
while (iterator.hasNext()) {
ShardRouting unassignedShard = iterator.next();
AllocateUnassignedDecision allocationDecision;
if (unassignedShard.primary() == primary && shardIdsFromBatch.contains(unassignedShard.shardId())) {
if (unassignedShard.primary() == primary && shardIds.contains(unassignedShard.shardId())) {
allocationDecision = AllocateUnassignedDecision.throttle(null);
executeDecision(unassignedShard, allocationDecision, allocation, iterator);
}
Original file line number Diff line number Diff line change
@@ -277,41 +277,51 @@ protected BatchRunnableExecutor innerAllocateUnassignedBatch(
}
List<TimeoutAwareRunnable> runnables = new ArrayList<>();
if (primary) {
Set<ShardId> timedOutPrimaryShardIds = new HashSet<>();
batchIdToStartedShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(shardsBatch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(
shardsBatch.getBatchedShardRoutings(),
allocation,
true
);
timedOutPrimaryShardIds.addAll(shardsBatch.getBatchedShards());
}

@Override
public void run() {
primaryBatchShardAllocator.allocateUnassignedBatch(shardsBatch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] primary shards", timedOutPrimaryShardIds.size());
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutPrimaryShardIds, allocation, true);
}
};
} else {
Set<ShardId> timedOutReplicaShardIds = new HashSet<>();
batchIdToStoreShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(batch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(batch.getBatchedShardRoutings(), allocation, false);
timedOutReplicaShardIds.addAll(batch.getBatchedShards());
}

@Override
public void run() {
replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] replica shards", timedOutReplicaShardIds.size());
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutReplicaShardIds, allocation, false);
}
};
}
}

@@ -846,11 +856,11 @@ public int getNumberOfStoreShardBatches() {
return batchIdToStoreShardBatch.size();
}

private void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
protected void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
this.primaryShardsBatchGatewayAllocatorTimeout = primaryShardsBatchGatewayAllocatorTimeout;
}

private void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
protected void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
this.replicaShardsBatchGatewayAllocatorTimeout = replicaShardsBatchGatewayAllocatorTimeout;
}
}
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.function.Supplier;

import static org.mockito.Mockito.atMost;
@@ -42,33 +43,53 @@ public void setupRunnables() {
public void testRunWithoutTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueSeconds(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).run();
verify(runnable2, times(1)).run();
verify(runnable3, times(1)).run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueNanos(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).onTimeout();
verify(runnable2, times(1)).onTimeout();
verify(runnable3, times(1)).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithPartialTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueMillis(50);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
doAnswer(invocation -> {
Thread.sleep(100);
return null;
@@ -81,17 +102,25 @@ public void testRunWithPartialTimeout() {
verify(runnable3, atMost(1)).onTimeout();
verify(runnable2, atMost(1)).onTimeout();
verify(runnable3, atMost(1)).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithEmptyRunnableList() {
setupRunnables();
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(1, countDownLatch.getCount());
}
}
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.BatchRunnableExecutor;
import org.opensearch.common.util.set.Sets;
import org.opensearch.core.index.shard.ShardId;
@@ -45,6 +46,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.gateway.ShardsBatchGatewayAllocator.PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING;
@@ -423,6 +426,24 @@ public void testReplicaAllocatorTimeout() {
assertEquals(-1, REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING.get(build).getMillis());
}

public void testCollectTimedOutShards() throws InterruptedException {
createIndexAndUpdateClusterState(2, 5, 2);
CountDownLatch latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, false);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
}

private void createIndexAndUpdateClusterState(int count, int numberOfShards, int numberOfReplicas) {
if (count == 0) return;
Metadata.Builder metadata = Metadata.builder();
Original file line number Diff line number Diff line change
@@ -41,7 +41,6 @@
import org.junit.Before;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@@ -264,8 +263,9 @@ public void testAllocateUnassignedBatchOnTimeoutWithMatchingPrimaryShards() {
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
ShardRouting shardRouting = routingAllocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard();

List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
@@ -277,8 +277,7 @@ public void testAllocateUnassignedBatchOnTimeoutWithNoMatchingPrimaryShards() {
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
List<ShardRouting> shardRoutings = new ArrayList<>();
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
batchAllocator.allocateUnassignedBatchOnTimeout(new HashSet<>(), routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(0, ignoredShards.size());
@@ -296,13 +295,33 @@ public void testAllocateUnassignedBatchOnTimeoutWithNonPrimaryShards() {
.shard(shardId.id())
.replicaShards()
.get(0);
List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, false);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, false);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
}

public void testAllocateUnassignedBatchOnTimeoutWithNoShards() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");

ShardRouting shardRouting = routingAllocation.routingTable()
.getIndicesRouting()
.get("test")
.shard(shardId.id())
.replicaShards()
.get(0);
Set<ShardId> shardIds = new HashSet<>();
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, false);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(0, ignoredShards.size());
}

private RoutingAllocation routingAllocationWithOnePrimary(
AllocationDeciders deciders,
UnassignedInfo.Reason reason,
Original file line number Diff line number Diff line change
@@ -720,9 +720,9 @@ public void testAllocateUnassignedBatchThrottlingAllocationDeciderIsHonoured() t
public void testAllocateUnassignedBatchOnTimeoutWithUnassignedReplicaShard() {
RoutingAllocation allocation = onePrimaryOnNode1And1Replica(yesAllocationDeciders());
final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
List<ShardRouting> shards = new ArrayList<>();
Set<ShardId> shards = new HashSet<>();
while (iterator.hasNext()) {
shards.add(iterator.next());
shards.add(iterator.next().shardId());
}
testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(1));
@@ -736,9 +736,9 @@ public void testAllocateUnassignedBatchOnTimeoutWithUnassignedReplicaShard() {
public void testAllocateUnassignedBatchOnTimeoutWithAlreadyRecoveringReplicaShard() {
RoutingAllocation allocation = onePrimaryOnNode1And1ReplicaRecovering(yesAllocationDeciders());
final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
List<ShardRouting> shards = new ArrayList<>();
Set<ShardId> shards = new HashSet<>();
while (iterator.hasNext()) {
shards.add(iterator.next());
shards.add(iterator.next().shardId());
}
testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(0));
Original file line number Diff line number Diff line change
@@ -29,13 +29,20 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;

public class TestShardBatchGatewayAllocator extends ShardsBatchGatewayAllocator {

CountDownLatch latch;

public TestShardBatchGatewayAllocator() {

}

public TestShardBatchGatewayAllocator(CountDownLatch latch) {
this.latch = latch;
}

public TestShardBatchGatewayAllocator(long maxBatchSize) {
super(maxBatchSize);
}
@@ -83,6 +90,13 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatc
}
return new AsyncShardFetch.FetchResult<>(foundShards, shardsToIgnoreNodes);
}

@Override
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
for (int i = 0; i < shardIds.size(); i++) {
latch.countDown();
}
}
};

ReplicaShardBatchAllocator replicaBatchShardAllocator = new ReplicaShardBatchAllocator() {
@@ -100,6 +114,13 @@ protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadataBatch.
protected boolean hasInitiatedFetching(ShardRouting shard) {
return true;
}

@Override
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
for (int i = 0; i < shardIds.size(); i++) {
latch.countDown();
}
}
};

@Override