From 2483981e6e04beeae940e337a9221a1dd7e7a90c Mon Sep 17 00:00:00 2001 From: Peter Alfonsi Date: Fri, 1 Mar 2024 10:24:22 -0800 Subject: [PATCH] Optimized multi dimension stats Signed-off-by: Peter Alfonsi --- .../cache/stats/MultiDimensionCacheStats.java | 76 +++++++++++++------ .../stats/MultiDimensionCacheStatsTests.java | 17 +++++ 2 files changed, 68 insertions(+), 25 deletions(-) diff --git a/server/src/main/java/org/opensearch/common/cache/stats/MultiDimensionCacheStats.java b/server/src/main/java/org/opensearch/common/cache/stats/MultiDimensionCacheStats.java index ba356335c11b4..370010dd7e282 100644 --- a/server/src/main/java/org/opensearch/common/cache/stats/MultiDimensionCacheStats.java +++ b/server/src/main/java/org/opensearch/common/cache/stats/MultiDimensionCacheStats.java @@ -13,6 +13,8 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -42,7 +44,7 @@ public class MultiDimensionCacheStats implements CacheStats { final String tierDimensionValue; // A map from a set of cache stats dimensions -> stats for that combination of dimensions. Does not include the tier dimension in its keys. - final ConcurrentMap, CacheStatsResponse> map; + final ConcurrentMap map; final int maxDimensionValues; CacheStatsResponse totalStats; @@ -62,11 +64,11 @@ public MultiDimensionCacheStats(List dimensionNames, String tierDimensio public MultiDimensionCacheStats(StreamInput in) throws IOException { this.dimensionNames = List.of(in.readStringArray()); this.tierDimensionValue = in.readString(); - Map, CacheStatsResponse> readMap = in.readMap( - i -> Set.of(i.readArray(CacheStatsDimension::new, CacheStatsDimension[]::new)), + Map readMap = in.readMap( + i -> new Key(Set.of(i.readArray(CacheStatsDimension::new, CacheStatsDimension[]::new))), CacheStatsResponse::new ); - this.map = new ConcurrentHashMap, CacheStatsResponse>(readMap); + this.map = new ConcurrentHashMap(readMap); this.totalStats = new CacheStatsResponse(in); this.maxDimensionValues = in.readVInt(); } @@ -77,7 +79,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(tierDimensionValue); out.writeMap( map, - (o, keySet) -> o.writeArray((o1, dim) -> ((CacheStatsDimension) dim).writeTo(o1), keySet.toArray()), + (o, key) -> o.writeArray((o1, dim) -> ((CacheStatsDimension) dim).writeTo(o1), key.dimensions.toArray()), (o, response) -> response.writeTo(o) ); totalStats.writeTo(out); @@ -89,6 +91,10 @@ public CacheStatsResponse getTotalStats() { return totalStats; } + /** + * Get the stats response aggregated by dimensions. If there are no values for the specified dimensions, + * returns an all-zero response. + */ @Override public CacheStatsResponse getStatsByDimensions(List dimensions) { if (!checkDimensionNames(dimensions)) { @@ -103,12 +109,16 @@ public CacheStatsResponse getStatsByDimensions(List dimensi modifiedDimensions.remove(tierDim); } + if (modifiedDimensions.size() == dimensionNames.size()) { + return map.getOrDefault(new Key(modifiedDimensions), new CacheStatsResponse()); + } + // I don't think there's a more efficient way to get arbitrary combinations of dimensions than to just keep a map // and iterate through it, checking if keys match. We can't pre-aggregate because it would consume a lot of memory. CacheStatsResponse response = new CacheStatsResponse(); - for (Set storedDimensions : map.keySet()) { - if (storedDimensions.containsAll(modifiedDimensions)) { - response.add(map.get(storedDimensions)); + for (Key key : map.keySet()) { + if (key.dimensions.containsAll(modifiedDimensions)) { + response.add(map.get(key)); } } return response; @@ -128,27 +138,14 @@ private CacheStatsDimension getTierDimension(List dimension private boolean checkDimensionNames(List dimensions) { for (CacheStatsDimension dim : dimensions) { - if (!dimensionNames.contains(dim.dimensionName) && !dim.dimensionName.equals(CacheStatsDimension.TIER_DIMENSION_NAME)) { + if (!(dimensionNames.contains(dim.dimensionName) || dim.dimensionName.equals(CacheStatsDimension.TIER_DIMENSION_NAME))) { + // Reject dimension names that aren't in the list and aren't the tier dimension return false; } } return true; } - private CacheStatsResponse getStatsBySingleDimension(CacheStatsDimension dimension) { - assert dimensionNames.size() == 1; - CacheStatsResponse response = new CacheStatsResponse(); - for (Set dimensions : map.keySet()) { - // Each set has only one element - for (CacheStatsDimension keyDimension : dimensions) { - if (keyDimension.dimensionValue.equals(dimension.dimensionValue)) { - response.add(map.get(dimensions)); - } - } - } - return response; - } - @Override public long getTotalHits() { return totalStats.getHits(); @@ -231,11 +228,11 @@ public void decrementEntriesByDimensions(List dimensions) { private CacheStatsResponse internalGetStats(List dimensions) { assert dimensions.size() == dimensionNames.size(); - CacheStatsResponse response = map.get(new HashSet<>(dimensions)); + CacheStatsResponse response = map.get(new Key(dimensions)); if (response == null) { if (map.size() < maxDimensionValues) { response = new CacheStatsResponse(); - map.put(new HashSet<>(dimensions), response); + map.put(new Key(dimensions), response); } else { throw new RuntimeException("Cannot add new combination of dimension values to stats object; reached maximum"); } @@ -249,4 +246,33 @@ private void internalIncrement(List dimensions, BiConsumer< incrementer.accept(totalStats, amount); } + /** + * Unmodifiable wrapper over a set of CacheStatsDimension. Pkg-private for testing. + */ + static class Key { + final Set dimensions; + Key(Set dimensions) { + this.dimensions = Collections.unmodifiableSet(dimensions); + } + Key(List dimensions) { + this(new HashSet<>(dimensions)); + } + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } if (o == null) { + return false; + } if (o.getClass() != Key.class) { + return false; + } + Key other = (Key) o; + return this.dimensions.equals(other.dimensions); + } + + @Override + public int hashCode() { + return this.dimensions.hashCode(); + } + } } diff --git a/server/src/test/java/org/opensearch/common/cache/stats/MultiDimensionCacheStatsTests.java b/server/src/test/java/org/opensearch/common/cache/stats/MultiDimensionCacheStatsTests.java index 127b5c979f27a..7855a2d202246 100644 --- a/server/src/test/java/org/opensearch/common/cache/stats/MultiDimensionCacheStatsTests.java +++ b/server/src/test/java/org/opensearch/common/cache/stats/MultiDimensionCacheStatsTests.java @@ -149,6 +149,23 @@ public void testTierLogic() throws Exception { assertEquals(new CacheStatsResponse(), stats.getStatsByDimensions(List.of(wrongTierDim))); } + public void testKeyEquality() throws Exception { + Set dims1 = new HashSet<>(); + dims1.add(new CacheStatsDimension("a", "1")); + dims1.add(new CacheStatsDimension("b", "2")); + dims1.add(new CacheStatsDimension("c", "3")); + MultiDimensionCacheStats.Key key1 = new MultiDimensionCacheStats.Key(dims1); + + List dims2 = new ArrayList<>(); + dims2.add(new CacheStatsDimension("c", "3")); + dims2.add(new CacheStatsDimension("a", "1")); + dims2.add(new CacheStatsDimension("b", "2")); + MultiDimensionCacheStats.Key key2 = new MultiDimensionCacheStats.Key(dims2); + + assertEquals(key1, key2); + assertEquals(key1.hashCode(), key2.hashCode()); + } + private Map> getUsedDimensionValues(MultiDimensionCacheStats stats, int numValuesPerDim) { Map> usedDimensionValues = new HashMap<>(); for (int i = 0; i < stats.dimensionNames.size(); i++) {