Skip to content

Commit

Permalink
move resource usages interactions into TaskResourceTrackingService
Browse files Browse the repository at this point in the history
Signed-off-by: Chenyang Ji <cyji@amazon.com>
  • Loading branch information
ansjcy committed Jun 4, 2024
1 parent e45809b commit 5bf3df4
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ShardOperationFailedException;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.AliasFilter;
Expand Down Expand Up @@ -628,7 +629,7 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) {
}

public void setPhaseResourceUsages() {
String taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get();
TaskResourceInfo taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get();
searchRequestContext.recordPhaseResourceUsage(taskResourceUsage);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,8 @@
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.annotation.InternalApi;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.HashMap;
Expand All @@ -45,12 +38,12 @@ public class SearchRequestContext {

private final SearchRequest searchRequest;
private final List<TaskResourceInfo> phaseResourceUsage;
private final Supplier<String> taskResourceUsageSupplier;
private final Supplier<TaskResourceInfo> taskResourceUsageSupplier;

SearchRequestContext(
final SearchRequestOperationsListener searchRequestOperationsListener,
final SearchRequest searchRequest,
final Supplier<String> taskResourceUsageSupplier
final Supplier<TaskResourceInfo> taskResourceUsageSupplier
) {
this.searchRequestOperationsListener = searchRequestOperationsListener;
this.absoluteStartNanos = System.nanoTime();
Expand Down Expand Up @@ -130,24 +123,12 @@ String formattedShardStats() {
}
}

public Supplier<String> getTaskResourceUsageSupplier() {
public Supplier<TaskResourceInfo> getTaskResourceUsageSupplier() {
return taskResourceUsageSupplier;
}

public void recordPhaseResourceUsage(String usage) {
try {
if (usage != null && !usage.isEmpty()) {
XContentParser parser = XContentHelper.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(usage),
MediaTypeRegistry.JSON
);
this.phaseResourceUsage.add(TaskResourceInfo.PARSER.apply(parser, null));
}
} catch (IOException e) {
logger.debug("fail to parse phase resource usages: ", e);
}
public void recordPhaseResourceUsage(TaskResourceInfo usage) {
this.phaseResourceUsage.add(usage);
}

public List<TaskResourceInfo> getPhaseResourceUsage() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.telemetry.metrics.MetricsRegistry;
import org.opensearch.telemetry.tracing.Span;
import org.opensearch.telemetry.tracing.SpanBuilder;
Expand Down Expand Up @@ -125,7 +126,6 @@
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH;
import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE;

/**
* Perform search action
Expand Down Expand Up @@ -187,6 +187,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
private final MetricsRegistry metricsRegistry;

private SearchQueryCategorizer searchQueryCategorizer;
private TaskResourceTrackingService taskResourceTrackingService;

@Inject
public TransportSearchAction(
Expand All @@ -204,7 +205,8 @@ public TransportSearchAction(
SearchPipelineService searchPipelineService,
MetricsRegistry metricsRegistry,
SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory,
Tracer tracer
Tracer tracer,
TaskResourceTrackingService taskResourceTrackingService
) {
super(SearchAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchRequest>) SearchRequest::new);
this.client = client;
Expand All @@ -225,6 +227,7 @@ public TransportSearchAction(
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(SEARCH_QUERY_METRICS_ENABLED_SETTING, this::setSearchQueryMetricsEnabled);
this.tracer = tracer;
this.taskResourceTrackingService = taskResourceTrackingService;
}

private void setSearchQueryMetricsEnabled(boolean searchQueryMetricsEnabled) {
Expand Down Expand Up @@ -452,14 +455,10 @@ private void executeRequest(
logger,
TraceableSearchRequestOperationsListener.create(tracer, requestSpan)
);
SearchRequestContext searchRequestContext = new SearchRequestContext(requestOperationsListeners, originalSearchRequest, () -> {
List<String> taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE);
if (taskResourceUsages != null && taskResourceUsages.size() > 0) {
return taskResourceUsages.get(0);
}
return null;
}

SearchRequestContext searchRequestContext = new SearchRequestContext(
requestOperationsListeners,
originalSearchRequest,
taskResourceTrackingService::getTaskResourceUsageFromThreadContext
);
searchRequestContext.getSearchRequestOperationsListener().onRequestStart(searchRequestContext);

Expand Down
56 changes: 8 additions & 48 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,7 @@
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.core.tasks.resourcetracker.ResourceStats;
import org.opensearch.core.tasks.resourcetracker.ResourceStatsType;
import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo;
import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric;
import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage;
import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
Expand Down Expand Up @@ -1140,49 +1134,15 @@ private DefaultSearchContext createSearchContext(ReaderContext reader, ShardSear

private void writeTaskResourceUsage(SearchShardTask task) {
try {
// Get resource usages from when the task started
ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo(
Thread.currentThread().getId(),
ResourceStatsType.WORKER_STATS
);
if (threadResourceInfo == null) {
return;
}
Map<ResourceStats, ResourceUsageInfo.ResourceStatsInfo> startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo();
if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) {
return;
}
// Get current resource usages
ResourceUsageMetric[] endValues = taskResourceTrackingService.getResourceUsageMetricsForThread(Thread.currentThread().getId());
long cpu = -1, mem = -1;
for (ResourceUsageMetric endValue : endValues) {
if (endValue.getStats() == ResourceStats.MEMORY) {
mem = endValue.getValue();
} else if (endValue.getStats() == ResourceStats.CPU) {
cpu = endValue.getValue();
}
}
if (cpu == -1 || mem == -1) {
logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem);
return;
// Get current resource usages from when the task started
TaskResourceInfo.Builder builder = taskResourceTrackingService.getCurrentTaskResourceUsageBuilder(task);
if (builder != null) {
// Attach NodeId to taskResourceInfo
TaskResourceInfo taskResourceInfo = builder.setNodeId(clusterService.localNode().getId()).build();
// Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request.
threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE);
threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
}

// Build task resource usage info
TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setAction(task.getAction())
.setTaskId(task.getId())
.setParentTaskId(task.getParentTaskId().getId())
.setNodeId(clusterService.localNode().getId())
.setTaskResourceUsage(
new TaskResourceUsage(
cpu - startValues.get(ResourceStats.CPU).getStartValue(),
mem - startValues.get(ResourceStats.MEMORY).getStartValue()
)
)
.build();

// Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request.
threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE);
threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
} catch (Exception e) {
logger.debug("Error during writing task resource usage: ", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
Expand All @@ -22,12 +23,23 @@
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.common.util.concurrent.ConcurrentMapLong;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.tasks.resourcetracker.ResourceStats;
import org.opensearch.core.tasks.resourcetracker.ResourceStatsType;
import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo;
import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric;
import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;
import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage;
import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -212,7 +224,7 @@ public Map<Long, Task> getResourceAwareTasks() {
return Collections.unmodifiableMap(resourceAwareTasks);
}

public ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) {
private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) {
ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric(
ResourceStats.MEMORY,
threadMXBean.getThreadAllocatedBytes(threadId)
Expand Down Expand Up @@ -262,6 +274,83 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) {
return storedContext;
}

/**
* Get the current task level resource usage.
*
* @param task {@link SearchShardTask}
* @return {@link TaskResourceInfo.Builder}
*/
public TaskResourceInfo.Builder getCurrentTaskResourceUsageBuilder(SearchShardTask task) {
try {
// Get resource usages from when the task started
ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo(
Thread.currentThread().getId(),
ResourceStatsType.WORKER_STATS
);
if (threadResourceInfo == null) {
return null;
}
Map<ResourceStats, ResourceUsageInfo.ResourceStatsInfo> startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo();
if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) {
return null;
}
// Get current resource usages
ResourceUsageMetric[] endValues = getResourceUsageMetricsForThread(Thread.currentThread().getId());
long cpu = -1, mem = -1;
for (ResourceUsageMetric endValue : endValues) {
if (endValue.getStats() == ResourceStats.MEMORY) {
mem = endValue.getValue();
} else if (endValue.getStats() == ResourceStats.CPU) {
cpu = endValue.getValue();
}
}
if (cpu == -1 || mem == -1) {
logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem);
return null;
}

// Build task resource usage info
return new TaskResourceInfo.Builder().setAction(task.getAction())
.setTaskId(task.getId())
.setParentTaskId(task.getParentTaskId().getId())
.setTaskResourceUsage(
new TaskResourceUsage(
cpu - startValues.get(ResourceStats.CPU).getStartValue(),
mem - startValues.get(ResourceStats.MEMORY).getStartValue()
)
);
} catch (Exception e) {
logger.debug("Error during writing task resource usage: ", e);
return null;
}
}

/**
* Get the task resource usages from {@link ThreadContext}
*
* @return {@link TaskResourceInfo}
*/
public TaskResourceInfo getTaskResourceUsageFromThreadContext() {
List<String> taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE);
if (taskResourceUsages != null && taskResourceUsages.size() > 0) {
String usage = taskResourceUsages.get(0);
try {
if (usage != null && !usage.isEmpty()) {
XContentParser parser = XContentHelper.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(usage),
MediaTypeRegistry.JSON
);
return TaskResourceInfo.PARSER.apply(parser, null);
}
} catch (IOException e) {
logger.debug("fail to parse phase resource usages: ", e);
}
}
return null;
}

/**
* Listener that gets invoked when a task execution completes.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request,
() -> ""
() -> null
),
NoopTracer.INSTANCE
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2327,7 +2327,8 @@ public void onFailure(final Exception e) {
),
NoopMetricsRegistry.INSTANCE,
searchRequestOperationsCompositeListenerFactory,
NoopTracer.INSTANCE
NoopTracer.INSTANCE,
new TaskResourceTrackingService(settings, clusterSettings, threadPool)
)
);
actions.put(
Expand Down

0 comments on commit 5bf3df4

Please sign in to comment.