Skip to content

Commit

Permalink
Optimise unassigned shards iteration after allocator timeout (#14977)
Browse files Browse the repository at this point in the history
* Optimise unassigned shards iteration after allocator timeout

Signed-off-by: Rishab Nahata <rnnahata@amazon.com>
  • Loading branch information
imRishN authored Aug 8, 2024
1 parent f03dde9 commit 555a56d
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ public void run() {
"Time taken to execute timed runnables in this cycle:[{}ms]",
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)
);
onComplete();
}

/**
* Callback method that is invoked after all {@link TimeoutAwareRunnable} instances in the batch have been processed.
* By default, this method does nothing, but it can be overridden by subclasses or modified in the implementation if
* there is a need to perform additional actions once the batch execution is completed.
*/
public void onComplete() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.opensearch.core.index.shard.ShardId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -82,17 +81,15 @@ public void allocateUnassigned(
executeDecision(shardRouting, allocateUnassignedDecision, allocation, unassignedAllocationHandler);
}

protected void allocateUnassignedBatchOnTimeout(List<ShardRouting> shardRoutings, RoutingAllocation allocation, boolean primary) {
Set<ShardId> shardIdsFromBatch = new HashSet<>();
for (ShardRouting shardRouting : shardRoutings) {
ShardId shardId = shardRouting.shardId();
shardIdsFromBatch.add(shardId);
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
if (shardIds.isEmpty()) {
return;
}
RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
while (iterator.hasNext()) {
ShardRouting unassignedShard = iterator.next();
AllocateUnassignedDecision allocationDecision;
if (unassignedShard.primary() == primary && shardIdsFromBatch.contains(unassignedShard.shardId())) {
if (unassignedShard.primary() == primary && shardIds.contains(unassignedShard.shardId())) {
allocationDecision = AllocateUnassignedDecision.throttle(null);
executeDecision(unassignedShard, allocationDecision, allocation, iterator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,41 +277,51 @@ protected BatchRunnableExecutor innerAllocateUnassignedBatch(
}
List<TimeoutAwareRunnable> runnables = new ArrayList<>();
if (primary) {
Set<ShardId> timedOutPrimaryShardIds = new HashSet<>();
batchIdToStartedShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(shardsBatch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(
shardsBatch.getBatchedShardRoutings(),
allocation,
true
);
timedOutPrimaryShardIds.addAll(shardsBatch.getBatchedShards());
}

@Override
public void run() {
primaryBatchShardAllocator.allocateUnassignedBatch(shardsBatch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] primary shards", timedOutPrimaryShardIds.size());
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutPrimaryShardIds, allocation, true);
}
};
} else {
Set<ShardId> timedOutReplicaShardIds = new HashSet<>();
batchIdToStoreShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(batch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(batch.getBatchedShardRoutings(), allocation, false);
timedOutReplicaShardIds.addAll(batch.getBatchedShards());
}

@Override
public void run() {
replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] replica shards", timedOutReplicaShardIds.size());
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutReplicaShardIds, allocation, false);
}
};
}
}

Expand Down Expand Up @@ -846,11 +856,11 @@ public int getNumberOfStoreShardBatches() {
return batchIdToStoreShardBatch.size();
}

private void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
protected void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
this.primaryShardsBatchGatewayAllocatorTimeout = primaryShardsBatchGatewayAllocatorTimeout;
}

private void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
protected void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
this.replicaShardsBatchGatewayAllocatorTimeout = replicaShardsBatchGatewayAllocatorTimeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.function.Supplier;

