Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature branch] Add Neural Stats API #1208

Open
wants to merge 22 commits into
base: feature/neural-stats-api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
### Features
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
- Add stats API ([#1208](https://github.com/opensearch-project/neural-search/pull/1208))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_STATS_ENABLED;

import java.util.Arrays;
import java.util.Collection;
Expand All @@ -14,12 +15,22 @@
import java.util.Optional;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableList;
import org.opensearch.action.ActionRequest;
import org.opensearch.neuralsearch.settings.NeuralSearchSettingsAccessor;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.neuralsearch.stats.info.InfoStatsManager;
import org.opensearch.transport.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.IndexScopedSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.settings.SettingsFilter;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -55,15 +66,21 @@
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.neuralsearch.rest.RestNeuralStatsAction;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.transport.NeuralStatsAction;
import org.opensearch.neuralsearch.transport.NeuralStatsTransportAction;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.PipelineServiceUtil;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
Expand All @@ -82,9 +99,13 @@
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin {
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private NeuralSearchSettingsAccessor settingsAccessor;
private PipelineServiceUtil pipelineServiceUtil;
private InfoStatsManager infoStatsManager;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();
public static final String EXPLANATION_RESPONSE_KEY = "explanation_response";
public static final String NEURAL_BASE_URI = "/_plugins/_neural";

@Override
public Collection<Object> createComponents(
Expand All @@ -105,7 +126,11 @@ public Collection<Object> createComponents(
NeuralSparseQueryBuilder.initialize(clientAccessor);
HybridQueryExecutor.initialize(threadPool);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
settingsAccessor = new NeuralSearchSettingsAccessor(clusterService, environment.settings());
pipelineServiceUtil = new PipelineServiceUtil(clusterService);
infoStatsManager = new InfoStatsManager(NeuralSearchClusterUtil.instance(), settingsAccessor, pipelineServiceUtil);
EventStatsManager.instance().initialize(settingsAccessor);
return List.of(clientAccessor, EventStatsManager.instance(), infoStatsManager);
}

@Override
Expand All @@ -117,6 +142,25 @@ public List<QuerySpec<?>> getQueries() {
);
}

@Override
public List<RestHandler> getRestHandlers(
Settings settings,
RestController restController,
ClusterSettings clusterSettings,
IndexScopedSettings indexScopedSettings,
SettingsFilter settingsFilter,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<DiscoveryNodes> nodesInCluster
) {
RestNeuralStatsAction restNeuralStatsAction = new RestNeuralStatsAction(settingsAccessor);
return ImmutableList.of(restNeuralStatsAction);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return Arrays.asList(new ActionHandler<>(NeuralStatsAction.INSTANCE, NeuralStatsTransportAction.class));
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(HybridQueryExecutor.getExecutorBuilder(settings));
Expand Down Expand Up @@ -167,7 +211,7 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR

@Override
public List<Setting<?>> getSettings() {
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS);
return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED, RERANKER_MAX_DOC_FIELDS, NEURAL_STATS_ENABLED);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.neuralsearch.stats.events.EventStatName;

/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
Expand Down Expand Up @@ -47,6 +49,7 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
EventStatsManager.increment(EventStatName.TEXT_EMBEDDING_PROCESSOR_EXECUTIONS);
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.rest;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.neuralsearch.settings.NeuralSearchSettingsAccessor;
import org.opensearch.neuralsearch.stats.NeuralStatsInput;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.info.InfoStatName;
import org.opensearch.neuralsearch.transport.NeuralStatsAction;
import org.opensearch.neuralsearch.transport.NeuralStatsRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestActions;
import org.opensearch.transport.client.node.NodeClient;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_BASE_URI;

/**
* Rest action handler for the neural stats API
* Calculates info stats and aggregates event stats from nodes and returns them in the response
*/
@Log4j2
@AllArgsConstructor
public class RestNeuralStatsAction extends BaseRestHandler {
public static final String FLATTEN_PARAM = "flat_keys";
public static final String INCLUDE_METADATA_PARAM = "include_metadata";
private static final String NAME = "neural_stats_action";

private static final Set<String> EVENT_STAT_NAMES = EnumSet.allOf(EventStatName.class)
.stream()
.map(EventStatName::getNameString)
.map(String::toLowerCase)
.collect(Collectors.toSet());

private static final Set<String> STATE_STAT_NAMES = EnumSet.allOf(InfoStatName.class)
.stream()
.map(InfoStatName::getNameString)
.map(String::toLowerCase)
.collect(Collectors.toSet());

private static final List<Route> ROUTES = ImmutableList.of(
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/"),
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/{nodeId}/stats/{stat}"),
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/"),
new Route(RestRequest.Method.GET, NEURAL_BASE_URI + "/stats/{stat}")
);

private static final Set<String> RESPONSE_PARAMS = ImmutableSet.of("nodeId", "stat");

private NeuralSearchSettingsAccessor settingsAccessor;

@Override
public String getName() {
return NAME;
}

@Override
public List<Route> routes() {
return ROUTES;
}

@Override
protected Set<String> responseParams() {
return RESPONSE_PARAMS;
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
if (settingsAccessor.isStatsEnabled() == false) {
return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.FORBIDDEN, "Stats endpoint is disabled"));
}

// Read inputs and convert to BaseNodesRequest with correct info configured
NeuralStatsRequest neuralStatsRequest = createNeuralStatsRequest(request);

return channel -> client.execute(
NeuralStatsAction.INSTANCE,
neuralStatsRequest,
new RestActions.NodesResponseRestListener<>(channel)
);
}

/**
* Creates a NeuralStatsRequest from a RestRequest
*
* @param request Rest request
* @return NeuralStatsRequest
*/
private NeuralStatsRequest createNeuralStatsRequest(RestRequest request) {
NeuralStatsInput neuralStatsInput = createNeuralStatsInputFromRequestParams(request);
String[] nodeIdsArr = neuralStatsInput.getNodeIds().toArray(new String[0]);

NeuralStatsRequest neuralStatsRequest = new NeuralStatsRequest(nodeIdsArr, neuralStatsInput);
neuralStatsRequest.timeout(request.param("timeout"));

return neuralStatsRequest;
}

NeuralStatsInput createNeuralStatsInputFromRequestParams(RestRequest request) {
NeuralStatsInput neuralStatsInput = new NeuralStatsInput();

// Parse specified nodes
Optional<String[]> nodeIds = splitCommaSeparatedParam(request, "nodeId");
if (nodeIds.isPresent()) {
neuralStatsInput.getNodeIds().addAll(Arrays.asList(nodeIds.get()));
}

// Parse query parameters
boolean flatten = request.paramAsBoolean(FLATTEN_PARAM, false);
neuralStatsInput.setFlatten(flatten);

boolean includeMetadata = request.paramAsBoolean(INCLUDE_METADATA_PARAM, false);
neuralStatsInput.setIncludeMetadata(includeMetadata);

// Determine which stat names to retrieve based on user parameters
Optional<String[]> stats = splitCommaSeparatedParam(request, "stat");

if (stats.isPresent() == false) {
// No specific stats requested, add all stats by default
addAllStats(neuralStatsInput);
return neuralStatsInput;
}

// Process requested stats
boolean anyStatAdded = processRequestedStats(stats.get(), neuralStatsInput);

// If no valid stats were added, fall back to all stats
if (anyStatAdded == false) {
addAllStats(neuralStatsInput);
}

return neuralStatsInput;
}

private boolean processRequestedStats(String[] stats, NeuralStatsInput neuralStatsInput) {
boolean statAdded = false;

for (String stat : stats) {
String normalizedStat = stat.toLowerCase(Locale.ROOT);
if (EVENT_STAT_NAMES.contains(normalizedStat)) {
neuralStatsInput.getEventStatNames().add(EventStatName.from(normalizedStat));
statAdded = true;
} else if (STATE_STAT_NAMES.contains(normalizedStat)) {
neuralStatsInput.getInfoStatNames().add(InfoStatName.from(normalizedStat));
statAdded = true;
}
log.info("Invalid stat name parsed: {}", normalizedStat);

}
return statAdded;
}

private void addAllStats(NeuralStatsInput neuralStatsInput) {
neuralStatsInput.getEventStatNames().addAll(EnumSet.allOf(EventStatName.class));
neuralStatsInput.getInfoStatNames().addAll(EnumSet.allOf(InfoStatName.class));
}

private Optional<String[]> splitCommaSeparatedParam(RestRequest request, String paramName) {
return Optional.ofNullable(request.param(paramName)).map(s -> s.split(","));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,16 @@ public final class NeuralSearchSettings {
50,
Setting.Property.NodeScope
);

/**
* Enables or disables the Stats API and event stat collection.
* If API is called when stats are disabled, the response will 403.
* Event stat increment calls are also treated as no-ops.
*/
public static final Setting<Boolean> NEURAL_STATS_ENABLED = Setting.boolSetting(
"plugins.neural_search.stats_enabled",
false,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.settings;

import lombok.Getter;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;

/**
* Class handles exposing settings related to neural search and manages callbacks when the settings change
*/
public class NeuralSearchSettingsAccessor {
@Getter
private volatile boolean isStatsEnabled;

/**
* Constructor, registers callbacks to update settings
* @param clusterService
* @param settings
*/
public NeuralSearchSettingsAccessor(ClusterService clusterService, Settings settings) {
isStatsEnabled = NeuralSearchSettings.NEURAL_STATS_ENABLED.get(settings);
registerSettingsCallbacks(clusterService);
}

private void registerSettingsCallbacks(ClusterService clusterService) {
clusterService.getClusterSettings().addSettingsUpdateConsumer(NeuralSearchSettings.NEURAL_STATS_ENABLED, value -> {
// If stats are being toggled off, clear and reset all stats
if (isStatsEnabled && (value == false)) {
EventStatsManager.instance().reset();
}
isStatsEnabled = value;
});
}
}
Loading
Loading