Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve reroute with many shards #48579

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@
import org.elasticsearch.index.shard.ShardId;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* A {@link RoutingNode} represents a cluster node associated with a single {@link DiscoveryNode} including all shards
Expand All @@ -41,6 +48,8 @@ public class RoutingNode implements Iterable<ShardRouting> {

private final LinkedHashMap<ShardId, ShardRouting> shards; // LinkedHashMap to preserve order

private final Map<ShardRoutingState, Set<ShardRouting>> statesToShards;

public RoutingNode(String nodeId, DiscoveryNode node, ShardRouting... shards) {
this(nodeId, node, buildShardRoutingMap(shards));
}
Expand All @@ -49,6 +58,13 @@ public RoutingNode(String nodeId, DiscoveryNode node, ShardRouting... shards) {
this.nodeId = nodeId;
this.node = node;
this.shards = shards;
statesToShards = new HashMap<>(ShardRoutingState.values().length);
for (ShardRoutingState state : ShardRoutingState.values()) {
statesToShards.put(state, new HashSet<>());
}
for (ShardRouting shardRouting : shards.values()) {
statesToShards.get(shardRouting.state()).add(shardRouting);
}
}

private static LinkedHashMap<ShardId, ShardRouting> buildShardRoutingMap(ShardRouting... shardRoutings) {
Expand Down Expand Up @@ -104,6 +120,7 @@ void add(ShardRouting shard) {
+ "] where it already exists. current [" + shards.get(shard.shardId()) + "]. new [" + shard + "]");
}
shards.put(shard.shardId(), shard);
statesToShards.get(shard.state()).add(shard);
}

void update(ShardRouting oldShard, ShardRouting newShard) {
Expand All @@ -112,11 +129,14 @@ void update(ShardRouting oldShard, ShardRouting newShard) {
// TODO: change caller logic in RoutingNodes so that this check can go away
return;
}
statesToShards.get(oldShard.state()).remove(oldShard);
statesToShards.get(newShard.state()).add(newShard);
ShardRouting previousValue = shards.put(newShard.shardId(), newShard);
assert previousValue == oldShard : "expected shard " + previousValue + " but was " + oldShard;
}

void remove(ShardRouting shard) {
statesToShards.get(shard.state()).remove(shard);
ShardRouting previousValue = shards.remove(shard.shardId());
assert previousValue == shard : "expected shard " + previousValue + " but was " + shard;
}
Expand All @@ -127,15 +147,7 @@ void remove(ShardRouting shard) {
* @return number of shards
*/
public int numberOfShardsWithState(ShardRoutingState... states) {
int count = 0;
for (ShardRouting shardEntry : this) {
for (ShardRoutingState state : states) {
if (shardEntry.state() == state) {
count++;
}
}
}
return count;
return Arrays.stream(states).mapToInt(s -> statesToShards.get(s).size()).sum();
}

/**
Expand All @@ -144,51 +156,34 @@ public int numberOfShardsWithState(ShardRoutingState... states) {
* @return List of shards
*/
public List<ShardRouting> shardsWithState(ShardRoutingState... states) {
List<ShardRouting> shards = new ArrayList<>();
for (ShardRouting shardEntry : this) {
for (ShardRoutingState state : states) {
if (shardEntry.state() == state) {
shards.add(shardEntry);
}
}
}
return shards;
return Arrays.stream(states)
.map(state -> statesToShards.get(state))
.flatMap(Collection::stream)
.collect(Collectors.toList());
}

/**
* Determine the shards of an index with a specific state
* @param index id of the index
* @param index id of the index
* @param states set of states which should be listed
* @return a list of shards
*/
public List<ShardRouting> shardsWithState(String index, ShardRoutingState... states) {
List<ShardRouting> shards = new ArrayList<>();

for (ShardRouting shardEntry : this) {
if (!shardEntry.getIndexName().equals(index)) {
continue;
}
for (ShardRoutingState state : states) {
if (shardEntry.state() == state) {
shards.add(shardEntry);
}
}
}
return shards;
return Arrays.stream(states)
.map(state -> statesToShards.get(state))
.flatMap(Collection::stream)
.filter(shard -> shard.getIndexName().equals(index))
.collect(Collectors.toList());
}

/**
* The number of shards on this node that will not be eventually relocated.
*/
public int numberOfOwningShards() {
int count = 0;
for (ShardRouting shardEntry : this) {
if (shardEntry.state() != ShardRoutingState.RELOCATING) {
count++;
}
}

return count;
return Arrays.stream(ShardRoutingState.values())
.filter(s -> s != ShardRoutingState.RELOCATING)
.mapToInt(s -> statesToShards.get(s).size())
.sum();
}

public String prettyPrint() {
Expand All @@ -200,6 +195,7 @@ public String prettyPrint() {
return sb.toString();
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("routingNode ([");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.cluster.routing.RecoverySource;
import org.elasticsearch.cluster.routing.RoutingNode;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.cluster.routing.UnassignedInfo;
import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
import org.elasticsearch.common.settings.ClusterSettings;
Expand Down Expand Up @@ -122,10 +123,10 @@ public Decision canAllocate(ShardRouting shardRouting, RoutingNode node, Routing
// count *just the primaries* currently doing recovery on the node and check against primariesInitialRecoveries

int primariesInRecovery = 0;
for (ShardRouting shard : node) {
for (ShardRouting shard : node.shardsWithState(ShardRoutingState.INITIALIZING)) {
// when a primary shard is INITIALIZING, it can be because of *initial recovery* or *relocation from another node*
// we only count initial recoveries here, so we need to make sure that relocating node is null
if (shard.initializing() && shard.primary() && shard.relocatingNodeId() == null) {
if (shard.primary() && shard.relocatingNodeId() == null) {
primariesInRecovery++;
}
}
Expand Down