import static org.mockito.Mockito.atMost;
Expand Down Expand Up @@ -42,33 +43,53 @@ public void setupRunnables() {
public void testRunWithoutTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueSeconds(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).run();
verify(runnable2, times(1)).run();
verify(runnable3, times(1)).run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueNanos(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).onTimeout();
verify(runnable2, times(1)).onTimeout();
verify(runnable3, times(1)).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithPartialTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueMillis(50);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
doAnswer(invocation -> {
Thread.sleep(100);
return null;
Expand All @@ -81,17 +102,25 @@ public void testRunWithPartialTimeout() {
verify(runnable3, atMost(1)).onTimeout();
verify(runnable2, atMost(1)).onTimeout();
verify(runnable3, atMost(1)).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithEmptyRunnableList() {
setupRunnables();
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(1, countDownLatch.getCount());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.BatchRunnableExecutor;
import org.opensearch.common.util.set.Sets;
import org.opensearch.core.index.shard.ShardId;
Expand All @@ -45,6 +46,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.gateway.ShardsBatchGatewayAllocator.PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING;
Expand Down Expand Up @@ -423,6 +426,24 @@ public void testReplicaAllocatorTimeout() {
assertEquals(-1, REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING.get(build).getMillis());
}

public void testCollectTimedOutShards() throws InterruptedException {
createIndexAndUpdateClusterState(2, 5, 2);
CountDownLatch latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, false);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
}

private void createIndexAndUpdateClusterState(int count, int numberOfShards, int numberOfReplicas) {
if (count == 0) return;
Metadata.Builder metadata = Metadata.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.junit.Before;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -264,8 +263,9 @@ public void testAllocateUnassignedBatchOnTimeoutWithMatchingPrimaryShards() {
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
ShardRouting shardRouting = routingAllocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard();

List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
Expand All @@ -277,8 +277,7 @@ public void testAllocateUnassignedBatchOnTimeoutWithNoMatchingPrimaryShards() {
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
List<ShardRouting> shardRoutings = new ArrayList<>();
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
batchAllocator.allocateUnassignedBatchOnTimeout(new HashSet<>(), routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(0, ignoredShards.size());
Expand All @@ -296,13 +295,33 @@ public void testAllocateUnassignedBatchOnTimeoutWithNonPrimaryShards() {
.shard(shardId.id())
.replicaShards()
.get(0);
List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, false);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, false);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
}

public void testAllocateUnassignedBatchOnTimeoutWithNoShards() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");

ShardRouting shardRouting = routingAllocation.routingTable()
.getIndicesRouting()
.get("test")
.shard(shardId.id())
.replicaShards()
.get(0);
Set<ShardId> shardIds = new HashSet<>();
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, false);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(0, ignoredShards.size());
}

private RoutingAllocation routingAllocationWithOnePrimary(
AllocationDeciders deciders,
UnassignedInfo.Reason reason,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,9 @@ public void testAllocateUnassignedBatchThrottlingAllocationDeciderIsHonoured() t
public void testAllocateUnassignedBatchOnTimeoutWithUnassignedReplicaShard() {
RoutingAllocation allocation = onePrimaryOnNode1And1Replica(yesAllocationDeciders());
final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
List<ShardRouting> shards = new ArrayList<>();
Set<ShardId> shards = new HashSet<>();
while (iterator.hasNext()) {
shards.add(iterator.next());
shards.add(iterator.next().shardId());
}
testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(1));
Expand All @@ -736,9 +736,9 @@ public void testAllocateUnassignedBatchOnTimeoutWithUnassignedReplicaShard() {
public void testAllocateUnassignedBatchOnTimeoutWithAlreadyRecoveringReplicaShard() {
RoutingAllocation allocation = onePrimaryOnNode1And1ReplicaRecovering(yesAllocationDeciders());
final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
List<ShardRouting> shards = new ArrayList<>();
Set<ShardId> shards = new HashSet<>();
while (iterator.hasNext()) {
shards.add(iterator.next());
shards.add(iterator.next().shardId());
}
testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;

public class TestShardBatchGatewayAllocator extends ShardsBatchGatewayAllocator {

CountDownLatch latch;

public TestShardBatchGatewayAllocator() {

}

public TestShardBatchGatewayAllocator(CountDownLatch latch) {
this.latch = latch;
}

public TestShardBatchGatewayAllocator(long maxBatchSize) {
super(maxBatchSize);
}
Expand Down Expand Up @@ -83,6 +90,13 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatc
}
return new AsyncShardFetch.FetchResult<>(foundShards, shardsToIgnoreNodes);
}

@Override
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
for (int i = 0; i < shardIds.size(); i++) {
latch.countDown();
}
}
};

ReplicaShardBatchAllocator replicaBatchShardAllocator = new ReplicaShardBatchAllocator() {
Expand All @@ -100,6 +114,13 @@ protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadataBatch.
protected boolean hasInitiatedFetching(ShardRouting shard) {
return true;
}

@Override
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
for (int i = 0; i < shardIds.size(); i++) {
latch.countDown();
}
}
};

@Override
Expand Down

0 comments on commit 555a56d

Please sign in to comment.