diff --git a/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java b/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java index 123d09246bb7b..62e958777f1bf 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java +++ b/server/src/main/java/org/opensearch/cluster/routing/RoutingNode.java @@ -48,6 +48,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.StreamSupport; /** * A {@link RoutingNode} represents a cluster node associated with a single {@link DiscoveryNode} including all shards @@ -55,11 +56,85 @@ */ public class RoutingNode implements Iterable { + static class BucketedShards implements Iterable { + private static Map map = new HashMap() {{ + put(true, 0); + put(false, 1); + }}; + + private final LinkedHashMap[] shards; // LinkedHashMap to preserve order + + BucketedShards (LinkedHashMap primaryShards, LinkedHashMap replicaShards) { + this.shards = new LinkedHashMap[2]; + this.shards[0] = primaryShards; + this.shards[1] = replicaShards; + } + + public boolean isEmpty() { + return this.shards[0].isEmpty() && this.shards[1].isEmpty(); + } + + public int size() { + return this.shards[0].size() + this.shards[1].size(); + } + + public boolean containsKey(ShardId shardId) { + return this.shards[0].containsKey(shardId) || this.shards[1].containsKey(shardId); + } + + public ShardRouting get(ShardId shardId) { + if (this.shards[0].containsKey(shardId)) { + return this.shards[0].get(shardId); + } + return this.shards[1].get(shardId); + } + + public ShardRouting add(ShardRouting shardRouting) { + return put(shardRouting.shardId(), shardRouting); + } + + public ShardRouting put(ShardId shardId, ShardRouting shardRouting) { + ShardRouting ret = this.shards[map.get(shardRouting.primary())].put(shardId, shardRouting); + if (this.shards[map.get(!shardRouting.primary())].containsKey(shardId)) { + return this.shards[map.get(!shardRouting.primary())].remove(shardId); + } + + return ret; + } + + public ShardRouting remove(ShardId shardId) { + if (this.shards[0].containsKey(shardId)) { + return this.shards[0].remove(shardId); + } + return this.shards[1].remove(shardId); + } + + @Override + public Iterator iterator() { + final Iterator iterator1 = Collections.unmodifiableCollection(shards[0].values()).iterator(); + final Iterator iterator2 = Collections.unmodifiableCollection(shards[1].values()).iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + return iterator1.hasNext() || iterator2.hasNext(); + } + + @Override + public ShardRouting next() { + if (iterator1.hasNext()) { + return iterator1.next(); + } + return iterator2.next(); + } + }; + } + } + private final String nodeId; private final DiscoveryNode node; - private final LinkedHashMap shards; // LinkedHashMap to preserve order + private final BucketedShards shards; private final LinkedHashSet initializingShards; @@ -67,44 +142,43 @@ public class RoutingNode implements Iterable { private final HashMap> shardsByIndex; - public RoutingNode(String nodeId, DiscoveryNode node, ShardRouting... shards) { - this(nodeId, node, buildShardRoutingMap(shards)); - } - - RoutingNode(String nodeId, DiscoveryNode node, LinkedHashMap shards) { + public RoutingNode(String nodeId, DiscoveryNode node, ShardRouting... shardRoutings) { this.nodeId = nodeId; this.node = node; - this.shards = shards; + final LinkedHashMap primaryShards = new LinkedHashMap<>(); + final LinkedHashMap replicaShards = new LinkedHashMap<>(); + this.shards = new BucketedShards(primaryShards, replicaShards); this.relocatingShards = new LinkedHashSet<>(); this.initializingShards = new LinkedHashSet<>(); this.shardsByIndex = new LinkedHashMap<>(); - for (ShardRouting shardRouting : shards.values()) { + + for (ShardRouting shardRouting : shardRoutings) { if (shardRouting.initializing()) { initializingShards.add(shardRouting); } else if (shardRouting.relocating()) { relocatingShards.add(shardRouting); } shardsByIndex.computeIfAbsent(shardRouting.index(), k -> new LinkedHashSet<>()).add(shardRouting); - } - assert invariant(); - } - private static LinkedHashMap buildShardRoutingMap(ShardRouting... shardRoutings) { - final LinkedHashMap shards = new LinkedHashMap<>(); - for (ShardRouting shardRouting : shardRoutings) { - ShardRouting previousValue = shards.put(shardRouting.shardId(), shardRouting); + ShardRouting previousValue; + if (shardRouting.primary()) { + previousValue = primaryShards.put(shardRouting.shardId(), shardRouting); + } else { + previousValue = replicaShards.put(shardRouting.shardId(), shardRouting); + } + if (previousValue != null) { - throw new IllegalArgumentException( - "Cannot have two different shards with same shard id " + shardRouting.shardId() + " on same node " - ); + throw new IllegalArgumentException("Cannot have two different shards with same shard id " + shardRouting.shardId() + + " on same node "); } } - return shards; + + assert invariant(); } @Override public Iterator iterator() { - return Collections.unmodifiableCollection(shards.values()).iterator(); + return shards.iterator(); } /** @@ -139,7 +213,7 @@ public int size() { */ void add(ShardRouting shard) { assert invariant(); - if (shards.containsKey(shard.shardId())) { + if (shards.add(shard) != null) { throw new IllegalStateException( "Trying to add a shard " + shard.shardId() @@ -152,7 +226,6 @@ void add(ShardRouting shard) { + "]" ); } - shards.put(shard.shardId(), shard); if (shard.initializing()) { initializingShards.add(shard); @@ -322,7 +395,7 @@ public int numberOfOwningShardsForIndex(final Index index) { public String prettyPrint() { StringBuilder sb = new StringBuilder(); sb.append("-----node_id[").append(nodeId).append("][").append(node == null ? "X" : "V").append("]\n"); - for (ShardRouting entry : shards.values()) { + for (ShardRouting entry : shards) { sb.append("--------").append(entry.shortSummary()).append('\n'); } return sb.toString(); @@ -345,7 +418,9 @@ public String toString() { } public List copyShards() { - return new ArrayList<>(shards.values()); + List result = new ArrayList<>(); + shards.forEach(result::add); + return result; } public boolean isEmpty() { @@ -355,23 +430,23 @@ public boolean isEmpty() { private boolean invariant() { // initializingShards must consistent with that in shards - Collection shardRoutingsInitializing = shards.values() - .stream() + Collection shardRoutingsInitializing = StreamSupport + .stream(shards.spliterator(), false) .filter(ShardRouting::initializing) .collect(Collectors.toList()); assert initializingShards.size() == shardRoutingsInitializing.size(); assert initializingShards.containsAll(shardRoutingsInitializing); // relocatingShards must consistent with that in shards - Collection shardRoutingsRelocating = shards.values() - .stream() + Collection shardRoutingsRelocating = StreamSupport + .stream(shards.spliterator(), false) .filter(ShardRouting::relocating) .collect(Collectors.toList()); assert relocatingShards.size() == shardRoutingsRelocating.size(); assert relocatingShards.containsAll(shardRoutingsRelocating); - final Map> shardRoutingsByIndex = shards.values() - .stream() + final Map> shardRoutingsByIndex = StreamSupport + .stream(shards.spliterator(), false) .collect(Collectors.groupingBy(ShardRouting::index, Collectors.toSet())); assert shardRoutingsByIndex.equals(shardsByIndex); diff --git a/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java b/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java index bf79ba26d527a..de51d579713e1 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java +++ b/server/src/main/java/org/opensearch/cluster/routing/RoutingNodes.java @@ -1284,6 +1284,8 @@ public Iterator nodeInterleavedShardIterator() { queue.add(entry.getValue().copyShards().iterator()); } return new Iterator() { + private Queue replicaShards = new ArrayDeque<>(); + private Queue> replicaIterators = new ArrayDeque<>(); public boolean hasNext() { while (!queue.isEmpty()) { if (queue.peek().hasNext()) { @@ -1291,6 +1293,15 @@ public boolean hasNext() { } queue.poll(); } + if (!replicaShards.isEmpty()) { + return true; + } + while (!replicaIterators.isEmpty()) { + if (replicaIterators.peek().hasNext()) { + return true; + } + replicaIterators.poll(); + } return false; } @@ -1298,10 +1309,25 @@ public ShardRouting next() { if (hasNext() == false) { throw new NoSuchElementException(); } - Iterator iter = queue.poll(); - ShardRouting result = iter.next(); - queue.offer(iter); - return result; + while (!queue.isEmpty()) { + Iterator iter = queue.poll(); + if (iter.hasNext()) { + ShardRouting result = iter.next(); + if (result.primary()) { + queue.offer(iter); + return result; + } + replicaShards.offer(result); + replicaIterators.offer(iter); + } + } + if (!replicaShards.isEmpty()) { + return replicaShards.poll(); + } + Iterator replicaIterator = replicaIterators.poll(); + ShardRouting replicaShard = replicaIterator.next(); + replicaIterators.offer(replicaIterator); + return replicaShard; } public void remove() {