Skip to content

Commit c56a110

Browse files
committed
Reformat response
Signed-off-by: Andy Qin <qinandy@amazon.com>
1 parent f6fc682 commit c56a110

14 files changed

+135
-95
lines changed

src/main/java/org/opensearch/neuralsearch/rest/RestNeuralStatsAction.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,20 @@
3838
@Log4j2
3939
@AllArgsConstructor
4040
public class RestNeuralStatsAction extends BaseRestHandler {
41-
public static final String FLATTEN_PARAM = "flat_keys";
41+
public static final String FLATTEN_PARAM = "flat_stat_paths";
4242
public static final String INCLUDE_METADATA_PARAM = "include_metadata";
4343
private static final String NAME = "neural_stats_action";
4444

4545
private static final Set<String> EVENT_STAT_NAMES = EnumSet.allOf(EventStatName.class)
4646
.stream()
4747
.map(EventStatName::getNameString)
48-
.map(String::toLowerCase)
48+
.map(str -> str.toLowerCase(Locale.ROOT))
4949
.collect(Collectors.toSet());
5050

5151
private static final Set<String> STATE_STAT_NAMES = EnumSet.allOf(InfoStatName.class)
5252
.stream()
5353
.map(InfoStatName::getNameString)
54-
.map(String::toLowerCase)
54+
.map(str -> str.toLowerCase(Locale.ROOT))
5555
.collect(Collectors.toSet());
5656

5757
private static final List<Route> ROUTES = ImmutableList.of(
@@ -61,7 +61,7 @@ public class RestNeuralStatsAction extends BaseRestHandler {
6161
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/{stat}")
6262
);
6363

64-
private static final Set<String> RESPONSE_PARAMS = ImmutableSet.of("nodeId", "stat");
64+
private static final Set<String> RESPONSE_PARAMS = ImmutableSet.of("nodeId", "stat", INCLUDE_METADATA_PARAM, FLATTEN_PARAM);
6565

6666
private NeuralSearchSettingsAccessor settingsAccessor;
6767

src/main/java/org/opensearch/neuralsearch/stats/NeuralStatsInput.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.opensearch.core.xcontent.ToXContent;
1414
import org.opensearch.core.xcontent.ToXContentObject;
1515
import org.opensearch.core.xcontent.XContentBuilder;
16+
import org.opensearch.neuralsearch.rest.RestNeuralStatsAction;
1617
import org.opensearch.neuralsearch.stats.events.EventStatName;
1718
import org.opensearch.neuralsearch.stats.info.InfoStatName;
1819

@@ -160,8 +161,8 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
160161
if (infoStatNames != null) {
161162
builder.field(STATE_STAT_NAMES_FIELD, infoStatNames);
162163
}
163-
builder.field("include_metadata", includeMetadata);
164-
builder.field("flat_keys", flatten);
164+
builder.field(RestNeuralStatsAction.INCLUDE_METADATA_PARAM, includeMetadata);
165+
builder.field(RestNeuralStatsAction.FLATTEN_PARAM, flatten);
165166
builder.endObject();
166167
return builder;
167168
}

src/main/java/org/opensearch/neuralsearch/stats/events/EventStatName.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.neuralsearch.stats.common.StatName;
1010

1111
import java.util.Arrays;
12+
import java.util.Locale;
1213
import java.util.Map;
1314
import java.util.stream.Collectors;
1415

@@ -50,7 +51,7 @@ public enum EventStatName implements StatName {
5051
// Validates all event stats are instantiated correctly. This is covered by unit tests as well.
5152
if (eventStat == null) {
5253
throw new IllegalArgumentException(
53-
String.format("Unable to initialize event stat [%s]. Unrecognized event stat type: [%s]", nameString, statType)
54+
String.format(Locale.ROOT, "Unable to initialize event stat [%s]. Unrecognized event stat type: [%s]", nameString, statType)
5455
);
5556
}
5657
}
@@ -63,7 +64,7 @@ public enum EventStatName implements StatName {
6364
*/
6465
public static EventStatName from(String name) {
6566
if (BY_NAME.containsKey(name) == false) {
66-
throw new IllegalArgumentException(String.format("Event stat not found: %s", name));
67+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Event stat not found: %s", name));
6768
}
6869
return BY_NAME.get(name);
6970
}

src/main/java/org/opensearch/neuralsearch/stats/events/TimestampedEventStatSnapshot.java

+6
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ public static TimestampedEventStatSnapshot aggregateEventStatSnapshots(Collectio
8787
Long minMinutes = null;
8888

8989
for (TimestampedEventStatSnapshot stat : snapshots) {
90+
// Mixed version clusters may have nodes that return null stat snapshots not available on older versions.
91+
// If so, exclude those from aggregation
92+
if (stat == null) {
93+
continue;
94+
}
95+
9096
// The first stat name is taken. This should never be called across event stats that don't share stat names
9197
if (name == null) {
9298
name = stat.getStatName();

src/main/java/org/opensearch/neuralsearch/stats/info/InfoStatName.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.opensearch.neuralsearch.stats.common.StatName;
1010

1111
import java.util.Arrays;
12+
import java.util.Locale;
1213
import java.util.Map;
1314
import java.util.stream.Collectors;
1415

@@ -19,7 +20,7 @@
1920
public enum InfoStatName implements StatName {
2021
// Cluster info
2122
CLUSTER_VERSION("cluster_version", "", InfoStatType.SETTABLE_STRING),
22-
TEXT_EMBEDDING_PROCESSORS("text_embedding_processors_in_pipelines", "processors.ingest", InfoStatType.COUNTABLE);
23+
TEXT_EMBEDDING_PROCESSORS("text_embedding_processors_in_pipelines", "processors.ingest", InfoStatType.INFO_COUNTABLE);
2324

2425
private final String nameString;
2526
private final String path;
@@ -48,7 +49,7 @@ public enum InfoStatName implements StatName {
4849
*/
4950
public static InfoStatName from(String value) {
5051
if (BY_NAME.containsKey(value) == false) {
51-
throw new IllegalArgumentException(String.format("Info stat not found: %s", value));
52+
throw new IllegalArgumentException(String.format(Locale.ROOT, "Info stat not found: %s", value));
5253
}
5354
return BY_NAME.get(value);
5455
}

src/main/java/org/opensearch/neuralsearch/stats/info/InfoStatType.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
* Enum for different kinds of info stat types to track
1313
*/
1414
public enum InfoStatType implements StatType {
15-
COUNTABLE,
15+
INFO_COUNTABLE,
1616
SETTABLE_STRING,
1717
SETTABLE_BOOLEAN;
1818

src/main/java/org/opensearch/neuralsearch/stats/info/InfoStatsManager.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ private Map<InfoStatName, CountableInfoStatSnapshot> getCountableStats() {
7272
// Initialize empty map with keys so stat names are visible in JSON even if the value is not counted
7373
Map<InfoStatName, CountableInfoStatSnapshot> countableInfoStats = new HashMap<>();
7474
for (InfoStatName stat : EnumSet.allOf(InfoStatName.class)) {
75-
if (stat.getStatType() == InfoStatType.COUNTABLE) {
75+
if (stat.getStatType() == InfoStatType.INFO_COUNTABLE) {
7676
countableInfoStats.put(stat, new CountableInfoStatSnapshot(stat));
7777
}
7878
}

src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsResponse.java

+30-15
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
*/
2626
@Getter
2727
public class NeuralStatsResponse extends BaseNodesResponse<NeuralStatsNodeResponse> implements ToXContentObject {
28+
public static final String INFO_KEY_PREFIX = "info";
29+
public static final String NODES_KEY_PREFIX = "nodes";
30+
public static final String AGGREGATED_NODES_KEY_PREFIX = "all_nodes";
2831

29-
private static final String NODES_KEY = "nodes";
30-
private Map<String, StatSnapshot<?>> clusterLevelStats;
32+
private Map<String, StatSnapshot<?>> infoStats;
33+
private Map<String, StatSnapshot<?>> aggregatedNodeStats;
3134
private Map<String, Map<String, StatSnapshot<?>>> nodeIdToNodeEventStats;
3235
private boolean flatten;
3336
private boolean includeMetadata;
@@ -40,11 +43,13 @@ public class NeuralStatsResponse extends BaseNodesResponse<NeuralStatsNodeRespon
4043
*/
4144
public NeuralStatsResponse(StreamInput in) throws IOException {
4245
super(new ClusterName(in), in.readList(NeuralStatsNodeResponse::readStats), in.readList(FailedNodeException::new));
43-
Map<String, StatSnapshot<?>> castedStats = (Map<String, StatSnapshot<?>>) (Map) in.readMap();
46+
Map<String, StatSnapshot<?>> castedInfoStats = (Map<String, StatSnapshot<?>>) (Map) in.readMap();
47+
Map<String, StatSnapshot<?>> castedAggregatedNodeStats = (Map<String, StatSnapshot<?>>) (Map) in.readMap();
4448
Map<String, Map<String, StatSnapshot<?>>> castedNodeIdToNodeEventStats = (Map<String, Map<String, StatSnapshot<?>>>) (Map) in
4549
.readMap();
4650

47-
this.clusterLevelStats = castedStats;
51+
this.infoStats = castedInfoStats;
52+
this.aggregatedNodeStats = castedAggregatedNodeStats;
4853
this.nodeIdToNodeEventStats = castedNodeIdToNodeEventStats;
4954
this.flatten = in.readBoolean();
5055
this.includeMetadata = in.readBoolean();
@@ -56,19 +61,20 @@ public NeuralStatsResponse(StreamInput in) throws IOException {
5661
* @param clusterName name of cluster
5762
* @param nodes List of NeuralStatsNodeResponses
5863
* @param failures List of failures from nodes
59-
* @param clusterLevelStats
6064
*/
6165
public NeuralStatsResponse(
6266
ClusterName clusterName,
6367
List<NeuralStatsNodeResponse> nodes,
6468
List<FailedNodeException> failures,
65-
Map<String, StatSnapshot<?>> clusterLevelStats,
69+
Map<String, StatSnapshot<?>> infoStats,
70+
Map<String, StatSnapshot<?>> aggregatedNodeStats,
6671
Map<String, Map<String, StatSnapshot<?>>> nodeIdToNodeEventStats,
6772
boolean flatten,
6873
boolean includeMetadata
6974
) {
7075
super(clusterName, nodes, failures);
71-
this.clusterLevelStats = clusterLevelStats;
76+
this.infoStats = infoStats;
77+
this.aggregatedNodeStats = aggregatedNodeStats;
7278
this.nodeIdToNodeEventStats = nodeIdToNodeEventStats;
7379
this.flatten = flatten;
7480
this.includeMetadata = includeMetadata;
@@ -77,10 +83,13 @@ public NeuralStatsResponse(
7783
@Override
7884
public void writeTo(StreamOutput out) throws IOException {
7985
super.writeTo(out);
80-
Map<String, Object> downcastedStats = (Map<String, Object>) (Map) (clusterLevelStats);
81-
Map<String, Object> downcastedNodeStats = (Map<String, Object>) (Map) (nodeIdToNodeEventStats);
82-
out.writeMap(downcastedStats);
83-
out.writeMap(downcastedNodeStats);
86+
Map<String, Object> downcastedInfoStats = (Map<String, Object>) (Map) (infoStats);
87+
Map<String, Object> downcastedAggregatedNodeStats = (Map<String, Object>) (Map) (aggregatedNodeStats);
88+
Map<String, Object> downcastedNodeIdToNodeEventStats = (Map<String, Object>) (Map) (nodeIdToNodeEventStats);
89+
90+
out.writeMap(downcastedInfoStats);
91+
out.writeMap(downcastedAggregatedNodeStats);
92+
out.writeMap(downcastedNodeIdToNodeEventStats);
8493
out.writeBoolean(flatten);
8594
out.writeBoolean(includeMetadata);
8695
}
@@ -97,12 +106,18 @@ public List<NeuralStatsNodeResponse> readNodesFrom(StreamInput in) throws IOExce
97106

98107
@Override
99108
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
100-
Map<String, Object> formattedClusterLevelStats = formatStats(clusterLevelStats);
101-
builder.mapContents(formattedClusterLevelStats);
109+
Map<String, Object> formattedInfoStats = formatStats(infoStats);
110+
builder.startObject(INFO_KEY_PREFIX);
111+
builder.mapContents(formattedInfoStats);
112+
builder.endObject();
102113

103-
Map<String, Object> formattedNodeEventStats = formatNodeEventStats(nodeIdToNodeEventStats);
114+
Map<String, Object> formattedAggregatedNodeStats = formatStats(aggregatedNodeStats);
115+
builder.startObject(AGGREGATED_NODES_KEY_PREFIX);
116+
builder.mapContents(formattedAggregatedNodeStats);
117+
builder.endObject();
104118

105-
builder.startObject(NODES_KEY);
119+
Map<String, Object> formattedNodeEventStats = formatNodeEventStats(nodeIdToNodeEventStats);
120+
builder.startObject(NODES_KEY_PREFIX);
106121
builder.mapContents(formattedNodeEventStats);
107122
builder.endObject();
108123

src/main/java/org/opensearch/neuralsearch/transport/NeuralStatsTransportAction.java

+15-13
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ public class NeuralStatsTransportAction extends TransportNodesAction<
3636
NeuralStatsResponse,
3737
NeuralStatsNodeRequest,
3838
NeuralStatsNodeResponse> {
39-
private static final String ALL_NODES_PREFIX = "all_nodes";
40-
4139
private final EventStatsManager eventStatsManager;
4240
private final InfoStatsManager infoStatsManager;
4341

@@ -86,30 +84,25 @@ protected NeuralStatsResponse newResponse(
8684
Map<String, Map<String, StatSnapshot<?>>> nodeIdToEventStats = processorNodeEventStatsIntoMap(responses);
8785

8886
// Sum the map to aggregate
89-
Map<String, StatSnapshot<?>> nodeAggregatedEventStats = aggregateNodesResponses(
87+
Map<String, StatSnapshot<?>> aggregatedNodeStats = aggregateNodesResponses(
9088
responses,
9189
request.getNeuralStatsInput().getEventStatNames()
9290
);
9391

94-
// Add aggregate to summed map
95-
resultStats.putAll(nodeAggregatedEventStats);
96-
9792
// Get info stats
98-
Map<InfoStatName, StatSnapshot<?>> stateStats = infoStatsManager.getStats(request.getNeuralStatsInput().getInfoStatNames());
93+
Map<InfoStatName, StatSnapshot<?>> infoStats = infoStatsManager.getStats(request.getNeuralStatsInput().getInfoStatNames());
9994

10095
// Convert stat name keys into flat path strings
101-
Map<String, StatSnapshot<?>> flatStateStats = stateStats.entrySet()
96+
Map<String, StatSnapshot<?>> flatInfoStats = infoStats.entrySet()
10297
.stream()
10398
.collect(Collectors.toMap(entry -> entry.getKey().getFullPath(), Map.Entry::getValue));
10499

105-
// Add to map
106-
resultStats.putAll(flatStateStats);
107-
108100
return new NeuralStatsResponse(
109101
clusterService.getClusterName(),
110102
responses,
111103
failures,
112-
resultStats,
104+
flatInfoStats,
105+
aggregatedNodeStats,
113106
nodeIdToEventStats,
114107
request.getNeuralStatsInput().isFlatten(),
115108
request.getNeuralStatsInput().isIncludeMetadata()
@@ -152,6 +145,11 @@ private Map<String, StatSnapshot<?>> aggregateNodesResponses(
152145
List<NeuralStatsNodeResponse> responses,
153146
EnumSet<EventStatName> statsToRetrieve
154147
) {
148+
// Catch empty nodes responses case.
149+
if (responses == null || responses.isEmpty()) {
150+
return new HashMap<>();
151+
}
152+
155153
// Convert node responses into list of Map<EventStatName, EventStatData>
156154
List<Map<EventStatName, TimestampedEventStatSnapshot>> nodeEventStatsList = responses.stream()
157155
.map(NeuralStatsNodeResponse::getStats)
@@ -169,7 +167,11 @@ private Map<String, StatSnapshot<?>> aggregateNodesResponses(
169167
TimestampedEventStatSnapshot aggregatedEventSnapshots = TimestampedEventStatSnapshot.aggregateEventStatSnapshots(
170168
timestampedEventStatSnapshotCollection
171169
);
172-
aggregatedMap.put(ALL_NODES_PREFIX + "." + eventStatName.getFullPath(), aggregatedEventSnapshots);
170+
171+
// Skip adding null event stats. This happens when a node id parameter is invalid.
172+
if (aggregatedEventSnapshots != null) {
173+
aggregatedMap.put(eventStatName.getFullPath(), aggregatedEventSnapshots);
174+
}
173175
}
174176

175177
return aggregatedMap;

src/test/java/org/opensearch/neuralsearch/rest/RestNeuralStatsActionIT.java

+12-22
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
package org.opensearch.neuralsearch.rest;
66

77
import lombok.extern.log4j.Log4j2;
8-
import org.apache.hc.core5.http.io.entity.EntityUtils;
98
import org.junit.After;
109
import org.junit.Before;
11-
import org.opensearch.client.Response;
1210
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
1311
import org.opensearch.neuralsearch.settings.NeuralSearchSettings;
1412
import org.opensearch.neuralsearch.stats.common.StatSnapshot;
@@ -73,14 +71,12 @@ public void test_textEmbedding() throws Exception {
7371
assertEquals(3, getDocCount(INDEX_NAME));
7472

7573
// Get stats request
76-
Response response;
7774
String responseBody;
7875
Map<String, Object> stats;
7976
List<Map<String, Object>> nodesStats;
8077

81-
response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>());
82-
responseBody = EntityUtils.toString(response.getEntity());
83-
stats = parseStatsResponse(responseBody);
78+
responseBody = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>());
79+
stats = parseInfoStatsResponse(responseBody);
8480
nodesStats = parseNodeStatsResponse(responseBody);
8581

8682
// Parse json to get stats
@@ -92,25 +88,22 @@ public void test_textEmbedding() throws Exception {
9288
updateClusterSettings("plugins.neural_search.stats_enabled", true);
9389

9490
// info stats should persist, event stats should be reset
95-
response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>());
96-
responseBody = EntityUtils.toString(response.getEntity());
97-
stats = parseStatsResponse(responseBody);
91+
responseBody = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>());
92+
stats = parseInfoStatsResponse(responseBody);
9893
nodesStats = parseNodeStatsResponse(responseBody);
9994
assertEquals(0, getNestedValue(nodesStats.getFirst(), EventStatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS));
10095
assertEquals(1, getNestedValue(stats, InfoStatName.TEXT_EMBEDDING_PROCESSORS));
10196
}
10297

10398
public void test_statsFiltering() throws Exception {
104-
Response response = executeNeuralStatRequest(
99+
String responseBody = executeNeuralStatRequest(
105100
new ArrayList<>(),
106101
Arrays.asList(InfoStatName.TEXT_EMBEDDING_PROCESSORS.getNameString())
107102
);
108103

109-
String responseBody = EntityUtils.toString(response.getEntity());
110-
Map<String, Object> stats = parseStatsResponse(responseBody);
104+
Map<String, Object> stats = parseInfoStatsResponse(responseBody);
111105
List<Map<String, Object>> nodesStats = parseNodeStatsResponse(responseBody);
112106

113-
//
114107
assertNull(getNestedValue(nodesStats.getFirst(), EventStatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS.getFullPath()));
115108
assertNotNull(getNestedValue(stats, InfoStatName.TEXT_EMBEDDING_PROCESSORS.getFullPath()));
116109
}
@@ -119,10 +112,9 @@ public void test_flatten() throws Exception {
119112
Map<String, String> params = new HashMap<>();
120113
params.put(RestNeuralStatsAction.FLATTEN_PARAM, "true");
121114

122-
Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>(), params);
115+
String responseBody = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>(), params);
123116

124-
String responseBody = EntityUtils.toString(response.getEntity());
125-
Map<String, Object> stats = parseStatsResponse(responseBody);
117+
Map<String, Object> stats = parseInfoStatsResponse(responseBody);
126118
List<Map<String, Object>> nodesStats = parseNodeStatsResponse(responseBody);
127119

128120
assertNotNull(nodesStats.getFirst().get(EventStatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS.getFullPath()));
@@ -133,9 +125,8 @@ public void test_includeMetadata() throws Exception {
133125
Map<String, String> params = new HashMap<>();
134126
params.put(RestNeuralStatsAction.INCLUDE_METADATA_PARAM, "true");
135127

136-
Response response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>(), params);
137-
String responseBody = EntityUtils.toString(response.getEntity());
138-
Map<String, Object> stats = parseStatsResponse(responseBody);
128+
String responseBody = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>(), params);
129+
Map<String, Object> stats = parseInfoStatsResponse(responseBody);
139130

140131
Object clusterVersionStatMetadata = getNestedValue(stats, InfoStatName.CLUSTER_VERSION.getFullPath());
141132

@@ -152,9 +143,8 @@ public void test_includeMetadata() throws Exception {
152143
// Fetch Without metadata
153144
params.put(RestNeuralStatsAction.INCLUDE_METADATA_PARAM, "false");
154145

155-
response = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>(), params);
156-
responseBody = EntityUtils.toString(response.getEntity());
157-
stats = parseStatsResponse(responseBody);
146+
responseBody = executeNeuralStatRequest(new ArrayList<>(), new ArrayList<>(), params);
147+
stats = parseInfoStatsResponse(responseBody);
158148

159149
// Path value should be the settable value
160150
String valueWithoutMetadata = (String) getNestedValue(stats, InfoStatName.CLUSTER_VERSION.getFullPath());

0 commit comments

Comments
 (0)