diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index af84422df7067..f0fc05c595d6f 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -51,6 +51,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ShardOperationFailedException; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.AliasFilter; @@ -628,7 +629,7 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) { } public void setPhaseResourceUsages() { - String taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get(); + TaskResourceInfo taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get(); searchRequestContext.recordPhaseResourceUsage(taskResourceUsage); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java index adbe19f4a613e..f27cafb8859a4 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java @@ -12,15 +12,8 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; import org.opensearch.common.annotation.InternalApi; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; import java.util.ArrayList; import java.util.EnumMap; import java.util.HashMap; @@ -45,12 +38,12 @@ public class SearchRequestContext { private final SearchRequest searchRequest; private final List phaseResourceUsage; - private final Supplier taskResourceUsageSupplier; + private final Supplier taskResourceUsageSupplier; SearchRequestContext( final SearchRequestOperationsListener searchRequestOperationsListener, final SearchRequest searchRequest, - final Supplier taskResourceUsageSupplier + final Supplier taskResourceUsageSupplier ) { this.searchRequestOperationsListener = searchRequestOperationsListener; this.absoluteStartNanos = System.nanoTime(); @@ -130,24 +123,12 @@ String formattedShardStats() { } } - public Supplier getTaskResourceUsageSupplier() { + public Supplier getTaskResourceUsageSupplier() { return taskResourceUsageSupplier; } - public void recordPhaseResourceUsage(String usage) { - try { - if (usage != null && !usage.isEmpty()) { - XContentParser parser = XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(usage), - MediaTypeRegistry.JSON - ); - this.phaseResourceUsage.add(TaskResourceInfo.PARSER.apply(parser, null)); - } - } catch (IOException e) { - logger.debug("fail to parse phase resource usages: ", e); - } + public void recordPhaseResourceUsage(TaskResourceInfo usage) { + this.phaseResourceUsage.add(usage); } public List getPhaseResourceUsage() { diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 09da8d03a0aeb..6e380775355a2 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -87,6 +87,7 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.telemetry.metrics.MetricsRegistry; import org.opensearch.telemetry.tracing.Span; import org.opensearch.telemetry.tracing.SpanBuilder; @@ -125,7 +126,6 @@ import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH; import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH; import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; -import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; /** * Perform search action @@ -187,6 +187,7 @@ public class TransportSearchAction extends HandledTransportAction) SearchRequest::new); this.client = client; @@ -225,6 +227,7 @@ public TransportSearchAction( clusterService.getClusterSettings() .addSettingsUpdateConsumer(SEARCH_QUERY_METRICS_ENABLED_SETTING, this::setSearchQueryMetricsEnabled); this.tracer = tracer; + this.taskResourceTrackingService = taskResourceTrackingService; } private void setSearchQueryMetricsEnabled(boolean searchQueryMetricsEnabled) { @@ -452,14 +455,10 @@ private void executeRequest( logger, TraceableSearchRequestOperationsListener.create(tracer, requestSpan) ); - SearchRequestContext searchRequestContext = new SearchRequestContext(requestOperationsListeners, originalSearchRequest, () -> { - List taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE); - if (taskResourceUsages != null && taskResourceUsages.size() > 0) { - return taskResourceUsages.get(0); - } - return null; - } - + SearchRequestContext searchRequestContext = new SearchRequestContext( + requestOperationsListeners, + originalSearchRequest, + taskResourceTrackingService::getTaskResourceUsageFromThreadContext ); searchRequestContext.getSearchRequestOperationsListener().onRequestStart(searchRequestContext); diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 403f3c73ab4c1..b98030909422b 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -73,13 +73,7 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.indices.breaker.CircuitBreakerService; -import org.opensearch.core.tasks.resourcetracker.ResourceStats; -import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; -import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo; -import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; -import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; -import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.IndexService; import org.opensearch.index.IndexSettings; @@ -1140,49 +1134,15 @@ private DefaultSearchContext createSearchContext(ReaderContext reader, ShardSear private void writeTaskResourceUsage(SearchShardTask task) { try { - // Get resource usages from when the task started - ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo( - Thread.currentThread().getId(), - ResourceStatsType.WORKER_STATS - ); - if (threadResourceInfo == null) { - return; - } - Map startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo(); - if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) { - return; - } - // Get current resource usages - ResourceUsageMetric[] endValues = taskResourceTrackingService.getResourceUsageMetricsForThread(Thread.currentThread().getId()); - long cpu = -1, mem = -1; - for (ResourceUsageMetric endValue : endValues) { - if (endValue.getStats() == ResourceStats.MEMORY) { - mem = endValue.getValue(); - } else if (endValue.getStats() == ResourceStats.CPU) { - cpu = endValue.getValue(); - } - } - if (cpu == -1 || mem == -1) { - logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem); - return; + // Get current resource usages from when the task started + TaskResourceInfo.Builder builder = taskResourceTrackingService.getCurrentTaskResourceUsageBuilder(task); + if (builder != null) { + // Attach NodeId to taskResourceInfo + TaskResourceInfo taskResourceInfo = builder.setNodeId(clusterService.localNode().getId()).build(); + // Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request. + threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE); + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); } - - // Build task resource usage info - TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setAction(task.getAction()) - .setTaskId(task.getId()) - .setParentTaskId(task.getParentTaskId().getId()) - .setNodeId(clusterService.localNode().getId()) - .setTaskResourceUsage( - new TaskResourceUsage( - cpu - startValues.get(ResourceStats.CPU).getStartValue(), - mem - startValues.get(ResourceStats.MEMORY).getStartValue() - ) - ) - .build(); - - // Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request. - threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE); - threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); } catch (Exception e) { logger.debug("Error during writing task resource usage: ", e); } diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 59e719a3c3250..97d2aa069afff 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -14,6 +14,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchShardTask; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; @@ -22,12 +23,23 @@ import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ConcurrentMapLong; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo; import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; +import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.Collections; @@ -212,7 +224,7 @@ public Map getResourceAwareTasks() { return Collections.unmodifiableMap(resourceAwareTasks); } - public ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { + private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( ResourceStats.MEMORY, threadMXBean.getThreadAllocatedBytes(threadId) @@ -262,6 +274,83 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { return storedContext; } + /** + * Get the current task level resource usage. + * + * @param task {@link SearchShardTask} + * @return {@link TaskResourceInfo.Builder} + */ + public TaskResourceInfo.Builder getCurrentTaskResourceUsageBuilder(SearchShardTask task) { + try { + // Get resource usages from when the task started + ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo( + Thread.currentThread().getId(), + ResourceStatsType.WORKER_STATS + ); + if (threadResourceInfo == null) { + return null; + } + Map startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo(); + if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) { + return null; + } + // Get current resource usages + ResourceUsageMetric[] endValues = getResourceUsageMetricsForThread(Thread.currentThread().getId()); + long cpu = -1, mem = -1; + for (ResourceUsageMetric endValue : endValues) { + if (endValue.getStats() == ResourceStats.MEMORY) { + mem = endValue.getValue(); + } else if (endValue.getStats() == ResourceStats.CPU) { + cpu = endValue.getValue(); + } + } + if (cpu == -1 || mem == -1) { + logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem); + return null; + } + + // Build task resource usage info + return new TaskResourceInfo.Builder().setAction(task.getAction()) + .setTaskId(task.getId()) + .setParentTaskId(task.getParentTaskId().getId()) + .setTaskResourceUsage( + new TaskResourceUsage( + cpu - startValues.get(ResourceStats.CPU).getStartValue(), + mem - startValues.get(ResourceStats.MEMORY).getStartValue() + ) + ); + } catch (Exception e) { + logger.debug("Error during writing task resource usage: ", e); + return null; + } + } + + /** + * Get the task resource usages from {@link ThreadContext} + * + * @return {@link TaskResourceInfo} + */ + public TaskResourceInfo getTaskResourceUsageFromThreadContext() { + List taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE); + if (taskResourceUsages != null && taskResourceUsages.size() > 0) { + String usage = taskResourceUsages.get(0); + try { + if (usage != null && !usage.isEmpty()) { + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(usage), + MediaTypeRegistry.JSON + ); + return TaskResourceInfo.PARSER.apply(parser, null); + } + } catch (IOException e) { + logger.debug("fail to parse phase resource usages: ", e); + } + } + return null; + } + /** * Listener that gets invoked when a task execution completes. */ diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 730f0569f8bc5..27336e86e52b0 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -201,7 +201,7 @@ private AbstractSearchAsyncAction createAction( new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), request, - () -> "" + () -> null ), NoopTracer.INSTANCE ) { diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index c8b17f6f1eacc..6793c6d1925d2 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -2327,7 +2327,8 @@ public void onFailure(final Exception e) { ), NoopMetricsRegistry.INSTANCE, searchRequestOperationsCompositeListenerFactory, - NoopTracer.INSTANCE + NoopTracer.INSTANCE, + new TaskResourceTrackingService(settings, clusterSettings, threadPool) ) ); actions.put(