Skip to content

Commit

Permalink
Weighted round-robin scheduling policy for shard coordination traffic…
Browse files Browse the repository at this point in the history
… routing

Signed-off-by: Anshu Agarwal <anshukag@amazon.com>
  • Loading branch information
Anshu Agarwal committed Aug 17, 2022
1 parent 280b938 commit c2eee63
Show file tree
Hide file tree
Showing 7 changed files with 671 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.cluster.metadata;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchParseException;
import org.opensearch.Version;
import org.opensearch.cluster.AbstractNamedDiffable;
import org.opensearch.cluster.NamedDiff;
import org.opensearch.cluster.routing.WRRWeight;
import org.opensearch.common.Strings;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;

/**
* Contains metadata for weighted round-robin shard routing weights
*
* @opensearch.internal
*/
public class WeightedRoundRobinMetadata extends AbstractNamedDiffable<Metadata.Custom> implements Metadata.Custom {
private static final Logger logger = LogManager.getLogger(WeightedRoundRobinMetadata.class);
public static final String TYPE = "wrr_shard_routing";
private WRRWeight wrrWeight;

public WRRWeight getWrrWeight() {
return wrrWeight;
}

public WeightedRoundRobinMetadata setWrrWeight(WRRWeight wrrWeight) {
this.wrrWeight = wrrWeight;
return this;
}

public WeightedRoundRobinMetadata(StreamInput in) throws IOException {
this.wrrWeight = new WRRWeight(in);
}

public WeightedRoundRobinMetadata(WRRWeight wrrWeight) {
this.wrrWeight = wrrWeight;
}

@Override
public EnumSet<Metadata.XContentContext> context() {
return Metadata.API_AND_GATEWAY;
}

@Override
public String getWriteableName() {
return TYPE;
}

@Override
public Version getMinimalSupportedVersion() {
// TODO: Check if this needs to be changed
return Version.CURRENT.minimumCompatibilityVersion();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
wrrWeight.writeTo(out);
}

public static NamedDiff<Metadata.Custom> readDiffFrom(StreamInput in) throws IOException {
return readDiffFrom(Metadata.Custom.class, TYPE, in);
}

public static WeightedRoundRobinMetadata fromXContent(XContentParser parser) throws IOException {
String attrKey = null;
Object attrValue;
String attributeName = null;
Map<String, Object> weights = new HashMap<>();
WRRWeight wrrWeight = null;
XContentParser.Token token;
// move to the first alias
parser.nextToken();
String awarenessField = null;

while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
awarenessField = parser.currentName();
if (parser.nextToken() != XContentParser.Token.START_OBJECT) {
throw new OpenSearchParseException("failed to parse wrr metadata [{}], expected object", awarenessField);
}
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
attributeName = parser.currentName();
if (parser.nextToken() != XContentParser.Token.START_OBJECT) {
throw new OpenSearchParseException("failed to parse wrr metadata [{}], expected object", attributeName);
}
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
attrKey = parser.currentName();
} else if (token == XContentParser.Token.VALUE_STRING) {
attrValue = parser.text();
weights.put(attrKey, attrValue);
} else {
throw new OpenSearchParseException("failed to parse wrr metadata attribute [{}], unknown type", attributeName);
}
}
}
} else {
throw new OpenSearchParseException("failed to parse wrr metadata attribute [{}]", attributeName);
}
}
wrrWeight = new WRRWeight(attributeName, weights);
return new WeightedRoundRobinMetadata(wrrWeight);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedRoundRobinMetadata that = (WeightedRoundRobinMetadata) o;
return wrrWeight.equals(that.wrrWeight);
}

@Override
public int hashCode() {
return wrrWeight.hashCode();
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
toXContent(wrrWeight, builder);
return builder;
}

public static void toXContent(WRRWeight wrrWeight, XContentBuilder builder) throws IOException {
builder.startObject("awareness");
builder.startObject(wrrWeight.attributeName());
for (Map.Entry<String, Object> entry : wrrWeight.weights().entrySet()) {
builder.field(entry.getKey(), entry.getValue());
}
builder.endObject();
builder.endObject();
}

