Skip to content

Commit

Permalink
Optimized multi dimension stats
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Alfonsi <petealft@amazon.com>
  • Loading branch information
Peter Alfonsi committed Mar 1, 2024
1 parent a61f033 commit 2483981
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -42,7 +44,7 @@ public class MultiDimensionCacheStats implements CacheStats {
final String tierDimensionValue;

// A map from a set of cache stats dimensions -> stats for that combination of dimensions. Does not include the tier dimension in its keys.
final ConcurrentMap<Set<CacheStatsDimension>, CacheStatsResponse> map;
final ConcurrentMap<Key, CacheStatsResponse> map;

final int maxDimensionValues;
CacheStatsResponse totalStats;
Expand All @@ -62,11 +64,11 @@ public MultiDimensionCacheStats(List<String> dimensionNames, String tierDimensio
public MultiDimensionCacheStats(StreamInput in) throws IOException {
this.dimensionNames = List.of(in.readStringArray());
this.tierDimensionValue = in.readString();
Map<Set<CacheStatsDimension>, CacheStatsResponse> readMap = in.readMap(
i -> Set.of(i.readArray(CacheStatsDimension::new, CacheStatsDimension[]::new)),
Map<Key, CacheStatsResponse> readMap = in.readMap(
i -> new Key(Set.of(i.readArray(CacheStatsDimension::new, CacheStatsDimension[]::new))),
CacheStatsResponse::new
);
this.map = new ConcurrentHashMap<Set<CacheStatsDimension>, CacheStatsResponse>(readMap);
this.map = new ConcurrentHashMap<Key, CacheStatsResponse>(readMap);
this.totalStats = new CacheStatsResponse(in);
this.maxDimensionValues = in.readVInt();
}
Expand All @@ -77,7 +79,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(tierDimensionValue);
out.writeMap(
map,
(o, keySet) -> o.writeArray((o1, dim) -> ((CacheStatsDimension) dim).writeTo(o1), keySet.toArray()),
(o, key) -> o.writeArray((o1, dim) -> ((CacheStatsDimension) dim).writeTo(o1), key.dimensions.toArray()),
(o, response) -> response.writeTo(o)
);
totalStats.writeTo(out);
Expand All @@ -89,6 +91,10 @@ public CacheStatsResponse getTotalStats() {
return totalStats;
}

/**
* Get the stats response aggregated by dimensions. If there are no values for the specified dimensions,
* returns an all-zero response.
*/
@Override
public CacheStatsResponse getStatsByDimensions(List<CacheStatsDimension> dimensions) {
if (!checkDimensionNames(dimensions)) {
Expand All @@ -103,12 +109,16 @@ public CacheStatsResponse getStatsByDimensions(List<CacheStatsDimension> dimensi
modifiedDimensions.remove(tierDim);
}

if (modifiedDimensions.size() == dimensionNames.size()) {
return map.getOrDefault(new Key(modifiedDimensions), new CacheStatsResponse());
}

// I don't think there's a more efficient way to get arbitrary combinations of dimensions than to just keep a map
// and iterate through it, checking if keys match. We can't pre-aggregate because it would consume a lot of memory.
CacheStatsResponse response = new CacheStatsResponse();
for (Set<CacheStatsDimension> storedDimensions : map.keySet()) {
if (storedDimensions.containsAll(modifiedDimensions)) {
response.add(map.get(storedDimensions));
for (Key key : map.keySet()) {
if (key.dimensions.containsAll(modifiedDimensions)) {
response.add(map.get(key));
}
}
return response;
Expand All @@ -128,27 +138,14 @@ private CacheStatsDimension getTierDimension(List<CacheStatsDimension> dimension

private boolean checkDimensionNames(List<CacheStatsDimension> dimensions) {
for (CacheStatsDimension dim : dimensions) {
if (!dimensionNames.contains(dim.dimensionName) && !dim.dimensionName.equals(CacheStatsDimension.TIER_DIMENSION_NAME)) {
if (!(dimensionNames.contains(dim.dimensionName) || dim.dimensionName.equals(CacheStatsDimension.TIER_DIMENSION_NAME))) {
// Reject dimension names that aren't in the list and aren't the tier dimension
return false;
}
}
return true;
}

private CacheStatsResponse getStatsBySingleDimension(CacheStatsDimension dimension) {
assert dimensionNames.size() == 1;
CacheStatsResponse response = new CacheStatsResponse();
for (Set<CacheStatsDimension> dimensions : map.keySet()) {
// Each set has only one element
for (CacheStatsDimension keyDimension : dimensions) {
if (keyDimension.dimensionValue.equals(dimension.dimensionValue)) {
response.add(map.get(dimensions));
}
}
}
return response;
}

@Override
public long getTotalHits() {
return totalStats.getHits();
Expand Down Expand Up @@ -231,11 +228,11 @@ public void decrementEntriesByDimensions(List<CacheStatsDimension> dimensions) {

private CacheStatsResponse internalGetStats(List<CacheStatsDimension> dimensions) {
assert dimensions.size() == dimensionNames.size();
CacheStatsResponse response = map.get(new HashSet<>(dimensions));
CacheStatsResponse response = map.get(new Key(dimensions));
if (response == null) {
if (map.size() < maxDimensionValues) {
response = new CacheStatsResponse();
map.put(new HashSet<>(dimensions), response);
map.put(new Key(dimensions), response);
} else {
throw new RuntimeException("Cannot add new combination of dimension values to stats object; reached maximum");
}
Expand All @@ -249,4 +246,33 @@ private void internalIncrement(List<CacheStatsDimension> dimensions, BiConsumer<
incrementer.accept(totalStats, amount);
}

/**
* Unmodifiable wrapper over a set of CacheStatsDimension. Pkg-private for testing.
*/
static class Key {
final Set<CacheStatsDimension> dimensions;
Key(Set<CacheStatsDimension> dimensions) {
this.dimensions = Collections.unmodifiableSet(dimensions);
}
Key(List<CacheStatsDimension> dimensions) {
this(new HashSet<>(dimensions));
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} if (o == null) {
return false;
} if (o.getClass() != Key.class) {
return false;
}
Key other = (Key) o;
return this.dimensions.equals(other.dimensions);
}

@Override
public int hashCode() {
return this.dimensions.hashCode();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,23 @@ public void testTierLogic() throws Exception {
assertEquals(new CacheStatsResponse(), stats.getStatsByDimensions(List.of(wrongTierDim)));
}

public void testKeyEquality() throws Exception {
Set<CacheStatsDimension> dims1 = new HashSet<>();
dims1.add(new CacheStatsDimension("a", "1"));
dims1.add(new CacheStatsDimension("b", "2"));
dims1.add(new CacheStatsDimension("c", "3"));
MultiDimensionCacheStats.Key key1 = new MultiDimensionCacheStats.Key(dims1);

List<CacheStatsDimension> dims2 = new ArrayList<>();
dims2.add(new CacheStatsDimension("c", "3"));
dims2.add(new CacheStatsDimension("a", "1"));
dims2.add(new CacheStatsDimension("b", "2"));
MultiDimensionCacheStats.Key key2 = new MultiDimensionCacheStats.Key(dims2);

assertEquals(key1, key2);
assertEquals(key1.hashCode(), key2.hashCode());
}

private Map<String, List<String>> getUsedDimensionValues(MultiDimensionCacheStats stats, int numValuesPerDim) {
Map<String, List<String>> usedDimensionValues = new HashMap<>();
for (int i = 0; i < stats.dimensionNames.size(); i++) {
Expand Down

0 comments on commit 2483981

Please sign in to comment.