@Override
public String toString() {
return Strings.toString(this);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.opensearch.index.Index;
import org.opensearch.index.shard.ShardId;
import org.opensearch.node.ResponseCollectorService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -292,6 +291,65 @@ public ShardIterator activeInitializingShardsRankedIt(
return new PlainShardIterator(shardId, ordered);
}

/**
* *
* @param wrrWeight Weighted round-robin weight entity
* @param nodes discovered nodes in the cluster
* @return an interator over active and initializing shards, ordered by weighted round-robin
* scheduling policy. Making sure that initializing shards are the last to iterate through.
*/
public ShardIterator activeInitializingShardsWRR(WRRWeight wrrWeight, DiscoveryNodes nodes) {
final int seed = shuffler.nextSeed();
ArrayList<ShardRouting> ordered = new ArrayList<>(activeShards.size() + allInitializingShards.size());
ArrayList<ShardRouting> orderedActiveShards = getShardsWRR(activeShards, wrrWeight, nodes);
ordered.addAll(shuffler.shuffle(orderedActiveShards, seed));
if (!allInitializingShards.isEmpty()) {
ArrayList<ShardRouting> orderedInitializingShards = getShardsWRR(allInitializingShards, wrrWeight, nodes);
ordered.addAll(shuffler.shuffle(orderedInitializingShards, seed));
}
return new PlainShardIterator(shardId, ordered);
}

/**
*
* @param shards shards to be ordered using weighted round-robin scheduling policy
* @param wrrWeight weights to be considered for routing
* @param nodes discovered nodes in the cluster
* @return list of shards ordered using weighted round-robin scheduling.
*/
private ArrayList<ShardRouting> getShardsWRR(List<ShardRouting> shards, WRRWeight wrrWeight, DiscoveryNodes nodes) {
List<WeightedRoundRobin.Entity<ShardRouting>> weightedShards = calculateShardWeight(shards, wrrWeight, nodes);
WeightedRoundRobin<ShardRouting> wrr = new WeightedRoundRobin<>(weightedShards);
List<WeightedRoundRobin.Entity<ShardRouting>> wrrOrderedActiveShards = wrr.orderEntities();
ArrayList<ShardRouting> orderedActiveShards = new ArrayList<>(activeShards.size());
for (WeightedRoundRobin.Entity<ShardRouting> shardRouting : wrrOrderedActiveShards) {
orderedActiveShards.add(shardRouting.getTarget());
}
return orderedActiveShards;
}

/**
* *
* @param shards associate weights to shards
* @param wrrWeight weights to be used for association
* @param nodes
* @return list of entity containing shard routing and associated weight.
*/
private List<WeightedRoundRobin.Entity<ShardRouting>> calculateShardWeight(
List<ShardRouting> shards,
WRRWeight wrrWeight,
DiscoveryNodes nodes
) {
List<WeightedRoundRobin.Entity<ShardRouting>> weightedShards = new ArrayList<>();
for (ShardRouting shard : shards) {
shard.currentNodeId();
DiscoveryNode node = nodes.get(shard.currentNodeId());
String attVal = node.getAttributes().get(wrrWeight.attributeName());
weightedShards.add(new WeightedRoundRobin.Entity<>(Double.parseDouble(wrrWeight.weights().get(attVal).toString()), shard));
}
return weightedShards;
}

private static Set<String> getAllNodeIds(final List<ShardRouting> shards) {
final Set<String> nodeIds = new HashSet<>();
for (ShardRouting shard : shards) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.WeightedRoundRobinMetadata;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.routing.allocation.decider.AwarenessAllocationDecider;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -68,6 +69,12 @@ public class OperationRouting {
Setting.Property.NodeScope
);

public static final Setting<Boolean> USE_WEIGHTED_ROUND_ROBIN = Setting.boolSetting(
"cluster.routing.use_weighted_round_robin",
true,
Setting.Property.Dynamic,
Setting.Property.NodeScope
);
public static final String IGNORE_AWARENESS_ATTRIBUTES = "cluster.search.ignore_awareness_attributes";
public static final Setting<Boolean> IGNORE_AWARENESS_ATTRIBUTES_SETTING = Setting.boolSetting(
IGNORE_AWARENESS_ATTRIBUTES,
Expand All @@ -79,6 +86,8 @@ public class OperationRouting {
private volatile boolean useAdaptiveReplicaSelection;
private volatile boolean ignoreAwarenessAttr;

private volatile boolean useWeightedRoundRobin;

public OperationRouting(Settings settings, ClusterSettings clusterSettings) {
// whether to ignore awareness attributes when routing requests
this.ignoreAwarenessAttr = clusterSettings.get(IGNORE_AWARENESS_ATTRIBUTES_SETTING);
Expand All @@ -88,8 +97,11 @@ public OperationRouting(Settings settings, ClusterSettings clusterSettings) {
this::setAwarenessAttributes
);
this.useAdaptiveReplicaSelection = USE_ADAPTIVE_REPLICA_SELECTION_SETTING.get(settings);
this.useWeightedRoundRobin = USE_WEIGHTED_ROUND_ROBIN.get(settings);
clusterSettings.addSettingsUpdateConsumer(USE_ADAPTIVE_REPLICA_SELECTION_SETTING, this::setUseAdaptiveReplicaSelection);
clusterSettings.addSettingsUpdateConsumer(IGNORE_AWARENESS_ATTRIBUTES_SETTING, this::setIgnoreAwarenessAttributes);
clusterSettings.addSettingsUpdateConsumer(USE_WEIGHTED_ROUND_ROBIN, this::setUseWeightedRoundRobin);

}

void setUseAdaptiveReplicaSelection(boolean useAdaptiveReplicaSelection) {
Expand All @@ -100,6 +112,14 @@ void setIgnoreAwarenessAttributes(boolean ignoreAwarenessAttributes) {
this.ignoreAwarenessAttr = ignoreAwarenessAttributes;
}

public boolean isUseWeightedRoundRobin() {
return useWeightedRoundRobin;
}

public void setUseWeightedRoundRobin(boolean useWeightedRoundRobin) {
this.useWeightedRoundRobin = useWeightedRoundRobin;
}

public boolean isIgnoreAwarenessAttr() {
return ignoreAwarenessAttr;
}
Expand Down Expand Up @@ -169,21 +189,37 @@ public GroupShardsIterator<ShardIterator> searchShards(
final Set<IndexShardRoutingTable> shards = computeTargetedShards(clusterState, concreteIndices, routing);
final Set<ShardIterator> set = new HashSet<>(shards.size());
for (IndexShardRoutingTable shard : shards) {
ShardIterator iterator = preferenceActiveShardIterator(
shard,
clusterState.nodes().getLocalNodeId(),
clusterState.nodes(),
preference,
collectorService,
nodeCounts
);
ShardIterator iterator = null;
// TODO: Do we need similar changes in getShards call??
if (isWeightedRoundRobinEnabled(clusterState)) {
WeightedRoundRobinMetadata weightedRoundRobinMetadata = clusterState.metadata().custom(WeightedRoundRobinMetadata.TYPE);
iterator = shard.activeInitializingShardsWRR(weightedRoundRobinMetadata.getWrrWeight(), clusterState.nodes());
} else {
iterator = preferenceActiveShardIterator(
shard,
clusterState.nodes().getLocalNodeId(),
clusterState.nodes(),
preference,
collectorService,
nodeCounts
);
}

if (iterator != null) {
set.add(iterator);
}
}
return GroupShardsIterator.sortAndCreate(new ArrayList<>(set));
}

private boolean isWeightedRoundRobinEnabled(ClusterState clusterState) {
WeightedRoundRobinMetadata weightedRoundRobinMetadata = clusterState.metadata().custom(WeightedRoundRobinMetadata.TYPE);
if (useWeightedRoundRobin && weightedRoundRobinMetadata != null) {
return true;
}
return false;
}

public static ShardIterator getShards(ClusterState clusterState, ShardId shardId) {
final IndexShardRoutingTable shard = clusterState.routingTable().shardRoutingTable(shardId);
return shard.activeInitializingShardsRandomIt();
Expand Down Expand Up @@ -227,6 +263,7 @@ private ShardIterator preferenceActiveShardIterator(
@Nullable ResponseCollectorService collectorService,
@Nullable Map<String, Long> nodeCounts
) {

if (preference == null || preference.isEmpty()) {
return shardRoutings(indexShard, nodes, collectorService, nodeCounts);
}
Expand Down
Loading

0 comments on commit c2eee63

Please sign in to comment.