From 66e1aeecd1389dc789ccbd288ebae8069057cbf2 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Wed, 23 Mar 2022 20:19:45 +0530 Subject: [PATCH 01/26] Add Task id in Thread Context Signed-off-by: Tushar Kharbanda --- .../admin/cluster/node/tasks/TasksIT.java | 6 +++ .../action/support/TransportAction.java | 28 ++++++---- .../org/opensearch/tasks/TaskManager.java | 34 ++++++++++++ .../transport/RequestHandlerRegistry.java | 4 ++ .../tasks/RecordingTaskManagerListener.java | 3 ++ .../node/tasks/TransportTasksActionTests.java | 52 +++++++++++++++++++ .../bulk/TransportBulkActionIngestTests.java | 3 +- .../opensearch/tasks/TaskManagerTests.java | 32 ++++++++++++ .../test/tasks/MockTaskManager.java | 16 ++++++ .../test/tasks/MockTaskManagerListener.java | 7 +++ 10 files changed, 174 insertions(+), 11 deletions(-) diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java index ac0ae44eb732e..4042dc27338fc 100644 --- a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java @@ -470,6 +470,9 @@ public void onTaskUnregistered(Task task) {} @Override public void waitForTaskCompletion(Task task) {} + + @Override + public void onThreadContextUpdate(Task task, Boolean taskIdAdded) {} }); } // Need to run the task in a separate thread because node client's .execute() is blocked by our task listener @@ -651,6 +654,9 @@ public void waitForTaskCompletion(Task task) { waitForWaitingToStart.countDown(); } + @Override + public void onThreadContextUpdate(Task task, Boolean taskIdAdded) {} + @Override public void onTaskRegistered(Task task) {} diff --git a/server/src/main/java/org/opensearch/action/support/TransportAction.java b/server/src/main/java/org/opensearch/action/support/TransportAction.java index 84ece8cfec530..97b975c255d2b 100644 --- a/server/src/main/java/org/opensearch/action/support/TransportAction.java +++ b/server/src/main/java/org/opensearch/action/support/TransportAction.java @@ -40,6 +40,7 @@ import org.opensearch.action.ActionResponse; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; @@ -155,18 +156,25 @@ public void onFailure(Exception e) { * Use this method when the transport action should continue to run in the context of the current task */ public final void execute(Task task, Request request, ActionListener listener) { - ActionRequestValidationException validationException = request.validate(); - if (validationException != null) { - listener.onFailure(validationException); - return; - } + ThreadContext.StoredContext storedContext = taskManager.addTaskIdInThreadContext(task); - if (task != null && request.getShouldStoreResult()) { - listener = new TaskResultStoringActionListener<>(taskManager, task, listener); - } + try { + ActionRequestValidationException validationException = request.validate(); + if (validationException != null) { + listener.onFailure(validationException); + return; + } - RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger); - requestFilterChain.proceed(task, actionName, request, listener); + if (task != null && request.getShouldStoreResult()) { + listener = new TaskResultStoringActionListener<>(taskManager, task, listener); + } + + RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger); + requestFilterChain.proceed(task, actionName, request, listener); + + } finally { + storedContext.restore(); + } } protected abstract void doExecute(Task task, Request request, ActionListener listener); diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 1f6169768f245..1ea9106ec8439 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -48,6 +48,7 @@ import org.opensearch.cluster.ClusterStateApplier; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.Nullable; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.common.settings.Settings; @@ -85,6 +86,8 @@ */ public class TaskManager implements ClusterStateApplier { + public static final String TASK_ID = "TASK_ID"; + private static final Logger logger = LogManager.getLogger(TaskManager.class); private static final TimeValue WAIT_FOR_COMPLETION_POLL = timeValueMillis(100); @@ -448,6 +451,37 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { throw new OpenSearchTimeoutException("Timed out waiting for completion of [{}]", task); } + /** + * Adds Task Id in the ThreadContext. + *

+ * Stashes the existing ThreadContext and preserves all the existing ThreadContext's data in the new ThreadContext + * as well. + * + * @param task for which Task Id needs to be added in ThreadContext. + * @return StoredContext reference to restore the ThreadContext from which we created a new one. + * Caller can call context.restore() to get the existing ThreadContext back. + */ + public ThreadContext.StoredContext addTaskIdInThreadContext(@Nullable Task task) { + if (task == null) { + return () -> {}; + } + + ThreadContext threadContext = threadPool.getThreadContext(); + + if (threadContext.getTransient(TASK_ID) != null) { + logger.warn( + "Task Id already present in the thread context. Thread Id: {}, Existing Task Id: {}, New Task Id: {}. Overwriting", + Thread.currentThread().getId(), + threadContext.getTransient(TASK_ID), + task.getId() + ); + } + + ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID)); + threadContext.putTransient(TASK_ID, task.getId()); + return storedContext; + } + private static class CancellableTaskHolder { private final CancellableTask task; private boolean finished = false; diff --git a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java index dcb021531f0ac..d5de970f728e3 100644 --- a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java +++ b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java @@ -37,6 +37,7 @@ import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskManager; @@ -81,6 +82,8 @@ public Request newRequest(StreamInput in) throws IOException { public void processMessageReceived(Request request, TransportChannel channel) throws Exception { final Task task = taskManager.register(channel.getChannelType(), action, request); + ThreadContext.StoredContext storedContext = taskManager.addTaskIdInThreadContext(task); + Releasable unregisterTask = () -> taskManager.unregister(task); try { if (channel instanceof TcpTransportChannel && task instanceof CancellableTask) { @@ -99,6 +102,7 @@ public void processMessageReceived(Request request, TransportChannel channel) th unregisterTask = null; } finally { Releasables.close(unregisterTask); + storedContext.restore(); } } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java index 7756eb12bb3f4..7c35a3c79d66f 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java @@ -75,6 +75,9 @@ public synchronized void onTaskUnregistered(Task task) { @Override public void waitForTaskCompletion(Task task) {} + @Override + public void onThreadContextUpdate(Task task, Boolean taskIdAdded) {} + public synchronized List> getEvents() { return Collections.unmodifiableList(new ArrayList<>(events)); } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index 7590bf88eeca0..fb7734c5f6199 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -66,6 +66,7 @@ import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.test.tasks.MockTaskManagerListener; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -82,12 +83,14 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.opensearch.action.support.PlainActionFuture.newFuture; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.not; +import static org.opensearch.tasks.TaskManager.TASK_ID; public class TransportTasksActionTests extends TaskManagerTestCase { @@ -648,6 +651,55 @@ protected void taskOperation(TestTasksRequest request, Task task, ActionListener assertEquals(0, responses.failureCount()); } + public void testTaskIdPersistsInThreadContext() { + Settings settings = Settings.builder().put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), true).build(); + setupTestNodes(settings); + connectNodes(testNodes); + + final List taskIdsAddedToThreadContext = new ArrayList<>(); + final List taskIdsRemovedFromThreadContext = new ArrayList<>(); + final long[] actualTaskIdInThreadContext = new long[1]; + final long[] expectedTaskIdInThreadContext = new long[1]; + + ((MockTaskManager) testNodes[0].transportService.getTaskManager()).addListener(new MockTaskManagerListener() { + @Override + public void waitForTaskCompletion(Task task) {} + + @Override + public void onThreadContextUpdate(Task task, Boolean taskIdAdded) { + if (taskIdAdded) { + taskIdsAddedToThreadContext.add(task.getId()); + } else { + taskIdsRemovedFromThreadContext.add(task.getId()); + } + } + + @Override + public void onTaskRegistered(Task task) {} + + @Override + public void onTaskUnregistered(Task task) { + if (task.getAction().equals("action1")) { + expectedTaskIdInThreadContext[0] = task.getId(); + actualTaskIdInThreadContext[0] = threadPool.getThreadContext().getTransient(TASK_ID); + } + } + }); + + TestTasksAction action = new TestTasksAction("action1", testNodes[0].clusterService, testNodes[0].transportService) { + @Override + protected void taskOperation(TestTasksRequest request, Task task, ActionListener listener) { + listener.onResponse(new TestTaskResponse(testNodes[0].getNodeId())); + } + }; + TestTasksRequest testTasksRequest = new TestTasksRequest(); + testTasksRequest.setActions("action1"); + ActionTestUtils.executeBlocking(action, testTasksRequest); + + assertEquals(expectedTaskIdInThreadContext[0], actualTaskIdInThreadContext[0]); + assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray())); + } + /** * This test starts nodes actions that blocks on all nodes. While node actions are blocked in the middle of execution * it executes a tasks action that targets these blocked node actions. The test verifies that task actions are only diff --git a/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java index 4b98870422ce8..202f1b7dcb5b4 100644 --- a/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/opensearch/action/bulk/TransportBulkActionIngestTests.java @@ -91,6 +91,7 @@ import static java.util.Collections.emptyMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.sameInstance; +import static org.mockito.Answers.RETURNS_MOCKS; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyString; @@ -224,7 +225,7 @@ public void setupAction() { remoteResponseHandler = ArgumentCaptor.forClass(TransportResponseHandler.class); // setup services that will be called by action - transportService = mock(TransportService.class); + transportService = mock(TransportService.class, RETURNS_MOCKS); clusterService = mock(ClusterService.class); localIngest = true; // setup nodes for local and remote diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 0f09b0de34206..d8ed4b81973f8 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -39,6 +39,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ConcurrentCollections; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -64,6 +65,7 @@ import static org.hamcrest.Matchers.everyItem; import static org.hamcrest.Matchers.in; import static org.mockito.Mockito.mock; +import static org.opensearch.tasks.TaskManager.TASK_ID; public class TaskManagerTests extends OpenSearchTestCase { private ThreadPool threadPool; @@ -91,6 +93,36 @@ public void testResultsServiceRetryTotalTime() { assertEquals(600000L, total); } + public void testAddTaskIdToThreadContext() { + final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); + final Task task = taskManager.register("transport", "test", new CancellableRequest("1")); + String key = "KEY"; + String value = "VALUE"; + + // Prepare thread context + threadPool.getThreadContext().putHeader(key, value); + threadPool.getThreadContext().putTransient(key, value); + threadPool.getThreadContext().addResponseHeader(key, value); + + ThreadContext.StoredContext storedContext = taskManager.addTaskIdInThreadContext(task); + + // All headers should be preserved and Task Id should also be included in thread context + verifyThreadContextFixedHeaders(key, value); + assertEquals((long) threadPool.getThreadContext().getTransient(TASK_ID), task.getId()); + + storedContext.restore(); + + // Post restore only task id should be removed from the thread context + verifyThreadContextFixedHeaders(key, value); + assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); + } + + private void verifyThreadContextFixedHeaders(String key, String value) { + assertEquals(threadPool.getThreadContext().getHeader(key), value); + assertEquals(threadPool.getThreadContext().getTransient(key), value); + assertEquals(threadPool.getThreadContext().getResponseHeaders().get(key).get(0), value); + } + public void testTrackingChannelTask() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); Set cancelledTasks = ConcurrentCollections.newConcurrentSet(); diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java index e60871f67ea54..d3f360e0d5414 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java @@ -39,6 +39,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Setting.Property; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskAwareRequest; import org.opensearch.tasks.TaskManager; @@ -127,6 +128,21 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { super.waitForTaskCompletion(task, untilInNanos); } + @Override + public ThreadContext.StoredContext addTaskIdInThreadContext(Task task) { + for (MockTaskManagerListener listener : listeners) { + listener.onThreadContextUpdate(task, true); + } + + ThreadContext.StoredContext storedContext = super.addTaskIdInThreadContext(task); + return () -> { + for (MockTaskManagerListener listener : listeners) { + listener.onThreadContextUpdate(task, false); + } + storedContext.restore(); + }; + } + public void addListener(MockTaskManagerListener listener) { listeners.add(listener); } diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java index eb8361ac552fc..1736695d6e33e 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java @@ -43,4 +43,11 @@ public interface MockTaskManagerListener { void onTaskUnregistered(Task task); void waitForTaskCompletion(Task task); + + /** + * + * @param taskIdAdded if false then task id is removed from the thread context. Null if no change + */ + void onThreadContextUpdate(Task task, Boolean taskIdAdded); + } From 84cc493aa2a1ff955e32b9004fea69b5b2420fe4 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Wed, 23 Mar 2022 20:25:08 +0530 Subject: [PATCH 02/26] Add resource tracking update support for tasks Signed-off-by: Tushar Kharbanda --- .../action/search/SearchShardTask.java | 5 + .../opensearch/action/search/SearchTask.java | 5 + .../common/settings/ClusterSettings.java | 4 +- .../util/concurrent/OpenSearchExecutors.java | 17 +- .../main/java/org/opensearch/node/Node.java | 4 + .../main/java/org/opensearch/tasks/Task.java | 15 +- .../org/opensearch/tasks/TaskManager.java | 84 ++- .../tasks/TaskResourceStatsUtil.java | 0 .../RunnableTaskExecutionListener.java | 33 ++ .../threadpool/TaskAwareRunnable.java | 85 +++ .../transport/TransportService.java | 4 + .../node/tasks/ResourceAwareTasksTests.java | 519 ++++++++++++++++++ 12 files changed, 763 insertions(+), 12 deletions(-) create mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceStatsUtil.java create mode 100644 server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java create mode 100644 server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java create mode 100644 server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java diff --git a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java index 2e506c6fe181b..f09701c7769eb 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java @@ -49,6 +49,11 @@ public SearchShardTask(long id, String type, String action, String description, super(id, type, action, description, parentTaskId, headers); } + @Override + public boolean supportsResourceTracking() { + return true; + } + @Override public boolean shouldCancelChildrenOnCancellation() { return false; diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index 7f80f7836be6c..bf6f141a3e829 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -78,6 +78,11 @@ public final String getDescription() { return descriptionSupplier.get(); } + @Override + public boolean supportsResourceTracking() { + return true; + } + /** * Attach a {@link SearchProgressListener} to this task. */ diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index c758b7d2918e7..d79de1f9a4179 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -40,6 +40,7 @@ import org.opensearch.index.ShardIndexingPressureMemoryManager; import org.opensearch.index.ShardIndexingPressureSettings; import org.opensearch.index.ShardIndexingPressureStore; +import org.opensearch.tasks.TaskManager; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction; import org.opensearch.action.admin.indices.close.TransportCloseIndexAction; @@ -568,7 +569,8 @@ public void apply(Settings value, Settings current, Settings previous) { ShardIndexingPressureMemoryManager.THROUGHPUT_DEGRADATION_LIMITS, ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT, ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS, - IndexingPressure.MAX_INDEXING_BYTES + IndexingPressure.MAX_INDEXING_BYTES, + TaskManager.TASK_RESOURCE_TRACKING_ENABLED ) ) ); diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java index 5a967528a6ae2..f44464d95efd4 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java @@ -40,6 +40,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.node.Node; +import org.opensearch.threadpool.TaskAwareRunnable; import java.util.List; import java.util.Optional; @@ -175,11 +176,11 @@ public static OpenSearchThreadPoolExecutor newFixed( /** * Return a new executor that will automatically adjust the queue size based on queue throughput. * - * @param size number of fixed threads to use for executing tasks + * @param size number of fixed threads to use for executing tasks * @param initialQueueCapacity initial size of the executor queue - * @param minQueueSize minimum queue size that the queue can be adjusted to - * @param maxQueueSize maximum queue size that the queue can be adjusted to - * @param frameSize number of tasks during which stats are collected before adjusting queue size + * @param minQueueSize minimum queue size that the queue can be adjusted to + * @param maxQueueSize maximum queue size that the queue can be adjusted to + * @param frameSize number of tasks during which stats are collected before adjusting queue size */ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( String name, @@ -201,6 +202,12 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( ConcurrentCollections.newBlockingQueue(), initialQueueCapacity ); + + Function runnableWrapper = (runnable) -> { + TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable); + return new TimedRunnable(taskAwareRunnable); + }; + return new QueueResizingOpenSearchThreadPoolExecutor( name, size, @@ -210,7 +217,7 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( queue, minQueueSize, maxQueueSize, - TimedRunnable::new, + runnableWrapper, frameSize, targetedResponseTime, threadFactory, diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 8ede6fdf76653..da28c8edcefe9 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -37,6 +37,7 @@ import org.apache.lucene.util.Constants; import org.apache.lucene.util.SetOnce; import org.opensearch.index.IndexingPressureService; +import org.opensearch.threadpool.TaskAwareRunnable; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.Assertions; import org.opensearch.Build; @@ -1057,6 +1058,9 @@ public Node start() throws NodeValidationException { TransportService transportService = injector.getInstance(TransportService.class); transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class)); transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService)); + + TaskAwareRunnable.setListener(transportService.getTaskManager()); + transportService.start(); assert localNodeFactory.getNode() != null; assert transportService.getLocalNode().equals(localNodeFactory.getNode()) diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 62453d08724ce..9aad853b070ba 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -32,8 +32,6 @@ package org.opensearch.tasks; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.NamedWriteable; @@ -53,8 +51,6 @@ */ public class Task { - private static final Logger logger = LogManager.getLogger(Task.class); - /** * The request header to mark tasks with specific ids */ @@ -336,6 +332,17 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp throw new IllegalStateException("cannot update final values if active thread resource entry is not present"); } + /** + * Individual tasks can override this if they want to support task resource tracking. We just need to make sure that + * the ThreadPool on which the task runs on have runnable wrapper similar to + * {@link org.opensearch.common.util.concurrent.OpenSearchExecutors#newAutoQueueFixed} + * + * @return true if resource tracking is supported by the task + */ + public boolean supportsResourceTracking() { + return false; + } + /** * Report of the internal status of a task. These can vary wildly from task * to task because each task is implemented differently but we should try diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 1ea9106ec8439..cfd2873a79bb0 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -34,6 +34,7 @@ import com.carrotsearch.hppc.ObjectIntHashMap; import com.carrotsearch.hppc.ObjectIntMap; +import com.sun.management.ThreadMXBean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -51,6 +52,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.ByteSizeValue; import org.opensearch.common.unit.TimeValue; @@ -58,10 +60,12 @@ import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ConcurrentMapLong; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TcpChannel; import java.io.IOException; +import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -80,11 +84,19 @@ import static org.opensearch.common.unit.TimeValue.timeValueMillis; import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; +import static org.opensearch.tasks.ResourceStatsType.WORKER_STATS; /** * Task Manager service for keeping track of currently running tasks on the nodes */ -public class TaskManager implements ClusterStateApplier { +public class TaskManager implements ClusterStateApplier, RunnableTaskExecutionListener { + + public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( + "task_resource_tracking.enabled", + false, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); public static final String TASK_ID = "TASK_ID"; @@ -92,7 +104,11 @@ public class TaskManager implements ClusterStateApplier { private static final TimeValue WAIT_FOR_COMPLETION_POLL = timeValueMillis(100); - /** Rest headers that are copied to the task */ + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); + + /** + * Rest headers that are copied to the task + */ private final List taskHeaders; private final ThreadPool threadPool; @@ -110,13 +126,23 @@ public class TaskManager implements ClusterStateApplier { private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES; private final ByteSizeValue maxHeaderSize; + private volatile boolean taskResourceTrackingEnabled; private final Map channelPendingTaskTrackers = ConcurrentCollections.newConcurrentMap(); private final SetOnce cancellationService = new SetOnce<>(); + private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + public TaskManager(Settings settings, ThreadPool threadPool, Set taskHeaders) { this.threadPool = threadPool; this.taskHeaders = new ArrayList<>(taskHeaders); this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); + this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); + } + + public boolean isTaskResourceTrackingEnabled() { + return taskResourceTrackingEnabled + && threadMXBean.isThreadAllocatedMemorySupported() + && threadMXBean.isThreadAllocatedMemoryEnabled(); } public void setTaskResultsService(TaskResultsService taskResultsService) { @@ -153,6 +179,10 @@ public Task register(String type, String action, TaskAwareRequest request) { logger.trace("register {} [{}] [{}] [{}]", task.getId(), type, action, task.getDescription()); } + if (task.supportsResourceTracking() && isTaskResourceTrackingEnabled()) { + resourceAwareTasks.put(task.getId(), task); + } + if (task instanceof CancellableTask) { registerCancellableTask(task); } else { @@ -205,6 +235,11 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { */ public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); + + if (task.supportsResourceTracking()) { + resourceAwareTasks.remove(task.getId(), task); + } + if (task instanceof CancellableTask) { CancellableTaskHolder holder = cancellableTasks.remove(task.getId()); if (holder != null) { @@ -327,6 +362,10 @@ public Map getCancellableTasks() { return Collections.unmodifiableMap(taskHashMap); } + public Map getResourceAwareTasks() { + return Collections.unmodifiableMap(resourceAwareTasks); + } + /** * Returns a task with given id, or null if the task is not found. */ @@ -364,6 +403,7 @@ public int getBanCount() { * Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing. *

* This method is called when a parent task that has children is cancelled. + * * @return a list of pending cancellable child tasks */ public List setBan(TaskId parentTaskId, String reason) { @@ -482,6 +522,46 @@ public ThreadContext.StoredContext addTaskIdInThreadContext(@Nullable Task task) return storedContext; } + /** + * Called when a thread starts working on a task's runnable. + * + * @param taskId of the task for which runnable is starting + * @param threadId of the thread which will be executing the runnable and we need to check resource usage for this + * thread + */ + @Override + public void taskExecutionStartedOnThread(long taskId, long threadId) { + if (resourceAwareTasks.containsKey(taskId)) { + resourceAwareTasks.get(taskId).startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } + + /** + * Called when a thread finishes working on a task's runnable. + * + * @param taskId of the task for which runnable is complete + * @param threadId of the thread which executed the runnable and we need to check resource usage for this thread + */ + @Override + public void taskExecutionFinishedOnThread(long taskId, long threadId) { + if (resourceAwareTasks.containsKey(taskId)) { + resourceAwareTasks.get(taskId).stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } + + private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { + ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( + ResourceStats.MEMORY, + threadMXBean.getThreadAllocatedBytes(threadId) + ); + ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); + return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; + } + + public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { + this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; + } + private static class CancellableTaskHolder { private final CancellableTask task; private boolean finished = false; diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceStatsUtil.java b/server/src/main/java/org/opensearch/tasks/TaskResourceStatsUtil.java new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java b/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java new file mode 100644 index 0000000000000..03cd66f80d044 --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/RunnableTaskExecutionListener.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +/** + * Listener for events when a runnable execution starts or finishes on a thread and is aware of the task for which the + * runnable is associated to. + */ +public interface RunnableTaskExecutionListener { + + /** + * Sends an update when ever a task's execution start on a thread + * + * @param taskId of task which has started + * @param threadId of thread which is executing the task + */ + void taskExecutionStartedOnThread(long taskId, long threadId); + + /** + * + * Sends an update when task execution finishes on a thread + * + * @param taskId of task which has finished + * @param threadId of thread which executed the task + */ + void taskExecutionFinishedOnThread(long taskId, long threadId); +} diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java new file mode 100644 index 0000000000000..1500a8bd7fdd0 --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.concurrent.WrappedRunnable; + +import java.util.Objects; + +import static java.lang.Thread.currentThread; +import static org.opensearch.tasks.TaskManager.TASK_ID; + +/** + * Responsible for wrapping the original task's runnable and sending updates on when it starts and finishes to + * entities listening to the events. + * + * It's able to associate runnable with a task with the help of task Id available in thread context. + */ +public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnable { + + private static RunnableTaskExecutionListener listener; + + private final Runnable original; + private final ThreadContext threadContext; + + public TaskAwareRunnable(ThreadContext threadContext, final Runnable original) { + this.original = original; + this.threadContext = threadContext; + } + + public static void setListener(RunnableTaskExecutionListener l) { + listener = l; + } + + @Override + public void onFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + + @Override + public boolean isForceExecution() { + return original instanceof AbstractRunnable && ((AbstractRunnable) original).isForceExecution(); + } + + @Override + public void onRejection(final Exception e) { + if (original instanceof AbstractRunnable) { + ((AbstractRunnable) original).onRejection(e); + } else { + ExceptionsHelper.reThrowIfNotNull(e); + } + } + + @Override + protected void doRun() throws Exception { + assert listener != null : "Listener should be attached"; + + Long taskId = threadContext.getTransient(TASK_ID); + + if (Objects.nonNull(taskId)) { + listener.taskExecutionStartedOnThread(taskId, currentThread().getId()); + } + try { + original.run(); + } finally { + if (Objects.nonNull(taskId)) { + listener.taskExecutionFinishedOnThread(taskId, currentThread().getId()); + } + } + + } + + @Override + public Runnable unwrap() { + return original; + } +} diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index ad46ed742c806..2fa9c515b98fa 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -218,6 +218,10 @@ public TransportService( remoteClusterService.listenForUpdates(clusterSettings); } clusterSettings.addSettingsUpdateConsumer(TransportSettings.SLOW_OPERATION_THRESHOLD_SETTING, transport::setSlowLogThreshold); + clusterSettings.addSettingsUpdateConsumer( + TaskManager.TASK_RESOURCE_TRACKING_ENABLED, + taskManager::setTaskResourceTrackingEnabled + ); } registerRequestHandler( HANDSHAKE_ACTION_NAME, diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java new file mode 100644 index 0000000000000..9c83d55bf6e9a --- /dev/null +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -0,0 +1,519 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.admin.cluster.node.tasks; + +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse; +import org.opensearch.action.support.ActionTestUtils; +import org.opensearch.action.support.nodes.BaseNodeRequest; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancelledException; +import org.opensearch.tasks.TaskId; +import org.opensearch.tasks.TaskManager; +import org.opensearch.threadpool.TaskAwareRunnable; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +public class ResourceAwareTasksTests extends TaskManagerTestCase { + + // For every task there's a general overhead before and after the actual task operation code is executed. + // This includes things like creating threadContext, Transport Channel, Tracking task cancellation etc. + // For the tasks used for this test that maximum memory overhead can be 450Kb + private static final int TASK_MAX_GENERAL_MEMORY_OVERHEAD = 450000; + + public static class ResourceAwareNodeRequest extends BaseNodeRequest { + protected String requestName; + + public ResourceAwareNodeRequest() { + super(); + } + + public ResourceAwareNodeRequest(StreamInput in) throws IOException { + super(in); + requestName = in.readString(); + } + + public ResourceAwareNodeRequest(NodesRequest request) { + requestName = request.requestName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestName); + } + + @Override + public String getDescription() { + return "ResourceAwareNodeRequest[" + requestName + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return false; + } + + @Override + public boolean supportsResourceTracking() { + return true; + } + }; + } + } + + public static class NodesRequest extends BaseNodesRequest { + private final String requestName; + + private NodesRequest(StreamInput in) throws IOException { + super(in); + requestName = in.readString(); + } + + public NodesRequest(String requestName, String... nodesIds) { + super(nodesIds); + this.requestName = requestName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(requestName); + } + + @Override + public String getDescription() { + return "NodesRequest[" + requestName + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + }; + } + } + + /** + * Simulates a task which executes work on search executor. + */ + class ResourceAwareNodesAction extends AbstractTestNodesAction { + + private final TaskTestContext taskTestContext; + private final boolean blockForCancellation; + + ResourceAwareNodesAction( + String actionName, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + boolean shouldBlock, + TaskTestContext taskTestContext + ) { + super(actionName, threadPool, clusterService, transportService, NodesRequest::new, ResourceAwareNodeRequest::new); + this.taskTestContext = taskTestContext; + this.blockForCancellation = shouldBlock; + } + + @Override + protected ResourceAwareNodeRequest newNodeRequest(NodesRequest request) { + return new ResourceAwareNodeRequest(request); + } + + @Override + protected NodeResponse nodeOperation(ResourceAwareNodeRequest request, Task task) { + assert task.supportsResourceTracking(); + + AtomicLong threadId = new AtomicLong(); + Future result = threadPool.executor(ThreadPool.Names.SEARCH).submit(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + ExceptionsHelper.reThrowIfNotNull(e); + } + + @Override + protected void doRun() { + threadId.set(Thread.currentThread().getId()); + + if (taskTestContext.operationStartValidator != null) { + try { + taskTestContext.operationStartValidator.accept(threadId.get()); + } catch (AssertionError error) { + throw new RuntimeException(error); + } + } + + Object[] allocation1 = new Object[1000000]; // 4MB + + if (blockForCancellation) { + // Simulate a job that takes forever to finish + // Using periodic checks method to identify that the task was cancelled + try { + boolean taskCancelled = waitUntil(((CancellableTask) task)::isCancelled); + if (taskCancelled) { + throw new TaskCancelledException("Task Cancelled"); + } else { + fail("It should have thrown an exception"); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + + } + + Object[] allocation2 = new Object[1000000]; // 4MB + } + }); + + try { + result.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e.getCause()); + } + if (taskTestContext.operationFinishedValidator != null) { + taskTestContext.operationFinishedValidator.accept(threadId.get()); + } + + return new NodeResponse(clusterService.localNode()); + } + + @Override + protected NodeResponse nodeOperation(ResourceAwareNodeRequest request) { + throw new UnsupportedOperationException("the task parameter is required"); + } + } + + private TaskTestContext startResourceAwareNodesAction( + TestNode node, + boolean blockForCancellation, + TaskTestContext taskTestContext, + ActionListener listener + ) { + NodesRequest request = new NodesRequest("Test Request", node.getNodeId()); + + taskTestContext.requestCompleteLatch = new CountDownLatch(1); + + ResourceAwareNodesAction action = new ResourceAwareNodesAction( + "internal:resourceAction", + threadPool, + node.clusterService, + node.transportService, + blockForCancellation, + taskTestContext + ); + taskTestContext.mainTask = action.execute(request, listener); + return taskTestContext; + } + + private static class TaskTestContext { + private Task mainTask; + private CountDownLatch requestCompleteLatch; + private Consumer operationStartValidator; + private Consumer operationFinishedValidator; + } + + public void testBasicTaskResourceTracking() throws Exception { + setup(true); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + }; + + taskTestContext.operationFinishedValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // Thread has finished working on the task's runnable + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + long expectedArrayAllocationOverhead = 2 * 4012688; // Task's memory overhead due to array allocations + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + assertTrue(Math.abs(actualTaskMemoryOverhead - expectedArrayAllocationOverhead) < TASK_MAX_GENERAL_MEMORY_OVERHEAD); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { + setup(true); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + }; + + taskTestContext.operationFinishedValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + + // Thread has finished working on the task's runnable + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + long expectedArrayAllocationOverhead = 4012688; // Task's memory overhead due to array allocations. Only one out of 2 + // allocations are completed before the task is cancelled + long actualArrayAllocationOverhead = task.getTotalResourceStats().getMemoryInBytes(); + + assertTrue(Math.abs(actualArrayAllocationOverhead - expectedArrayAllocationOverhead) < TASK_MAX_GENERAL_MEMORY_OVERHEAD); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], true, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Cancel main task + CancelTasksRequest request = new CancelTasksRequest(); + request.setReason("Cancelling request to verify Task resource tracking behaviour"); + request.setTaskId(new TaskId(testNodes[0].getNodeId(), taskTestContext.mainTask.getId())); + ActionTestUtils.executeBlocking(testNodes[0].transportCancelTasksAction, request); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertEquals(0, resourceTasks.size()); + assertNull(throwableReference.get()); + assertNotNull(responseReference.get()); + assertEquals(1, responseReference.get().failureCount()); + assertEquals(TaskCancelledException.class, findActualException(responseReference.get().failures().get(0)).getClass()); + } + + public void testTaskResourceTrackingDisabled() throws Exception { + setup(false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + + taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Exception { + setup(true); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + TaskManager taskManager = testNodes[0].transportService.getTaskManager(); + Map resourceTasks = taskManager.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + // One thread is currently working on task but not finished + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); + assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); + + taskManager.setTaskResourceTrackingEnabled(false); + }; + + taskTestContext.operationFinishedValidator = threadId -> { + Task task = resourceTasks.values().stream().findAny().get(); + // Thread has finished working on the task's runnable + assertEquals(1, resourceTasks.size()); + assertEquals(1, task.getResourceStats().size()); + assertEquals(1, task.getResourceStats().get(threadId).size()); + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + + long expectedArrayAllocationOverhead = 2 * 4012688; // Task's memory overhead due to array allocations + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); + assertTrue(Math.abs(actualTaskMemoryOverhead - expectedArrayAllocationOverhead) < TASK_MAX_GENERAL_MEMORY_OVERHEAD); + assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exception { + setup(false); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + TaskTestContext taskTestContext = new TaskTestContext(); + + TaskManager taskManager = testNodes[0].transportService.getTaskManager(); + Map resourceTasks = taskManager.getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + assertEquals(0, resourceTasks.size()); + + taskManager.setTaskResourceTrackingEnabled(true); + }; + + taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + + private void setup(boolean resourceTrackingEnabled) { + Settings settings = Settings.builder().put("task_resource_tracking.enabled", resourceTrackingEnabled).build(); + setupTestNodes(settings); + connectNodes(testNodes[0]); + TaskAwareRunnable.setListener(testNodes[0].transportService.getTaskManager()); + } + + private Throwable findActualException(Exception e) { + Throwable throwable = e.getCause(); + while (throwable.getCause() != null) { + throwable = throwable.getCause(); + } + return throwable; + } + + private void assertTasksRequestFinishedSuccessfully(int activeResourceTasks, NodesResponse nodesResponse, Throwable throwable) { + assertEquals(0, activeResourceTasks); + assertNull(throwable); + assertNotNull(nodesResponse); + assertEquals(0, nodesResponse.failureCount()); + } + +} From a01aac2ee41d06063644c2c595c0ceeb5be49eed Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Wed, 23 Mar 2022 20:25:39 +0530 Subject: [PATCH 03/26] List tasks action support for task resource refresh Signed-off-by: Tushar Kharbanda --- .../tasks/list/TransportListTasksAction.java | 6 +++ .../org/opensearch/tasks/TaskManager.java | 22 +++++++++ .../node/tasks/ResourceAwareTasksTests.java | 45 +++++++++++++++++++ .../opensearch/tasks/TaskManagerTests.java | 41 +++++++++++++++++ 4 files changed, 114 insertions(+) diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java index b7875c5f99774..4e87bfc202d8b 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java @@ -101,6 +101,12 @@ protected void processTasks(ListTasksRequest request, Consumer operation) } taskManager.waitForTaskCompletion(task, timeoutNanos); }); + } else { + operation = operation.andThen(task -> { + if (task.supportsResourceTracking()) { + taskManager.refreshResourceStats(task); + } + }); } super.processTasks(request, operation); } diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index cfd2873a79bb0..a6dd4c4f19caf 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -558,6 +558,28 @@ private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; } + public void refreshResourceStats(Task... tasks) { + if (isTaskResourceTrackingEnabled() == false) { + return; + } + + for (Task task : tasks) { + if (task.supportsResourceTracking() && resourceAwareTasks.containsKey(task.getId())) { + refreshResourceStats(task); + } + } + } + + private void refreshResourceStats(Task resourceAwareTask) { + resourceAwareTask.getResourceStats().forEach((threadId, threadResourceInfos) -> { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } + }); + } + public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index 9c83d55bf6e9a..e85df6ae900c2 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -25,6 +25,7 @@ import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; +import org.opensearch.tasks.TaskInfo; import org.opensearch.tasks.TaskManager; import org.opensearch.threadpool.TaskAwareRunnable; import org.opensearch.threadpool.ThreadPool; @@ -494,6 +495,50 @@ public void onFailure(Exception e) { assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); } + public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException { + setup(true); + + final AtomicReference throwableReference = new AtomicReference<>(); + final AtomicReference responseReference = new AtomicReference<>(); + + TaskTestContext taskTestContext = new TaskTestContext(); + + Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + + taskTestContext.operationStartValidator = threadId -> { + ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking( + testNodes[0].transportListTasksAction, + new ListTasksRequest().setActions("internal:resourceAction*").setDetailed(true) + ); + + TaskInfo taskInfo = listTasksResponse.getTasks().get(1); + + assertNotNull(taskInfo.getResourceStats()); + assertNotNull(taskInfo.getResourceStats().getResourceUsageInfo()); + assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getCpuTimeInNanos() > 0); + assertTrue(taskInfo.getResourceStats().getResourceUsageInfo().get("total").getMemoryInBytes() > 0); + }; + + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + // Waiting for whole request to complete and return successfully till client + taskTestContext.requestCompleteLatch.await(); + + assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); + } + private void setup(boolean resourceTrackingEnabled) { Settings settings = Settings.builder().put("task_resource_tracking.enabled", resourceTrackingEnabled).build(); setupTestNodes(settings); diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index d8ed4b81973f8..41cff587b3470 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -34,13 +34,16 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TaskAwareRunnable; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.FakeTcpChannel; @@ -58,6 +61,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; @@ -117,6 +121,43 @@ public void testAddTaskIdToThreadContext() { assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); } + public void testTaskStatsResourceRefresh() { + Settings settings = Settings.builder().put("task_resource_tracking.enabled", true).build(); + + final TaskManager taskManager = new TaskManager(settings, threadPool, Collections.emptySet()); + final Task task = taskManager.register("transport", "test", new SearchRequest()); + + TaskAwareRunnable.setListener(taskManager); + threadPool.getThreadContext().putTransient(TASK_ID, task.getId()); + CountDownLatch latch = new CountDownLatch(1); + + threadPool.executor(ThreadPool.Names.SEARCH).submit(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + + } + + @Override + protected void doRun() throws InterruptedException { + Object[] allocation = new Object[100]; + latch.await(); + } + }); + + long cpuConsumptionAtStart = task.getTotalResourceStats().getCpuTimeInNanos(); + long memoryConsumptionAtStart = task.getTotalResourceStats().getMemoryInBytes(); + + taskManager.refreshResourceStats(task); + + long cpuConsumptionAfterRefresh = task.getTotalResourceStats().getCpuTimeInNanos(); + long memoryConsumptionAfterRefresh = task.getTotalResourceStats().getMemoryInBytes(); + + latch.countDown(); + + assertTrue(cpuConsumptionAtStart < cpuConsumptionAfterRefresh); + assertTrue(memoryConsumptionAtStart < memoryConsumptionAfterRefresh); + } + private void verifyThreadContextFixedHeaders(String key, String value) { assertEquals(threadPool.getThreadContext().getHeader(key), value); assertEquals(threadPool.getThreadContext().getTransient(key), value); From 36d2de18e3d6b3e4eac8faca64256734cd59c730 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Tue, 29 Mar 2022 23:11:06 +0530 Subject: [PATCH 04/26] Handle task unregistration case on same thread Signed-off-by: Tushar Kharbanda --- server/src/main/java/org/opensearch/tasks/TaskManager.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index a6dd4c4f19caf..3fe32194c0d46 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -178,7 +178,6 @@ public Task register(String type, String action, TaskAwareRequest request) { if (logger.isTraceEnabled()) { logger.trace("register {} [{}] [{}] [{}]", task.getId(), type, action, task.getDescription()); } - if (task.supportsResourceTracking() && isTaskResourceTrackingEnabled()) { resourceAwareTasks.put(task.getId(), task); } @@ -237,7 +236,11 @@ public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); if (task.supportsResourceTracking()) { - resourceAwareTasks.remove(task.getId(), task); + try { + taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); + } catch (Exception ignored) {} finally { + resourceAwareTasks.remove(task.getId(), task); + } } if (task instanceof CancellableTask) { From be7cb830b879db71cb4cf22e70b06dcc6f781ebf Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 31 Mar 2022 13:11:31 +0530 Subject: [PATCH 05/26] Add lazy initialisation for RunnableTaskExecutionListener Signed-off-by: Tushar Kharbanda --- .../util/concurrent/OpenSearchExecutors.java | 17 +++++++---- .../main/java/org/opensearch/node/Node.java | 10 +++++-- .../AutoQueueAdjustingExecutorBuilder.java | 18 +++++++++++- .../RunnableTaskListenerFactory.java | 29 +++++++++++++++++++ .../threadpool/TaskAwareRunnable.java | 16 ++++------ .../org/opensearch/threadpool/ThreadPool.java | 21 ++++++++++++-- .../node/tasks/ResourceAwareTasksTests.java | 4 +-- .../node/tasks/TaskManagerTestCase.java | 5 +++- .../opensearch/tasks/TaskManagerTests.java | 8 +++-- .../opensearch/threadpool/TestThreadPool.java | 15 +++++++++- 10 files changed, 115 insertions(+), 28 deletions(-) create mode 100644 server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java index f44464d95efd4..2c2ded62fc25d 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java @@ -40,6 +40,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.node.Node; +import org.opensearch.threadpool.RunnableTaskListenerFactory; import org.opensearch.threadpool.TaskAwareRunnable; import java.util.List; @@ -191,7 +192,8 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( int frameSize, TimeValue targetedResponseTime, ThreadFactory threadFactory, - ThreadContext contextHolder + ThreadContext contextHolder, + RunnableTaskListenerFactory runnableTaskListener ) { if (initialQueueCapacity <= 0) { throw new IllegalArgumentException( @@ -203,10 +205,15 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( initialQueueCapacity ); - Function runnableWrapper = (runnable) -> { - TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable); - return new TimedRunnable(taskAwareRunnable); - }; + Function runnableWrapper; + if (runnableTaskListener != null) { + runnableWrapper = (runnable) -> { + TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener); + return new TimedRunnable(taskAwareRunnable); + }; + } else { + runnableWrapper = TimedRunnable::new; + } return new QueueResizingOpenSearchThreadPoolExecutor( name, diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index da28c8edcefe9..a1a32274fcb6c 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -37,7 +37,7 @@ import org.apache.lucene.util.Constants; import org.apache.lucene.util.SetOnce; import org.opensearch.index.IndexingPressureService; -import org.opensearch.threadpool.TaskAwareRunnable; +import org.opensearch.threadpool.RunnableTaskListenerFactory; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.Assertions; import org.opensearch.Build; @@ -434,7 +434,8 @@ protected Node( final List> executorBuilders = pluginsService.getExecutorBuilders(settings); - final ThreadPool threadPool = new ThreadPool(settings, executorBuilders.toArray(new ExecutorBuilder[0])); + RunnableTaskListenerFactory runnableTaskListener = new RunnableTaskListenerFactory(); + final ThreadPool threadPool = new ThreadPool(settings, runnableTaskListener, executorBuilders.toArray(new ExecutorBuilder[0])); resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS)); final ResourceWatcherService resourceWatcherService = new ResourceWatcherService(settings, threadPool); resourcesToClose.add(resourceWatcherService); @@ -941,6 +942,7 @@ protected Node( b.bind(ShardLimitValidator.class).toInstance(shardLimitValidator); b.bind(FsHealthService.class).toInstance(fsHealthService); b.bind(SystemIndices.class).toInstance(systemIndices); + b.bind(RunnableTaskListenerFactory.class).toInstance(runnableTaskListener); }); injector = modules.createInjector(); @@ -1059,7 +1061,8 @@ public Node start() throws NodeValidationException { transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class)); transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService)); - TaskAwareRunnable.setListener(transportService.getTaskManager()); + RunnableTaskListenerFactory runnableTaskListener = injector.getInstance(RunnableTaskListenerFactory.class); + runnableTaskListener.apply(transportService.getTaskManager()); transportService.start(); assert localNodeFactory.getNode() != null; @@ -1494,4 +1497,5 @@ DiscoveryNode getNode() { return localNode.get(); } } + } diff --git a/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java b/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java index 2bac5eba9fc28..ca018d46260d6 100644 --- a/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java +++ b/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java @@ -61,6 +61,7 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder maxQueueSizeSetting; private final Setting targetedResponseTimeSetting; private final Setting frameSizeSetting; + private final RunnableTaskListenerFactory runnableTaskListener; AutoQueueAdjustingExecutorBuilder( final Settings settings, @@ -70,6 +71,19 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder> settings() { Setting.Property.Deprecated, Setting.Property.Deprecated ); + this.runnableTaskListener = runnableTaskListener; } @Override @@ -230,7 +245,8 @@ ThreadPool.ExecutorHolder build(final AutoExecutorSettings settings, final Threa frameSize, targetedResponseTime, threadFactory, - threadContext + threadContext, + runnableTaskListener ); // TODO: in a subsequent change we hope to extend ThreadPool.Info to be more specific for the thread pool type final ThreadPool.Info info = new ThreadPool.Info( diff --git a/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java b/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java new file mode 100644 index 0000000000000..3848723f1b40a --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +import org.apache.lucene.util.SetOnce; + +import java.util.function.Function; + +public class RunnableTaskListenerFactory implements Function { + + private final SetOnce listener = new SetOnce<>(); + + @Override + public RunnableTaskExecutionListener apply(RunnableTaskExecutionListener runnableTaskExecutionListener) { + listener.set(runnableTaskExecutionListener); + return listener.get(); + } + + public RunnableTaskExecutionListener get() { + assert listener.get() != null; + return listener.get(); + } +} diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java index 1500a8bd7fdd0..b83868cffbbed 100644 --- a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -26,18 +26,14 @@ */ public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnable { - private static RunnableTaskExecutionListener listener; - private final Runnable original; private final ThreadContext threadContext; + private final RunnableTaskListenerFactory runnableTaskListener; - public TaskAwareRunnable(ThreadContext threadContext, final Runnable original) { + public TaskAwareRunnable(ThreadContext threadContext, final Runnable original, final RunnableTaskListenerFactory runnableTaskListener) { this.original = original; this.threadContext = threadContext; - } - - public static void setListener(RunnableTaskExecutionListener l) { - listener = l; + this.runnableTaskListener = runnableTaskListener; } @Override @@ -61,18 +57,18 @@ public void onRejection(final Exception e) { @Override protected void doRun() throws Exception { - assert listener != null : "Listener should be attached"; + assert runnableTaskListener.get() != null : "Listener should be attached"; Long taskId = threadContext.getTransient(TASK_ID); if (Objects.nonNull(taskId)) { - listener.taskExecutionStartedOnThread(taskId, currentThread().getId()); + runnableTaskListener.get().taskExecutionStartedOnThread(taskId, currentThread().getId()); } try { original.run(); } finally { if (Objects.nonNull(taskId)) { - listener.taskExecutionFinishedOnThread(taskId, currentThread().getId()); + runnableTaskListener.get().taskExecutionFinishedOnThread(taskId, currentThread().getId()); } } diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index c2530ccee5588..4371eb0bf617b 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -184,6 +184,14 @@ public Collection builders() { ); public ThreadPool(final Settings settings, final ExecutorBuilder... customBuilders) { + this(settings, null, customBuilders); + } + + public ThreadPool( + final Settings settings, + final RunnableTaskListenerFactory runnableTaskListener, + final ExecutorBuilder... customBuilders + ) { assert Node.NODE_NAME_SETTING.exists(settings); final Map builders = new HashMap<>(); @@ -197,11 +205,20 @@ public ThreadPool(final Settings settings, final ExecutorBuilder... customBui builders.put(Names.ANALYZE, new FixedExecutorBuilder(settings, Names.ANALYZE, 1, 16)); builders.put( Names.SEARCH, - new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH, searchThreadPoolSize(allocatedProcessors), 1000, 1000, 1000, 2000) + new AutoQueueAdjustingExecutorBuilder( + settings, + Names.SEARCH, + searchThreadPoolSize(allocatedProcessors), + 1000, + 1000, + 1000, + 2000, + runnableTaskListener + ) ); builders.put( Names.SEARCH_THROTTLED, - new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, 100, 100, 200) + new AutoQueueAdjustingExecutorBuilder(settings, Names.SEARCH_THROTTLED, 1, 100, 100, 100, 200, runnableTaskListener) ); builders.put(Names.MANAGEMENT, new ScalingExecutorBuilder(Names.MANAGEMENT, 1, 5, TimeValue.timeValueMinutes(5))); // no queue as this means clients will need to handle rejections on listener queue even if the operation succeeded diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index e85df6ae900c2..d078c5c0e7ca5 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -27,7 +27,6 @@ import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; import org.opensearch.tasks.TaskManager; -import org.opensearch.threadpool.TaskAwareRunnable; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -543,7 +542,8 @@ private void setup(boolean resourceTrackingEnabled) { Settings settings = Settings.builder().put("task_resource_tracking.enabled", resourceTrackingEnabled).build(); setupTestNodes(settings); connectNodes(testNodes[0]); - TaskAwareRunnable.setListener(testNodes[0].transportService.getTaskManager()); + + runnableTaskListener.apply(testNodes[0].transportService.getTaskManager()); } private Throwable findActualException(Exception e) { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index c8411b31e0709..a4b3c2be545b1 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -61,6 +61,7 @@ import org.opensearch.tasks.TaskManager; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.threadpool.RunnableTaskListenerFactory; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -89,10 +90,12 @@ public abstract class TaskManagerTestCase extends OpenSearchTestCase { protected ThreadPool threadPool; protected TestNode[] testNodes; protected int nodesCount; + protected RunnableTaskListenerFactory runnableTaskListener; @Before public void setupThreadPool() { - threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName()); + runnableTaskListener = new RunnableTaskListenerFactory(); + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } public void setupTestNodes(Settings settings) { diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 41cff587b3470..4c6cc88e9892b 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -43,7 +43,7 @@ import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.TaskAwareRunnable; +import org.opensearch.threadpool.RunnableTaskListenerFactory; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.FakeTcpChannel; @@ -73,10 +73,12 @@ public class TaskManagerTests extends OpenSearchTestCase { private ThreadPool threadPool; + private RunnableTaskListenerFactory runnableTaskListener; @Before public void setupThreadPool() { - threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName()); + runnableTaskListener = new RunnableTaskListenerFactory(); + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } @After @@ -127,7 +129,7 @@ public void testTaskStatsResourceRefresh() { final TaskManager taskManager = new TaskManager(settings, threadPool, Collections.emptySet()); final Task task = taskManager.register("transport", "test", new SearchRequest()); - TaskAwareRunnable.setListener(taskManager); + runnableTaskListener.apply(taskManager); threadPool.getThreadContext().putTransient(TASK_ID, task.getId()); CountDownLatch latch = new CountDownLatch(1); diff --git a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java index 5f8611d99f0a0..eeca3e4719ac9 100644 --- a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java +++ b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java @@ -47,12 +47,25 @@ public class TestThreadPool extends ThreadPool { private volatile boolean returnRejectingExecutor = false; private volatile ThreadPoolExecutor rejectingExecutor; + public TestThreadPool(String name, RunnableTaskListenerFactory runnableTaskListener, ExecutorBuilder... customBuilders) { + this(name, Settings.EMPTY, runnableTaskListener, customBuilders); + } + public TestThreadPool(String name, ExecutorBuilder... customBuilders) { this(name, Settings.EMPTY, customBuilders); } public TestThreadPool(String name, Settings settings, ExecutorBuilder... customBuilders) { - super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), customBuilders); + this(name, settings, null, customBuilders); + } + + public TestThreadPool( + String name, + Settings settings, + RunnableTaskListenerFactory runnableTaskListener, + ExecutorBuilder... customBuilders + ) { + super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), runnableTaskListener, customBuilders); } @Override From f61ef7de720b0e0e241b670af76fa8b3faf800bf Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 31 Mar 2022 16:31:33 +0530 Subject: [PATCH 06/26] Segregate resource tracking logic to a separate service. Signed-off-by: Tushar Kharbanda --- .../tasks/list/TransportListTasksAction.java | 6 +- .../common/settings/ClusterSettings.java | 4 +- .../main/java/org/opensearch/node/Node.java | 9 +- .../org/opensearch/tasks/TaskManager.java | 107 ++------------ .../tasks/TaskResourceTrackingService.java | 132 ++++++++++++++++++ .../transport/TransportService.java | 4 - .../node/tasks/ResourceAwareTasksTests.java | 21 ++- .../node/tasks/TaskManagerTestCase.java | 4 + .../client/node/NodeClientHeadersTests.java | 3 + .../snapshots/SnapshotResiliencyTests.java | 2 + .../opensearch/tasks/TaskManagerTests.java | 40 ------ .../test/transport/MockTransportService.java | 5 + 12 files changed, 177 insertions(+), 160 deletions(-) create mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java index 4e87bfc202d8b..960d82b106cb6 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java @@ -102,11 +102,7 @@ protected void processTasks(ListTasksRequest request, Consumer operation) taskManager.waitForTaskCompletion(task, timeoutNanos); }); } else { - operation = operation.andThen(task -> { - if (task.supportsResourceTracking()) { - taskManager.refreshResourceStats(task); - } - }); + operation = operation.andThen(taskManager::refreshTasksInfo); } super.processTasks(request, operation); } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index d79de1f9a4179..4cacc3bcf37eb 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -40,7 +40,7 @@ import org.opensearch.index.ShardIndexingPressureMemoryManager; import org.opensearch.index.ShardIndexingPressureSettings; import org.opensearch.index.ShardIndexingPressureStore; -import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction; import org.opensearch.action.admin.indices.close.TransportCloseIndexAction; @@ -570,7 +570,7 @@ public void apply(Settings value, Settings current, Settings previous) { ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT, ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS, IndexingPressure.MAX_INDEXING_BYTES, - TaskManager.TASK_RESOURCE_TRACKING_ENABLED + TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED ) ) ); diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index a1a32274fcb6c..f97f0b46ef476 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -37,6 +37,7 @@ import org.apache.lucene.util.Constants; import org.apache.lucene.util.SetOnce; import org.opensearch.index.IndexingPressureService; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.RunnableTaskListenerFactory; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.Assertions; @@ -1061,8 +1062,14 @@ public Node start() throws NodeValidationException { transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class)); transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService)); + TaskResourceTrackingService taskResourceTrackingService = new TaskResourceTrackingService( + settings(), + clusterService.getClusterSettings() + ); + transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); + RunnableTaskListenerFactory runnableTaskListener = injector.getInstance(RunnableTaskListenerFactory.class); - runnableTaskListener.apply(transportService.getTaskManager()); + runnableTaskListener.apply(taskResourceTrackingService); transportService.start(); assert localNodeFactory.getNode() != null; diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 3fe32194c0d46..994203a396ebd 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -34,7 +34,6 @@ import com.carrotsearch.hppc.ObjectIntHashMap; import com.carrotsearch.hppc.ObjectIntMap; -import com.sun.management.ThreadMXBean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -52,7 +51,6 @@ import org.opensearch.common.Nullable; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.ByteSizeValue; import org.opensearch.common.unit.TimeValue; @@ -60,12 +58,10 @@ import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ConcurrentMapLong; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TcpChannel; import java.io.IOException; -import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -84,19 +80,11 @@ import static org.opensearch.common.unit.TimeValue.timeValueMillis; import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; -import static org.opensearch.tasks.ResourceStatsType.WORKER_STATS; /** * Task Manager service for keeping track of currently running tasks on the nodes */ -public class TaskManager implements ClusterStateApplier, RunnableTaskExecutionListener { - - public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( - "task_resource_tracking.enabled", - false, - Setting.Property.Dynamic, - Setting.Property.NodeScope - ); +public class TaskManager implements ClusterStateApplier { public static final String TASK_ID = "TASK_ID"; @@ -104,8 +92,6 @@ public class TaskManager implements ClusterStateApplier, RunnableTaskExecutionLi private static final TimeValue WAIT_FOR_COMPLETION_POLL = timeValueMillis(100); - private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); - /** * Rest headers that are copied to the task */ @@ -122,27 +108,18 @@ public class TaskManager implements ClusterStateApplier, RunnableTaskExecutionLi private final Map banedParents = new ConcurrentHashMap<>(); private TaskResultsService taskResultsService; + private SetOnce taskResourceTrackingService = new SetOnce<>(); private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES; private final ByteSizeValue maxHeaderSize; - private volatile boolean taskResourceTrackingEnabled; private final Map channelPendingTaskTrackers = ConcurrentCollections.newConcurrentMap(); private final SetOnce cancellationService = new SetOnce<>(); - private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); - public TaskManager(Settings settings, ThreadPool threadPool, Set taskHeaders) { this.threadPool = threadPool; this.taskHeaders = new ArrayList<>(taskHeaders); this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); - this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); - } - - public boolean isTaskResourceTrackingEnabled() { - return taskResourceTrackingEnabled - && threadMXBean.isThreadAllocatedMemorySupported() - && threadMXBean.isThreadAllocatedMemoryEnabled(); } public void setTaskResultsService(TaskResultsService taskResultsService) { @@ -154,6 +131,10 @@ public void setTaskCancellationService(TaskCancellationService taskCancellationS this.cancellationService.set(taskCancellationService); } + public void setTaskResourceTrackingService(TaskResourceTrackingService taskResourceTrackingService) { + this.taskResourceTrackingService.set(taskResourceTrackingService); + } + /** * Registers a task without parent task */ @@ -178,8 +159,8 @@ public Task register(String type, String action, TaskAwareRequest request) { if (logger.isTraceEnabled()) { logger.trace("register {} [{}] [{}] [{}]", task.getId(), type, action, task.getDescription()); } - if (task.supportsResourceTracking() && isTaskResourceTrackingEnabled()) { - resourceAwareTasks.put(task.getId(), task); + if (task.supportsResourceTracking()) { + taskResourceTrackingService.get().registerTask(task); } if (task instanceof CancellableTask) { @@ -236,11 +217,7 @@ public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); if (task.supportsResourceTracking()) { - try { - taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); - } catch (Exception ignored) {} finally { - resourceAwareTasks.remove(task.getId(), task); - } + taskResourceTrackingService.get().unregisterTask(task); } if (task instanceof CancellableTask) { @@ -365,10 +342,6 @@ public Map getCancellableTasks() { return Collections.unmodifiableMap(taskHashMap); } - public Map getResourceAwareTasks() { - return Collections.unmodifiableMap(resourceAwareTasks); - } - /** * Returns a task with given id, or null if the task is not found. */ @@ -525,66 +498,8 @@ public ThreadContext.StoredContext addTaskIdInThreadContext(@Nullable Task task) return storedContext; } - /** - * Called when a thread starts working on a task's runnable. - * - * @param taskId of the task for which runnable is starting - * @param threadId of the thread which will be executing the runnable and we need to check resource usage for this - * thread - */ - @Override - public void taskExecutionStartedOnThread(long taskId, long threadId) { - if (resourceAwareTasks.containsKey(taskId)) { - resourceAwareTasks.get(taskId).startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); - } - } - - /** - * Called when a thread finishes working on a task's runnable. - * - * @param taskId of the task for which runnable is complete - * @param threadId of the thread which executed the runnable and we need to check resource usage for this thread - */ - @Override - public void taskExecutionFinishedOnThread(long taskId, long threadId) { - if (resourceAwareTasks.containsKey(taskId)) { - resourceAwareTasks.get(taskId).stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); - } - } - - private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { - ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( - ResourceStats.MEMORY, - threadMXBean.getThreadAllocatedBytes(threadId) - ); - ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); - return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; - } - - public void refreshResourceStats(Task... tasks) { - if (isTaskResourceTrackingEnabled() == false) { - return; - } - - for (Task task : tasks) { - if (task.supportsResourceTracking() && resourceAwareTasks.containsKey(task.getId())) { - refreshResourceStats(task); - } - } - } - - private void refreshResourceStats(Task resourceAwareTask) { - resourceAwareTask.getResourceStats().forEach((threadId, threadResourceInfos) -> { - for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { - if (threadResourceInfo.isActive()) { - resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); - } - } - }); - } - - public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { - this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; + public void refreshTasksInfo(Task... tasks) { + taskResourceTrackingService.get().refreshResourceStats(tasks); } private static class CancellableTaskHolder { diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java new file mode 100644 index 0000000000000..0e8926370f6ec --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -0,0 +1,132 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import com.sun.management.ThreadMXBean; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ConcurrentCollections; +import org.opensearch.common.util.concurrent.ConcurrentMapLong; +import org.opensearch.threadpool.RunnableTaskExecutionListener; + +import java.lang.management.ManagementFactory; +import java.util.Collections; +import java.util.Map; + +import static org.opensearch.tasks.ResourceStatsType.WORKER_STATS; + +/** + * Service that helps track resource usage of tasks running on a node. + */ +public class TaskResourceTrackingService implements RunnableTaskExecutionListener { + + public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( + "task_resource_tracking.enabled", + false, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); + + private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + private volatile boolean taskResourceTrackingEnabled; + + public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings) { + this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); + + clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled); + } + + public boolean isTaskResourceTrackingEnabled() { + return taskResourceTrackingEnabled + && threadMXBean.isThreadAllocatedMemorySupported() + && threadMXBean.isThreadAllocatedMemoryEnabled(); + } + + public void registerTask(Task task) { + if (isTaskResourceTrackingEnabled()) { + resourceAwareTasks.put(task.getId(), task); + } + } + + public void unregisterTask(Task task) { + try { + taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); + } catch (Exception ignored) {} finally { + resourceAwareTasks.remove(task.getId(), task); + } + } + + public void refreshResourceStats(Task... tasks) { + if (isTaskResourceTrackingEnabled() == false) { + return; + } + + for (Task task : tasks) { + if (task.supportsResourceTracking() && resourceAwareTasks.containsKey(task.getId())) { + refreshResourceStats(task); + } + } + } + + private void refreshResourceStats(Task resourceAwareTask) { + resourceAwareTask.getResourceStats().forEach((threadId, threadResourceInfos) -> { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } + }); + } + + public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { + this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; + } + + /** + * Called when a thread starts working on a task's runnable. + * + * @param taskId of the task for which runnable is starting + * @param threadId of the thread which will be executing the runnable and we need to check resource usage for this + * thread + */ + @Override + public void taskExecutionStartedOnThread(long taskId, long threadId) { + if (resourceAwareTasks.containsKey(taskId)) { + resourceAwareTasks.get(taskId).startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } + + /** + * Called when a thread finishes working on a task's runnable. + * + * @param taskId of the task for which runnable is complete + * @param threadId of the thread which executed the runnable and we need to check resource usage for this thread + */ + @Override + public void taskExecutionFinishedOnThread(long taskId, long threadId) { + if (resourceAwareTasks.containsKey(taskId)) { + resourceAwareTasks.get(taskId).stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } + + public Map getResourceAwareTasks() { + return Collections.unmodifiableMap(resourceAwareTasks); + } + + private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { + ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( + ResourceStats.MEMORY, + threadMXBean.getThreadAllocatedBytes(threadId) + ); + ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); + return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; + } +} diff --git a/server/src/main/java/org/opensearch/transport/TransportService.java b/server/src/main/java/org/opensearch/transport/TransportService.java index 2fa9c515b98fa..ad46ed742c806 100644 --- a/server/src/main/java/org/opensearch/transport/TransportService.java +++ b/server/src/main/java/org/opensearch/transport/TransportService.java @@ -218,10 +218,6 @@ public TransportService( remoteClusterService.listenForUpdates(clusterSettings); } clusterSettings.addSettingsUpdateConsumer(TransportSettings.SLOW_OPERATION_THRESHOLD_SETTING, transport::setSlowLogThreshold); - clusterSettings.addSettingsUpdateConsumer( - TaskManager.TASK_RESOURCE_TRACKING_ENABLED, - taskManager::setTaskResourceTrackingEnabled - ); } registerRequestHandler( HANDSHAKE_ACTION_NAME, diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index d078c5c0e7ca5..564ccfa73332f 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -26,7 +26,6 @@ import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; -import org.opensearch.tasks.TaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -249,7 +248,7 @@ public void testBasicTaskResourceTracking() throws Exception { final AtomicReference responseReference = new AtomicReference<>(); TaskTestContext taskTestContext = new TaskTestContext(); - Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); taskTestContext.operationStartValidator = threadId -> { Task task = resourceTasks.values().stream().findAny().get(); @@ -305,7 +304,7 @@ public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { final AtomicReference responseReference = new AtomicReference<>(); TaskTestContext taskTestContext = new TaskTestContext(); - Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); taskTestContext.operationStartValidator = threadId -> { Task task = resourceTasks.values().stream().findAny().get(); @@ -373,7 +372,7 @@ public void testTaskResourceTrackingDisabled() throws Exception { final AtomicReference responseReference = new AtomicReference<>(); TaskTestContext taskTestContext = new TaskTestContext(); - Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); taskTestContext.operationStartValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; @@ -406,8 +405,7 @@ public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Excepti final AtomicReference responseReference = new AtomicReference<>(); TaskTestContext taskTestContext = new TaskTestContext(); - TaskManager taskManager = testNodes[0].transportService.getTaskManager(); - Map resourceTasks = taskManager.getResourceAwareTasks(); + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); taskTestContext.operationStartValidator = threadId -> { Task task = resourceTasks.values().stream().findAny().get(); @@ -419,7 +417,7 @@ public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Excepti assertEquals(0, task.getTotalResourceStats().getCpuTimeInNanos()); assertEquals(0, task.getTotalResourceStats().getMemoryInBytes()); - taskManager.setTaskResourceTrackingEnabled(false); + testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(false); }; taskTestContext.operationFinishedValidator = threadId -> { @@ -463,13 +461,12 @@ public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exceptio final AtomicReference responseReference = new AtomicReference<>(); TaskTestContext taskTestContext = new TaskTestContext(); - TaskManager taskManager = testNodes[0].transportService.getTaskManager(); - Map resourceTasks = taskManager.getResourceAwareTasks(); + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); taskTestContext.operationStartValidator = threadId -> { assertEquals(0, resourceTasks.size()); - taskManager.setTaskResourceTrackingEnabled(true); + testNodes[0].taskResourceTrackingService.setTaskResourceTrackingEnabled(true); }; taskTestContext.operationFinishedValidator = threadId -> { assertEquals(0, resourceTasks.size()); }; @@ -502,7 +499,7 @@ public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException TaskTestContext taskTestContext = new TaskTestContext(); - Map resourceTasks = testNodes[0].transportService.getTaskManager().getResourceAwareTasks(); + Map resourceTasks = testNodes[0].taskResourceTrackingService.getResourceAwareTasks(); taskTestContext.operationStartValidator = threadId -> { ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking( @@ -543,7 +540,7 @@ private void setup(boolean resourceTrackingEnabled) { setupTestNodes(settings); connectNodes(testNodes[0]); - runnableTaskListener.apply(testNodes[0].transportService.getTaskManager()); + runnableTaskListener.apply(testNodes[0].taskResourceTrackingService); } private Throwable findActualException(Exception e) { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index a4b3c2be545b1..5f720b1b53485 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -59,6 +59,7 @@ import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.tasks.TaskCancellationService; import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.threadpool.RunnableTaskListenerFactory; @@ -228,6 +229,8 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool transportService.start(); clusterService = createClusterService(threadPool, discoveryNode.get()); clusterService.addStateApplier(transportService.getTaskManager()); + taskResourceTrackingService = new TaskResourceTrackingService(settings, clusterService.getClusterSettings()); + transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); ActionFilters actionFilters = new ActionFilters(emptySet()); transportListTasksAction = new TransportListTasksAction(clusterService, transportService, actionFilters); transportCancelTasksAction = new TransportCancelTasksAction(clusterService, transportService, actionFilters); @@ -236,6 +239,7 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool public final ClusterService clusterService; public final TransportService transportService; + public final TaskResourceTrackingService taskResourceTrackingService; private final SetOnce discoveryNode = new SetOnce<>(); public final TransportListTasksAction transportListTasksAction; public final TransportCancelTasksAction transportCancelTasksAction; diff --git a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java index cb9e3a6a19388..dddd7fd1350dd 100644 --- a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java +++ b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java @@ -43,11 +43,14 @@ import org.opensearch.common.settings.Settings; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import java.util.Collections; import java.util.HashMap; +import static org.mockito.Mockito.mock; + public class NodeClientHeadersTests extends AbstractClientHeadersTestCase { private static final ActionFilters EMPTY_FILTERS = new ActionFilters(Collections.emptySet()); diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 26e19e532b6bc..821906a763f38 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -198,6 +198,7 @@ import org.opensearch.search.fetch.FetchPhase; import org.opensearch.search.query.QueryPhase; import org.opensearch.snapshots.mockstore.MockEventuallyConsistentRepository; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.disruption.DisruptableMockTransport; import org.opensearch.threadpool.ThreadPool; @@ -1738,6 +1739,7 @@ public void onFailure(final Exception e) { final IndexNameExpressionResolver indexNameExpressionResolver = new IndexNameExpressionResolver( new ThreadContext(Settings.EMPTY) ); + transportService.getTaskManager().setTaskResourceTrackingService(new TaskResourceTrackingService(settings, clusterSettings)); repositoriesService = new RepositoriesService( settings, clusterService, diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 4c6cc88e9892b..5271b2af8880d 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -34,12 +34,10 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests; -import org.opensearch.action.search.SearchRequest; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; @@ -61,7 +59,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; @@ -123,43 +120,6 @@ public void testAddTaskIdToThreadContext() { assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); } - public void testTaskStatsResourceRefresh() { - Settings settings = Settings.builder().put("task_resource_tracking.enabled", true).build(); - - final TaskManager taskManager = new TaskManager(settings, threadPool, Collections.emptySet()); - final Task task = taskManager.register("transport", "test", new SearchRequest()); - - runnableTaskListener.apply(taskManager); - threadPool.getThreadContext().putTransient(TASK_ID, task.getId()); - CountDownLatch latch = new CountDownLatch(1); - - threadPool.executor(ThreadPool.Names.SEARCH).submit(new AbstractRunnable() { - @Override - public void onFailure(Exception e) { - - } - - @Override - protected void doRun() throws InterruptedException { - Object[] allocation = new Object[100]; - latch.await(); - } - }); - - long cpuConsumptionAtStart = task.getTotalResourceStats().getCpuTimeInNanos(); - long memoryConsumptionAtStart = task.getTotalResourceStats().getMemoryInBytes(); - - taskManager.refreshResourceStats(task); - - long cpuConsumptionAfterRefresh = task.getTotalResourceStats().getCpuTimeInNanos(); - long memoryConsumptionAfterRefresh = task.getTotalResourceStats().getMemoryInBytes(); - - latch.countDown(); - - assertTrue(cpuConsumptionAtStart < cpuConsumptionAfterRefresh); - assertTrue(memoryConsumptionAtStart < memoryConsumptionAfterRefresh); - } - private void verifyThreadContextFixedHeaders(String key, String value) { assertEquals(threadPool.getThreadContext().getHeader(key), value); assertEquals(threadPool.getThreadContext().getTransient(key), value); diff --git a/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java index 9b9baebd540c3..637f198e3ed1a 100644 --- a/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java @@ -57,6 +57,7 @@ import org.opensearch.node.Node; import org.opensearch.plugins.Plugin; import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.threadpool.ThreadPool; @@ -89,6 +90,8 @@ import java.util.function.Function; import java.util.function.Supplier; +import static org.mockito.Mockito.mock; + /** * A mock delegate service that allows to simulate different network topology failures. * Internally it maps TransportAddress objects to rules that inject failures. @@ -250,6 +253,8 @@ private MockTransportService( new StubbableConnectionManager(new ClusterConnectionManager(settings, transport)) ); this.original = transport.getDelegate(); + + this.taskManager.setTaskResourceTrackingService(mock(TaskResourceTrackingService.class)); } private static TransportAddress[] extractTransportAddresses(TransportService transportService) { From 7058bd7d57ee60b35b39e41cfa50fd678004bb47 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 31 Mar 2022 19:51:42 +0530 Subject: [PATCH 07/26] Check for running threads during task unregister Signed-off-by: Tushar Kharbanda --- .../tasks/TaskResourceTrackingService.java | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 0e8926370f6ec..ace729c090f26 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -18,6 +18,7 @@ import java.lang.management.ManagementFactory; import java.util.Collections; +import java.util.List; import java.util.Map; import static org.opensearch.tasks.ResourceStatsType.WORKER_STATS; @@ -56,12 +57,25 @@ public void registerTask(Task task) { } } + /** + * unregisters tasks registered earlier. + * + * It doesn't have feature enabled check to avoid any issues if setting was disable while the task was in progress. + * + * It's also responsible to stop tracking the current thread's resources against this task if not already done. + * This happens when the thread handling the request itself calls the unregister method. So in this case unregister + * happens before runnable finishes. + * + * @param task + */ public void unregisterTask(Task task) { - try { + if (!resourceAwareTasks.containsKey(task.getId())) { + return; + } + if (isThreadWorkingOnTask(task, Thread.currentThread().getId())) { taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); - } catch (Exception ignored) {} finally { - resourceAwareTasks.remove(task.getId(), task); } + resourceAwareTasks.remove(task.getId(), task); } public void refreshResourceStats(Task... tasks) { @@ -129,4 +143,15 @@ private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; } + + private boolean isThreadWorkingOnTask(Task task, long threadId) { + List threadResourceInfos = task.getResourceStats().getOrDefault(threadId, Collections.emptyList()); + + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) { + return true; + } + } + return false; + } } From aa35b82d4d17a380f75d3396a13b2e4063638d25 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Tue, 5 Apr 2022 14:25:35 +0530 Subject: [PATCH 08/26] Moved thread context logic to resource tracking service Signed-off-by: Tushar Kharbanda --- .../action/support/TransportAction.java | 67 +++++++++--------- .../main/java/org/opensearch/node/Node.java | 3 +- .../org/opensearch/tasks/TaskManager.java | 38 +--------- .../tasks/TaskResourceTrackingService.java | 70 +++++++++++++++---- .../threadpool/TaskAwareRunnable.java | 2 +- .../transport/RequestHandlerRegistry.java | 2 +- .../node/tasks/TaskManagerTestCase.java | 2 +- .../node/tasks/TransportTasksActionTests.java | 4 +- .../snapshots/SnapshotResiliencyTests.java | 3 +- .../opensearch/tasks/TaskManagerTests.java | 8 +-- .../test/tasks/MockTaskManager.java | 4 +- 11 files changed, 111 insertions(+), 92 deletions(-) diff --git a/server/src/main/java/org/opensearch/action/support/TransportAction.java b/server/src/main/java/org/opensearch/action/support/TransportAction.java index 97b975c255d2b..5ec218f54dec0 100644 --- a/server/src/main/java/org/opensearch/action/support/TransportAction.java +++ b/server/src/main/java/org/opensearch/action/support/TransportAction.java @@ -89,31 +89,39 @@ public final Task execute(Request request, ActionListener listener) { */ final Releasable unregisterChildNode = registerChildNode(request.getParentTask()); final Task task; + try { task = taskManager.register("transport", actionName, request); } catch (TaskCancelledException e) { unregisterChildNode.close(); throw e; } - execute(task, request, new ActionListener() { - @Override - public void onResponse(Response response) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onResponse(response); + + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); + try { + execute(task, request, new ActionListener() { + @Override + public void onResponse(Response response) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onResponse(response); + } } - } - @Override - public void onFailure(Exception e) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onFailure(e); + @Override + public void onFailure(Exception e) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onFailure(e); + } } - } - }); + }); + } finally { + storedContext.close(); + } + return task; } @@ -156,25 +164,18 @@ public void onFailure(Exception e) { * Use this method when the transport action should continue to run in the context of the current task */ public final void execute(Task task, Request request, ActionListener listener) { - ThreadContext.StoredContext storedContext = taskManager.addTaskIdInThreadContext(task); - - try { - ActionRequestValidationException validationException = request.validate(); - if (validationException != null) { - listener.onFailure(validationException); - return; - } - - if (task != null && request.getShouldStoreResult()) { - listener = new TaskResultStoringActionListener<>(taskManager, task, listener); - } - - RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger); - requestFilterChain.proceed(task, actionName, request, listener); + ActionRequestValidationException validationException = request.validate(); + if (validationException != null) { + listener.onFailure(validationException); + return; + } - } finally { - storedContext.restore(); + if (task != null && request.getShouldStoreResult()) { + listener = new TaskResultStoringActionListener<>(taskManager, task, listener); } + + RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger); + requestFilterChain.proceed(task, actionName, request, listener); } protected abstract void doExecute(Task task, Request request, ActionListener listener); diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index f97f0b46ef476..254b2d85de050 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -1064,7 +1064,8 @@ public Node start() throws NodeValidationException { TaskResourceTrackingService taskResourceTrackingService = new TaskResourceTrackingService( settings(), - clusterService.getClusterSettings() + clusterService.getClusterSettings(), + injector.getInstance(ThreadPool.class) ); transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 994203a396ebd..6defc9a829c64 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -86,8 +86,6 @@ */ public class TaskManager implements ClusterStateApplier { - public static final String TASK_ID = "TASK_ID"; - private static final Logger logger = LogManager.getLogger(TaskManager.class); private static final TimeValue WAIT_FOR_COMPLETION_POLL = timeValueMillis(100); @@ -159,9 +157,6 @@ public Task register(String type, String action, TaskAwareRequest request) { if (logger.isTraceEnabled()) { logger.trace("register {} [{}] [{}] [{}]", task.getId(), type, action, task.getDescription()); } - if (task.supportsResourceTracking()) { - taskResourceTrackingService.get().registerTask(task); - } if (task instanceof CancellableTask) { registerCancellableTask(task); @@ -217,7 +212,7 @@ public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); if (task.supportsResourceTracking()) { - taskResourceTrackingService.get().unregisterTask(task); + taskResourceTrackingService.get().stopTracking(task); } if (task instanceof CancellableTask) { @@ -467,35 +462,8 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { throw new OpenSearchTimeoutException("Timed out waiting for completion of [{}]", task); } - /** - * Adds Task Id in the ThreadContext. - *

- * Stashes the existing ThreadContext and preserves all the existing ThreadContext's data in the new ThreadContext - * as well. - * - * @param task for which Task Id needs to be added in ThreadContext. - * @return StoredContext reference to restore the ThreadContext from which we created a new one. - * Caller can call context.restore() to get the existing ThreadContext back. - */ - public ThreadContext.StoredContext addTaskIdInThreadContext(@Nullable Task task) { - if (task == null) { - return () -> {}; - } - - ThreadContext threadContext = threadPool.getThreadContext(); - - if (threadContext.getTransient(TASK_ID) != null) { - logger.warn( - "Task Id already present in the thread context. Thread Id: {}, Existing Task Id: {}, New Task Id: {}. Overwriting", - Thread.currentThread().getId(), - threadContext.getTransient(TASK_ID), - task.getId() - ); - } - - ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID)); - threadContext.putTransient(TASK_ID, task.getId()); - return storedContext; + public ThreadContext.StoredContext taskExecutionStarted(Task task) { + return taskResourceTrackingService.get().startTracking(task); } public void refreshTasksInfo(Task... tasks) { diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index ace729c090f26..f49f9f50c90e1 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -9,12 +9,16 @@ package org.opensearch.tasks; import com.sun.management.ThreadMXBean; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ConcurrentMapLong; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.threadpool.RunnableTaskExecutionListener; +import org.opensearch.threadpool.ThreadPool; import java.lang.management.ManagementFactory; import java.util.Collections; @@ -28,19 +32,25 @@ */ public class TaskResourceTrackingService implements RunnableTaskExecutionListener { + private static final Logger logger = LogManager.getLogger(TaskManager.class); + public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( "task_resource_tracking.enabled", false, Setting.Property.Dynamic, Setting.Property.NodeScope ); + public static final String TASK_ID = "TASK_ID"; + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); - private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + public final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); private volatile boolean taskResourceTrackingEnabled; + private ThreadPool threadPool; - public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings) { + public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool) { this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); + this.threadPool = threadPool; clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled); } @@ -51,31 +61,44 @@ public boolean isTaskResourceTrackingEnabled() { && threadMXBean.isThreadAllocatedMemoryEnabled(); } - public void registerTask(Task task) { - if (isTaskResourceTrackingEnabled()) { - resourceAwareTasks.put(task.getId(), task); + public ThreadContext.StoredContext startTracking(Task task) { + if (task.supportsResourceTracking() == false || isTaskResourceTrackingEnabled() == false) { + return () -> {}; } + + resourceAwareTasks.put(task.getId(), task); + return addTaskIdToThreadContext(task); + } /** * unregisters tasks registered earlier. - * + *

* It doesn't have feature enabled check to avoid any issues if setting was disable while the task was in progress. - * + *

* It's also responsible to stop tracking the current thread's resources against this task if not already done. * This happens when the thread handling the request itself calls the unregister method. So in this case unregister * happens before runnable finishes. * * @param task */ - public void unregisterTask(Task task) { - if (!resourceAwareTasks.containsKey(task.getId())) { - return; - } + public void stopTracking(Task task) { if (isThreadWorkingOnTask(task, Thread.currentThread().getId())) { taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); } - resourceAwareTasks.remove(task.getId(), task); + + assert validateNoActiveThread(task) : "No thread should be active when task is finished"; + + resourceAwareTasks.remove(task.getId()); + } + + private boolean validateNoActiveThread(Task task) { + for (List threadResourceInfos : task.getResourceStats().values()) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) return false; + } + } + return true; } public void refreshResourceStats(Task... tasks) { @@ -154,4 +177,27 @@ private boolean isThreadWorkingOnTask(Task task, long threadId) { } return false; } + + /** + * Adds Task Id in the ThreadContext. + *

+ * Stashes the existing ThreadContext and preserves all the existing ThreadContext's data in the new ThreadContext + * as well. + * + * @param task for which Task Id needs to be added in ThreadContext. + * @return StoredContext reference to restore the ThreadContext from which we created a new one. + * Caller can call context.restore() to get the existing ThreadContext back. + */ + private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { + ThreadContext threadContext = threadPool.getThreadContext(); + + boolean noStaleTaskIdPresentInThreadContext = threadContext.getTransient(TASK_ID) == null + || resourceAwareTasks.containsKey((long) threadContext.getTransient(TASK_ID)); + assert noStaleTaskIdPresentInThreadContext : "Stale Task Id shouldn't be present in thread context"; + + ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID)); + threadContext.putTransient(TASK_ID, task.getId()); + return storedContext; + } + } diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java index b83868cffbbed..7a0b55a2cf011 100644 --- a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -16,7 +16,7 @@ import java.util.Objects; import static java.lang.Thread.currentThread; -import static org.opensearch.tasks.TaskManager.TASK_ID; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; /** * Responsible for wrapping the original task's runnable and sending updates on when it starts and finishes to diff --git a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java index d5de970f728e3..4b987036909e4 100644 --- a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java +++ b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java @@ -82,7 +82,7 @@ public Request newRequest(StreamInput in) throws IOException { public void processMessageReceived(Request request, TransportChannel channel) throws Exception { final Task task = taskManager.register(channel.getChannelType(), action, request); - ThreadContext.StoredContext storedContext = taskManager.addTaskIdInThreadContext(task); + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); Releasable unregisterTask = () -> taskManager.unregister(task); try { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index 5f720b1b53485..bca1115cce570 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -229,7 +229,7 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool transportService.start(); clusterService = createClusterService(threadPool, discoveryNode.get()); clusterService.addStateApplier(transportService.getTaskManager()); - taskResourceTrackingService = new TaskResourceTrackingService(settings, clusterService.getClusterSettings()); + taskResourceTrackingService = new TaskResourceTrackingService(settings, clusterService.getClusterSettings(), threadPool); transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); ActionFilters actionFilters = new ActionFilters(emptySet()); transportListTasksAction = new TransportListTasksAction(clusterService, transportService, actionFilters); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index fb7734c5f6199..dc2fa00c22008 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -90,7 +90,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.not; -import static org.opensearch.tasks.TaskManager.TASK_ID; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; public class TransportTasksActionTests extends TaskManagerTestCase { @@ -651,6 +651,7 @@ protected void taskOperation(TestTasksRequest request, Task task, ActionListener assertEquals(0, responses.failureCount()); } +/* public void testTaskIdPersistsInThreadContext() { Settings settings = Settings.builder().put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), true).build(); setupTestNodes(settings); @@ -699,6 +700,7 @@ protected void taskOperation(TestTasksRequest request, Task task, ActionListener assertEquals(expectedTaskIdInThreadContext[0], actualTaskIdInThreadContext[0]); assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray())); } +*/ /** * This test starts nodes actions that blocks on all nodes. While node actions are blocked in the middle of execution diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 821906a763f38..b69869c0cc5a7 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -1739,7 +1739,8 @@ public void onFailure(final Exception e) { final IndexNameExpressionResolver indexNameExpressionResolver = new IndexNameExpressionResolver( new ThreadContext(Settings.EMPTY) ); - transportService.getTaskManager().setTaskResourceTrackingService(new TaskResourceTrackingService(settings, clusterSettings)); + transportService.getTaskManager() + .setTaskResourceTrackingService(new TaskResourceTrackingService(settings, clusterSettings, threadPool)); repositoriesService = new RepositoriesService( settings, clusterService, diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 5271b2af8880d..f099935f7be5a 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -66,7 +66,7 @@ import static org.hamcrest.Matchers.everyItem; import static org.hamcrest.Matchers.in; import static org.mockito.Mockito.mock; -import static org.opensearch.tasks.TaskManager.TASK_ID; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; public class TaskManagerTests extends OpenSearchTestCase { private ThreadPool threadPool; @@ -96,7 +96,7 @@ public void testResultsServiceRetryTotalTime() { assertEquals(600000L, total); } - public void testAddTaskIdToThreadContext() { + /* public void testAddTaskIdToThreadContext() { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); final Task task = taskManager.register("transport", "test", new CancellableRequest("1")); String key = "KEY"; @@ -107,7 +107,7 @@ public void testAddTaskIdToThreadContext() { threadPool.getThreadContext().putTransient(key, value); threadPool.getThreadContext().addResponseHeader(key, value); - ThreadContext.StoredContext storedContext = taskManager.addTaskIdInThreadContext(task); + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); // All headers should be preserved and Task Id should also be included in thread context verifyThreadContextFixedHeaders(key, value); @@ -118,7 +118,7 @@ public void testAddTaskIdToThreadContext() { // Post restore only task id should be removed from the thread context verifyThreadContextFixedHeaders(key, value); assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); - } + }*/ private void verifyThreadContextFixedHeaders(String key, String value) { assertEquals(threadPool.getThreadContext().getHeader(key), value); diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java index d3f360e0d5414..92193553b2b9a 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java @@ -129,12 +129,12 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { } @Override - public ThreadContext.StoredContext addTaskIdInThreadContext(Task task) { + public ThreadContext.StoredContext taskExecutionStarted(Task task) { for (MockTaskManagerListener listener : listeners) { listener.onThreadContextUpdate(task, true); } - ThreadContext.StoredContext storedContext = super.addTaskIdInThreadContext(task); + ThreadContext.StoredContext storedContext = super.taskExecutionStarted(task); return () -> { for (MockTaskManagerListener listener : listeners) { listener.onThreadContextUpdate(task, false); From 23d639be3b353692878c5f66e5ff131c9d8666ac Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Tue, 5 Apr 2022 14:25:56 +0530 Subject: [PATCH 09/26] preserve task id in thread context even after stash Signed-off-by: Tushar Kharbanda --- .../common/util/concurrent/ThreadContext.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index d844a8f158ea4..52fa4957029e7 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -66,6 +66,7 @@ import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT; import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; /** * A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with @@ -134,16 +135,23 @@ public StoredContext stashContext() { * This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user. * Otherwise when context is stash, it should be empty. */ + + ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT; + if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) { - ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders( + threadContextStruct.putHeaders( MapBuilder.newMapBuilder() .put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID)) .immutableMap() ); - threadLocal.set(threadContextStruct); - } else { - threadLocal.set(DEFAULT_CONTEXT); } + + if (context.requestHeaders.containsKey(TASK_ID)) { + threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID)); + } + + threadLocal.set(threadContextStruct); + return () -> { // If the node and thus the threadLocal get closed while this task // is still executing, we don't want this runnable to fail with an From 930edaef41392ed3c6c218bd66e498c78352687d Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Tue, 5 Apr 2022 17:19:52 +0530 Subject: [PATCH 10/26] Add null check for resource tracking service Signed-off-by: Tushar Kharbanda --- server/src/main/java/org/opensearch/tasks/TaskManager.java | 5 +++-- .../cluster/node/tasks/TransportTasksActionTests.java | 7 ++----- .../test/java/org/opensearch/tasks/TaskManagerTests.java | 4 +--- .../opensearch/test/transport/MockTransportService.java | 5 ----- 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 6defc9a829c64..a7989cf8281d1 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -48,7 +48,6 @@ import org.opensearch.cluster.ClusterStateApplier; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.common.Nullable; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.common.settings.Settings; @@ -211,7 +210,7 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); - if (task.supportsResourceTracking()) { + if (taskResourceTrackingService.get() != null && task.supportsResourceTracking()) { taskResourceTrackingService.get().stopTracking(task); } @@ -463,6 +462,8 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { } public ThreadContext.StoredContext taskExecutionStarted(Task task) { + if (taskResourceTrackingService.get() == null) return () -> {}; + return taskResourceTrackingService.get().startTracking(task); } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index dc2fa00c22008..ae8b62676c124 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -66,7 +66,6 @@ import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; import org.opensearch.test.tasks.MockTaskManager; -import org.opensearch.test.tasks.MockTaskManagerListener; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -83,14 +82,12 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import static org.hamcrest.Matchers.containsInAnyOrder; import static org.opensearch.action.support.PlainActionFuture.newFuture; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.not; -import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; public class TransportTasksActionTests extends TaskManagerTestCase { @@ -651,7 +648,7 @@ protected void taskOperation(TestTasksRequest request, Task task, ActionListener assertEquals(0, responses.failureCount()); } -/* + /* public void testTaskIdPersistsInThreadContext() { Settings settings = Settings.builder().put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), true).build(); setupTestNodes(settings); @@ -700,7 +697,7 @@ protected void taskOperation(TestTasksRequest request, Task task, ActionListener assertEquals(expectedTaskIdInThreadContext[0], actualTaskIdInThreadContext[0]); assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray())); } -*/ + */ /** * This test starts nodes actions that blocks on all nodes. While node actions are blocked in the middle of execution diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index f099935f7be5a..1817e3b3a709c 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -39,7 +39,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ConcurrentCollections; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.RunnableTaskListenerFactory; import org.opensearch.threadpool.TestThreadPool; @@ -66,7 +65,6 @@ import static org.hamcrest.Matchers.everyItem; import static org.hamcrest.Matchers.in; import static org.mockito.Mockito.mock; -import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; public class TaskManagerTests extends OpenSearchTestCase { private ThreadPool threadPool; @@ -96,7 +94,7 @@ public void testResultsServiceRetryTotalTime() { assertEquals(600000L, total); } - /* public void testAddTaskIdToThreadContext() { + /* public void testAddTaskIdToThreadContext() { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); final Task task = taskManager.register("transport", "test", new CancellableRequest("1")); String key = "KEY"; diff --git a/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java index 637f198e3ed1a..9b9baebd540c3 100644 --- a/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/opensearch/test/transport/MockTransportService.java @@ -57,7 +57,6 @@ import org.opensearch.node.Node; import org.opensearch.plugins.Plugin; import org.opensearch.tasks.TaskManager; -import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.threadpool.ThreadPool; @@ -90,8 +89,6 @@ import java.util.function.Function; import java.util.function.Supplier; -import static org.mockito.Mockito.mock; - /** * A mock delegate service that allows to simulate different network topology failures. * Internally it maps TransportAddress objects to rules that inject failures. @@ -253,8 +250,6 @@ private MockTransportService( new StubbableConnectionManager(new ClusterConnectionManager(settings, transport)) ); this.original = transport.getDelegate(); - - this.taskManager.setTaskResourceTrackingService(mock(TaskResourceTrackingService.class)); } private static TransportAddress[] extractTransportAddresses(TransportService transportService) { From f68e91d5c70f840dfbdee6e08f0d0cb86db42971 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 16:36:24 +0530 Subject: [PATCH 11/26] Tracking service tests and minor refactoring Signed-off-by: Tushar Kharbanda --- .../tasks/list/TransportListTasksAction.java | 13 ++- .../org/opensearch/cluster/ClusterModule.java | 2 + .../main/java/org/opensearch/node/Node.java | 6 +- .../org/opensearch/tasks/TaskManager.java | 10 +- .../tasks/TaskResourceStatsUtil.java | 0 .../tasks/TaskResourceTrackingService.java | 57 ++++++----- .../node/tasks/ResourceAwareTasksTests.java | 78 +++++++++++++-- .../node/tasks/TaskManagerTestCase.java | 7 +- .../client/node/NodeClientHeadersTests.java | 3 - .../opensearch/tasks/TaskManagerTests.java | 30 ------ .../TaskResourceTrackingServiceTests.java | 97 +++++++++++++++++++ 11 files changed, 228 insertions(+), 75 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/tasks/TaskResourceStatsUtil.java create mode 100644 server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java index 960d82b106cb6..df448d2665434 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java @@ -42,6 +42,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskInfo; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -60,8 +61,15 @@ public static long waitForCompletionTimeout(TimeValue timeout) { private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30); + private final TaskResourceTrackingService taskResourceTrackingService; + @Inject - public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) { + public TransportListTasksAction( + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + TaskResourceTrackingService taskResourceTrackingService + ) { super( ListTasksAction.NAME, clusterService, @@ -72,6 +80,7 @@ public TransportListTasksAction(ClusterService clusterService, TransportService TaskInfo::new, ThreadPool.Names.MANAGEMENT ); + this.taskResourceTrackingService = taskResourceTrackingService; } @Override @@ -102,7 +111,7 @@ protected void processTasks(ListTasksRequest request, Consumer operation) taskManager.waitForTaskCompletion(task, timeoutNanos); }); } else { - operation = operation.andThen(taskManager::refreshTasksInfo); + operation = operation.andThen(taskResourceTrackingService::refreshResourceStats); } super.processTasks(request, operation); } diff --git a/server/src/main/java/org/opensearch/cluster/ClusterModule.java b/server/src/main/java/org/opensearch/cluster/ClusterModule.java index c85691b80d7c3..b9f3a2a99f0b7 100644 --- a/server/src/main/java/org/opensearch/cluster/ClusterModule.java +++ b/server/src/main/java/org/opensearch/cluster/ClusterModule.java @@ -94,6 +94,7 @@ import org.opensearch.script.ScriptMetadata; import org.opensearch.snapshots.SnapshotsInfoService; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.tasks.TaskResultsService; import java.util.ArrayList; @@ -394,6 +395,7 @@ protected void configure() { bind(NodeMappingRefreshAction.class).asEagerSingleton(); bind(MappingUpdatedAction.class).asEagerSingleton(); bind(TaskResultsService.class).asEagerSingleton(); + bind(TaskResourceTrackingService.class).asEagerSingleton(); bind(AllocationDeciders.class).toInstance(allocationDeciders); bind(ShardsAllocator.class).toInstance(shardsAllocator); } diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 254b2d85de050..aa95088dc1394 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -1062,11 +1062,7 @@ public Node start() throws NodeValidationException { transportService.getTaskManager().setTaskResultsService(injector.getInstance(TaskResultsService.class)); transportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(transportService)); - TaskResourceTrackingService taskResourceTrackingService = new TaskResourceTrackingService( - settings(), - clusterService.getClusterSettings(), - injector.getInstance(ThreadPool.class) - ); + TaskResourceTrackingService taskResourceTrackingService = injector.getInstance(TaskResourceTrackingService.class); transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); RunnableTaskListenerFactory runnableTaskListener = injector.getInstance(RunnableTaskListenerFactory.class); diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index a7989cf8281d1..09f435c971ecc 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -461,16 +461,18 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { throw new OpenSearchTimeoutException("Timed out waiting for completion of [{}]", task); } + /** + * Takes actions when a task is registered and its execution starts + * + * @param task getting executed. + * @return AutoCloseable to free up resources (clean up thread context) when task execution block returns + */ public ThreadContext.StoredContext taskExecutionStarted(Task task) { if (taskResourceTrackingService.get() == null) return () -> {}; return taskResourceTrackingService.get().startTracking(task); } - public void refreshTasksInfo(Task... tasks) { - taskResourceTrackingService.get().refreshResourceStats(tasks); - } - private static class CancellableTaskHolder { private final CancellableTask task; private boolean finished = false; diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceStatsUtil.java b/server/src/main/java/org/opensearch/tasks/TaskResourceStatsUtil.java deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index f49f9f50c90e1..cf8e07fc263a7 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -9,8 +9,7 @@ package org.opensearch.tasks; import com.sun.management.ThreadMXBean; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; +import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -32,8 +31,6 @@ */ public class TaskResourceTrackingService implements RunnableTaskExecutionListener { - private static final Logger logger = LogManager.getLogger(TaskManager.class); - public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( "task_resource_tracking.enabled", false, @@ -44,10 +41,11 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); - public final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + private final ThreadPool threadPool; private volatile boolean taskResourceTrackingEnabled; - private ThreadPool threadPool; + @Inject public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool) { this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); this.threadPool = threadPool; @@ -55,12 +53,25 @@ public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSet clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled); } + public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { + this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; + } + public boolean isTaskResourceTrackingEnabled() { return taskResourceTrackingEnabled && threadMXBean.isThreadAllocatedMemorySupported() && threadMXBean.isThreadAllocatedMemoryEnabled(); } + /** + * Executes logic only if task supports resource tracking and resource tracking setting is enabled. + *

+ * 1. Starts tracking the task in map of resourceAwareTasks. + * 2. Adds Task Id in thread context to make sure it's available while task is processed across multiple threads. + * + * @param task for which resources needs to be tracked + * @return Autocloseable stored context to restore ThreadContext to the state before this method changed it. + */ public ThreadContext.StoredContext startTracking(Task task) { if (task.supportsResourceTracking() == false || isTaskResourceTrackingEnabled() == false) { return () -> {}; @@ -72,15 +83,15 @@ public ThreadContext.StoredContext startTracking(Task task) { } /** - * unregisters tasks registered earlier. + * Stops tracking task registered earlier for tracking. *

* It doesn't have feature enabled check to avoid any issues if setting was disable while the task was in progress. *

* It's also responsible to stop tracking the current thread's resources against this task if not already done. - * This happens when the thread handling the request itself calls the unregister method. So in this case unregister + * This happens when the thread executing the request logic itself calls the unregister method. So in this case unregister * happens before runnable finishes. * - * @param task + * @param task task which has finished and doesn't need resource tracking. */ public void stopTracking(Task task) { if (isThreadWorkingOnTask(task, Thread.currentThread().getId())) { @@ -92,15 +103,12 @@ public void stopTracking(Task task) { resourceAwareTasks.remove(task.getId()); } - private boolean validateNoActiveThread(Task task) { - for (List threadResourceInfos : task.getResourceStats().values()) { - for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { - if (threadResourceInfo.isActive()) return false; - } - } - return true; - } - + /** + * Refreshes the resource stats for the tasks provided by looking into which threads are actively working on these + * and how much resources these have consumed till now. + * + * @param tasks for which resource stats needs to be refreshed. + */ public void refreshResourceStats(Task... tasks) { if (isTaskResourceTrackingEnabled() == false) { return; @@ -123,10 +131,6 @@ private void refreshResourceStats(Task resourceAwareTask) { }); } - public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) { - this.taskResourceTrackingEnabled = taskResourceTrackingEnabled; - } - /** * Called when a thread starts working on a task's runnable. * @@ -178,6 +182,15 @@ private boolean isThreadWorkingOnTask(Task task, long threadId) { return false; } + private boolean validateNoActiveThread(Task task) { + for (List threadResourceInfos : task.getResourceStats().values()) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { + if (threadResourceInfo.isActive()) return false; + } + } + return true; + } + /** * Adds Task Id in the ThreadContext. *

diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index 564ccfa73332f..5e77064002258 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -26,10 +26,14 @@ import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskInfo; +import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.test.tasks.MockTaskManagerListener; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; @@ -38,6 +42,9 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + public class ResourceAwareTasksTests extends TaskManagerTestCase { // For every task there's a general overhead before and after the actual task operation code is executed. @@ -242,7 +249,7 @@ private static class TaskTestContext { } public void testBasicTaskResourceTracking() throws Exception { - setup(true); + setup(true, false); final AtomicReference throwableReference = new AtomicReference<>(); final AtomicReference responseReference = new AtomicReference<>(); @@ -298,7 +305,7 @@ public void onFailure(Exception e) { } public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { - setup(true); + setup(true, false); final AtomicReference throwableReference = new AtomicReference<>(); final AtomicReference responseReference = new AtomicReference<>(); @@ -366,7 +373,7 @@ public void onFailure(Exception e) { } public void testTaskResourceTrackingDisabled() throws Exception { - setup(false); + setup(false, false); final AtomicReference throwableReference = new AtomicReference<>(); final AtomicReference responseReference = new AtomicReference<>(); @@ -399,7 +406,7 @@ public void onFailure(Exception e) { } public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Exception { - setup(true); + setup(true, false); final AtomicReference throwableReference = new AtomicReference<>(); final AtomicReference responseReference = new AtomicReference<>(); @@ -455,7 +462,7 @@ public void onFailure(Exception e) { } public void testTaskResourceTrackingEnabledWhileTaskInProgress() throws Exception { - setup(false); + setup(false, false); final AtomicReference throwableReference = new AtomicReference<>(); final AtomicReference responseReference = new AtomicReference<>(); @@ -492,7 +499,7 @@ public void onFailure(Exception e) { } public void testOnDemandRefreshWhileFetchingTasks() throws InterruptedException { - setup(true); + setup(true, false); final AtomicReference throwableReference = new AtomicReference<>(); final AtomicReference responseReference = new AtomicReference<>(); @@ -535,8 +542,63 @@ public void onFailure(Exception e) { assertTasksRequestFinishedSuccessfully(resourceTasks.size(), responseReference.get(), throwableReference.get()); } - private void setup(boolean resourceTrackingEnabled) { - Settings settings = Settings.builder().put("task_resource_tracking.enabled", resourceTrackingEnabled).build(); + public void testTaskIdPersistsInThreadContext() throws InterruptedException { + setup(true, true); + + final List taskIdsAddedToThreadContext = new ArrayList<>(); + final List taskIdsRemovedFromThreadContext = new ArrayList<>(); + AtomicLong actualTaskIdInThreadContext = new AtomicLong(-1); + AtomicLong expectedTaskIdInThreadContext = new AtomicLong(-2); + + ((MockTaskManager) testNodes[0].transportService.getTaskManager()).addListener(new MockTaskManagerListener() { + @Override + public void waitForTaskCompletion(Task task) {} + + @Override + public void taskExecutionStarted(Task task, Boolean closeableInvoked) { + if (closeableInvoked) { + taskIdsRemovedFromThreadContext.add(task.getId()); + } else { + taskIdsAddedToThreadContext.add(task.getId()); + } + } + + @Override + public void onTaskRegistered(Task task) {} + + @Override + public void onTaskUnregistered(Task task) { + if (task.getAction().equals("internal:resourceAction[n]")) { + expectedTaskIdInThreadContext.set(task.getId()); + actualTaskIdInThreadContext.set(threadPool.getThreadContext().getTransient(TASK_ID)); + } + } + }); + + TaskTestContext taskTestContext = new TaskTestContext(); + startResourceAwareNodesAction(testNodes[0], false, taskTestContext, new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + taskTestContext.requestCompleteLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + taskTestContext.requestCompleteLatch.countDown(); + } + }); + + taskTestContext.requestCompleteLatch.await(); + + assertEquals(expectedTaskIdInThreadContext.get(), actualTaskIdInThreadContext.get()); + assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray())); + } + + private void setup(boolean resourceTrackingEnabled, boolean useMockTaskManager) { + Settings settings = Settings.builder() + .put("task_resource_tracking.enabled", resourceTrackingEnabled) + .put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), useMockTaskManager) + .build(); setupTestNodes(settings); connectNodes(testNodes[0]); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index bca1115cce570..a3584d8a9b7b1 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -232,7 +232,12 @@ protected TaskManager createTaskManager(Settings settings, ThreadPool threadPool taskResourceTrackingService = new TaskResourceTrackingService(settings, clusterService.getClusterSettings(), threadPool); transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); ActionFilters actionFilters = new ActionFilters(emptySet()); - transportListTasksAction = new TransportListTasksAction(clusterService, transportService, actionFilters); + transportListTasksAction = new TransportListTasksAction( + clusterService, + transportService, + actionFilters, + taskResourceTrackingService + ); transportCancelTasksAction = new TransportCancelTasksAction(clusterService, transportService, actionFilters); transportService.acceptIncomingRequests(); } diff --git a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java index dddd7fd1350dd..cb9e3a6a19388 100644 --- a/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java +++ b/server/src/test/java/org/opensearch/client/node/NodeClientHeadersTests.java @@ -43,14 +43,11 @@ import org.opensearch.common.settings.Settings; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskManager; -import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import java.util.Collections; import java.util.HashMap; -import static org.mockito.Mockito.mock; - public class NodeClientHeadersTests extends AbstractClientHeadersTestCase { private static final ActionFilters EMPTY_FILTERS = new ActionFilters(Collections.emptySet()); diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 1817e3b3a709c..1186d630b4bed 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -94,36 +94,6 @@ public void testResultsServiceRetryTotalTime() { assertEquals(600000L, total); } - /* public void testAddTaskIdToThreadContext() { - final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); - final Task task = taskManager.register("transport", "test", new CancellableRequest("1")); - String key = "KEY"; - String value = "VALUE"; - - // Prepare thread context - threadPool.getThreadContext().putHeader(key, value); - threadPool.getThreadContext().putTransient(key, value); - threadPool.getThreadContext().addResponseHeader(key, value); - - ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); - - // All headers should be preserved and Task Id should also be included in thread context - verifyThreadContextFixedHeaders(key, value); - assertEquals((long) threadPool.getThreadContext().getTransient(TASK_ID), task.getId()); - - storedContext.restore(); - - // Post restore only task id should be removed from the thread context - verifyThreadContextFixedHeaders(key, value); - assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); - }*/ - - private void verifyThreadContextFixedHeaders(String key, String value) { - assertEquals(threadPool.getThreadContext().getHeader(key), value); - assertEquals(threadPool.getThreadContext().getTransient(key), value); - assertEquals(threadPool.getThreadContext().getResponseHeaders().get(key).get(0), value); - } - public void testTrackingChannelTask() throws Exception { final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); Set cancelledTasks = ConcurrentCollections.newConcurrentSet(); diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java new file mode 100644 index 0000000000000..15ddc253dfa6d --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests; +import org.opensearch.action.search.SearchTask; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.RunnableTaskListenerFactory; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.util.HashMap; + +import static org.opensearch.tasks.ResourceStats.MEMORY; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; + +public class TaskResourceTrackingServiceTests extends OpenSearchTestCase { + + private ThreadPool threadPool; + private TaskResourceTrackingService taskResourceTrackingService; + + @Before + public void setup() { + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), new RunnableTaskListenerFactory()); + taskResourceTrackingService = new TaskResourceTrackingService( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), + threadPool + ); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); + } + + public void testThreadContextUpdateOnTrackingStart() { + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + + Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()); + + String key = "KEY"; + String value = "VALUE"; + + // Prepare thread context + threadPool.getThreadContext().putHeader(key, value); + threadPool.getThreadContext().putTransient(key, value); + threadPool.getThreadContext().addResponseHeader(key, value); + + ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task); + + // All headers should be preserved and Task Id should also be included in thread context + verifyThreadContextFixedHeaders(key, value); + assertEquals((long) threadPool.getThreadContext().getTransient(TASK_ID), task.getId()); + + storedContext.restore(); + + // Post restore only task id should be removed from the thread context + verifyThreadContextFixedHeaders(key, value); + assertNull(threadPool.getThreadContext().getTransient(TASK_ID)); + } + + public void testStopTrackingHandlesCurrentActiveThread() { + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + Task task = new SearchTask(1, "test", "test", () -> "Test", TaskId.EMPTY_TASK_ID, new HashMap<>()); + ThreadContext.StoredContext storedContext = taskResourceTrackingService.startTracking(task); + long threadId = Thread.currentThread().getId(); + taskResourceTrackingService.taskExecutionStartedOnThread(task.getId(), threadId); + + assertTrue(task.getResourceStats().get(threadId).get(0).isActive()); + assertEquals(0, task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue()); + + taskResourceTrackingService.stopTracking(task); + + // Makes sure stop tracking marks the current active thread inactive and refreshes the resource stats before returning. + assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); + assertTrue(task.getResourceStats().get(threadId).get(0).getResourceUsageInfo().getStatsInfo().get(MEMORY).getTotalValue() > 0); + } + + private void verifyThreadContextFixedHeaders(String key, String value) { + assertEquals(threadPool.getThreadContext().getHeader(key), value); + assertEquals(threadPool.getThreadContext().getTransient(key), value); + assertEquals(threadPool.getThreadContext().getResponseHeaders().get(key).get(0), value); + } + +} From e788dfdaf5ecf296aeee74a2d6740a697b0e3dca Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 16:37:10 +0530 Subject: [PATCH 12/26] Preserve task id fix with test Signed-off-by: Tushar Kharbanda --- .../common/util/concurrent/ThreadContext.java | 6 +++--- .../common/util/concurrent/ThreadContextTests.java | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 52fa4957029e7..35d7d925ce106 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -139,15 +139,15 @@ public StoredContext stashContext() { ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT; if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) { - threadContextStruct.putHeaders( + threadContextStruct = threadContextStruct.putHeaders( MapBuilder.newMapBuilder() .put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID)) .immutableMap() ); } - if (context.requestHeaders.containsKey(TASK_ID)) { - threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID)); + if (context.transientHeaders.containsKey(TASK_ID)) { + threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID)); } threadLocal.set(threadContextStruct); diff --git a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java index 9c70accaca3e4..64286e47b4966 100644 --- a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java @@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.sameInstance; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; public class ThreadContextTests extends OpenSearchTestCase { @@ -154,6 +155,15 @@ public void testNewContextWithClearedTransients() { assertEquals(1, threadContext.getResponseHeaders().get("baz").size()); } + public void testStashContextWithPreservedTransients() { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + threadContext.putTransient("foo", "bar"); + threadContext.putTransient(TASK_ID, 1); + threadContext.stashContext(); + assertNull(threadContext.getTransient("foo")); + assertEquals(1, (int) threadContext.getTransient(TASK_ID)); + } + public void testStashWithOrigin() { final String origin = randomAlphaOfLengthBetween(4, 16); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); From 39dfc22822939098693b769e4ee1869f176cf825 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 16:37:56 +0530 Subject: [PATCH 13/26] Minor test changes and Task tracking call update Signed-off-by: Tushar Kharbanda --- .../admin/cluster/node/tasks/TasksIT.java | 4 +- .../action/support/TransportAction.java | 37 ++++++++------ .../transport/RequestHandlerRegistry.java | 4 +- .../tasks/RecordingTaskManagerListener.java | 2 +- .../node/tasks/TransportTasksActionTests.java | 51 ------------------- .../test/tasks/MockTaskManager.java | 4 +- .../test/tasks/MockTaskManagerListener.java | 6 +-- 7 files changed, 29 insertions(+), 79 deletions(-) diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java index 4042dc27338fc..c74f992970545 100644 --- a/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/cluster/node/tasks/TasksIT.java @@ -472,7 +472,7 @@ public void onTaskUnregistered(Task task) {} public void waitForTaskCompletion(Task task) {} @Override - public void onThreadContextUpdate(Task task, Boolean taskIdAdded) {} + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} }); } // Need to run the task in a separate thread because node client's .execute() is blocked by our task listener @@ -655,7 +655,7 @@ public void waitForTaskCompletion(Task task) { } @Override - public void onThreadContextUpdate(Task task, Boolean taskIdAdded) {} + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} @Override public void onTaskRegistered(Task task) {} diff --git a/server/src/main/java/org/opensearch/action/support/TransportAction.java b/server/src/main/java/org/opensearch/action/support/TransportAction.java index 5ec218f54dec0..83fca715c7e28 100644 --- a/server/src/main/java/org/opensearch/action/support/TransportAction.java +++ b/server/src/main/java/org/opensearch/action/support/TransportAction.java @@ -138,25 +138,30 @@ public final Task execute(Request request, TaskListener listener) { unregisterChildNode.close(); throw e; } - execute(task, request, new ActionListener() { - @Override - public void onResponse(Response response) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onResponse(task, response); + ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); + try { + execute(task, request, new ActionListener() { + @Override + public void onResponse(Response response) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onResponse(task, response); + } } - } - @Override - public void onFailure(Exception e) { - try { - Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); - } finally { - listener.onFailure(task, e); + @Override + public void onFailure(Exception e) { + try { + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); + } finally { + listener.onFailure(task, e); + } } - } - }); + }); + } finally { + storedContext.close(); + } return task; } diff --git a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java index 4b987036909e4..73be6e5b601e9 100644 --- a/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java +++ b/server/src/main/java/org/opensearch/transport/RequestHandlerRegistry.java @@ -82,7 +82,7 @@ public Request newRequest(StreamInput in) throws IOException { public void processMessageReceived(Request request, TransportChannel channel) throws Exception { final Task task = taskManager.register(channel.getChannelType(), action, request); - ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task); + ThreadContext.StoredContext contextToRestore = taskManager.taskExecutionStarted(task); Releasable unregisterTask = () -> taskManager.unregister(task); try { @@ -102,7 +102,7 @@ public void processMessageReceived(Request request, TransportChannel channel) th unregisterTask = null; } finally { Releasables.close(unregisterTask); - storedContext.restore(); + contextToRestore.restore(); } } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java index 7c35a3c79d66f..9bd44185baf24 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/RecordingTaskManagerListener.java @@ -76,7 +76,7 @@ public synchronized void onTaskUnregistered(Task task) { public void waitForTaskCompletion(Task task) {} @Override - public void onThreadContextUpdate(Task task, Boolean taskIdAdded) {} + public void taskExecutionStarted(Task task, Boolean closeableInvoked) {} public synchronized List> getEvents() { return Collections.unmodifiableList(new ArrayList<>(events)); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java index ae8b62676c124..7590bf88eeca0 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TransportTasksActionTests.java @@ -648,57 +648,6 @@ protected void taskOperation(TestTasksRequest request, Task task, ActionListener assertEquals(0, responses.failureCount()); } - /* - public void testTaskIdPersistsInThreadContext() { - Settings settings = Settings.builder().put(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.getKey(), true).build(); - setupTestNodes(settings); - connectNodes(testNodes); - - final List taskIdsAddedToThreadContext = new ArrayList<>(); - final List taskIdsRemovedFromThreadContext = new ArrayList<>(); - final long[] actualTaskIdInThreadContext = new long[1]; - final long[] expectedTaskIdInThreadContext = new long[1]; - - ((MockTaskManager) testNodes[0].transportService.getTaskManager()).addListener(new MockTaskManagerListener() { - @Override - public void waitForTaskCompletion(Task task) {} - - @Override - public void onThreadContextUpdate(Task task, Boolean taskIdAdded) { - if (taskIdAdded) { - taskIdsAddedToThreadContext.add(task.getId()); - } else { - taskIdsRemovedFromThreadContext.add(task.getId()); - } - } - - @Override - public void onTaskRegistered(Task task) {} - - @Override - public void onTaskUnregistered(Task task) { - if (task.getAction().equals("action1")) { - expectedTaskIdInThreadContext[0] = task.getId(); - actualTaskIdInThreadContext[0] = threadPool.getThreadContext().getTransient(TASK_ID); - } - } - }); - - TestTasksAction action = new TestTasksAction("action1", testNodes[0].clusterService, testNodes[0].transportService) { - @Override - protected void taskOperation(TestTasksRequest request, Task task, ActionListener listener) { - listener.onResponse(new TestTaskResponse(testNodes[0].getNodeId())); - } - }; - TestTasksRequest testTasksRequest = new TestTasksRequest(); - testTasksRequest.setActions("action1"); - ActionTestUtils.executeBlocking(action, testTasksRequest); - - assertEquals(expectedTaskIdInThreadContext[0], actualTaskIdInThreadContext[0]); - assertThat(taskIdsAddedToThreadContext, containsInAnyOrder(taskIdsRemovedFromThreadContext.toArray())); - } - */ - /** * This test starts nodes actions that blocks on all nodes. While node actions are blocked in the middle of execution * it executes a tasks action that targets these blocked node actions. The test verifies that task actions are only diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java index 92193553b2b9a..677ec7a0a6600 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManager.java @@ -131,13 +131,13 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { @Override public ThreadContext.StoredContext taskExecutionStarted(Task task) { for (MockTaskManagerListener listener : listeners) { - listener.onThreadContextUpdate(task, true); + listener.taskExecutionStarted(task, false); } ThreadContext.StoredContext storedContext = super.taskExecutionStarted(task); return () -> { for (MockTaskManagerListener listener : listeners) { - listener.onThreadContextUpdate(task, false); + listener.taskExecutionStarted(task, true); } storedContext.restore(); }; diff --git a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java index 1736695d6e33e..f15f878995aa2 100644 --- a/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java +++ b/test/framework/src/main/java/org/opensearch/test/tasks/MockTaskManagerListener.java @@ -44,10 +44,6 @@ public interface MockTaskManagerListener { void waitForTaskCompletion(Task task); - /** - * - * @param taskIdAdded if false then task id is removed from the thread context. Null if no change - */ - void onThreadContextUpdate(Task task, Boolean taskIdAdded); + void taskExecutionStarted(Task task, Boolean closeableInvoked); } From 4b26d2d3d5723fe6b8502ae3266b8be860fa9f6b Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 18:53:54 +0530 Subject: [PATCH 14/26] Fix Auto Queue executor method's signature Signed-off-by: Tushar Kharbanda --- .../util/concurrent/OpenSearchExecutors.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java index 2c2ded62fc25d..a20f196aa2f5d 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java @@ -174,6 +174,31 @@ public static OpenSearchThreadPoolExecutor newFixed( ); } + public static OpenSearchThreadPoolExecutor newAutoQueueFixed( + String name, + int size, + int initialQueueCapacity, + int minQueueSize, + int maxQueueSize, + int frameSize, + TimeValue targetedResponseTime, + ThreadFactory threadFactory, + ThreadContext contextHolder + ) { + return newAutoQueueFixed( + name, + size, + initialQueueCapacity, + minQueueSize, + maxQueueSize, + frameSize, + targetedResponseTime, + threadFactory, + contextHolder, + null + ); + } + /** * Return a new executor that will automatically adjust the queue size based on queue throughput. * From c76ce40e7071f4dc14a1a1c35506214f59220a1b Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 18:54:48 +0530 Subject: [PATCH 15/26] Make task runnable task listener factory implement consumer Signed-off-by: Tushar Kharbanda --- server/src/main/java/org/opensearch/node/Node.java | 2 +- .../threadpool/RunnableTaskListenerFactory.java | 8 ++++---- .../admin/cluster/node/tasks/ResourceAwareTasksTests.java | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index aa95088dc1394..49f231e6f5057 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -1066,7 +1066,7 @@ public Node start() throws NodeValidationException { transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); RunnableTaskListenerFactory runnableTaskListener = injector.getInstance(RunnableTaskListenerFactory.class); - runnableTaskListener.apply(taskResourceTrackingService); + runnableTaskListener.accept(taskResourceTrackingService); transportService.start(); assert localNodeFactory.getNode() != null; diff --git a/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java b/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java index 3848723f1b40a..40f7663beb860 100644 --- a/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java +++ b/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java @@ -10,20 +10,20 @@ import org.apache.lucene.util.SetOnce; -import java.util.function.Function; +import java.util.function.Consumer; -public class RunnableTaskListenerFactory implements Function { +public class RunnableTaskListenerFactory implements Consumer { private final SetOnce listener = new SetOnce<>(); @Override - public RunnableTaskExecutionListener apply(RunnableTaskExecutionListener runnableTaskExecutionListener) { + public void accept(RunnableTaskExecutionListener runnableTaskExecutionListener) { listener.set(runnableTaskExecutionListener); - return listener.get(); } public RunnableTaskExecutionListener get() { assert listener.get() != null; return listener.get(); } + } diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index 5e77064002258..394b32cd70611 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -602,7 +602,7 @@ private void setup(boolean resourceTrackingEnabled, boolean useMockTaskManager) setupTestNodes(settings); connectNodes(testNodes[0]); - runnableTaskListener.apply(testNodes[0].taskResourceTrackingService); + runnableTaskListener.accept(testNodes[0].taskResourceTrackingService); } private Throwable findActualException(Exception e) { From 6915d171a25da0a3150e4f6e92941f7f84ff42ec Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 18:55:21 +0530 Subject: [PATCH 16/26] Use reflection for ThreadMXBean Signed-off-by: Tushar Kharbanda --- .../tasks/TaskResourceTrackingService.java | 44 +++++++++++++++---- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index cf8e07fc263a7..8a9e73253636d 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -8,7 +8,6 @@ package org.opensearch.tasks; -import com.sun.management.ThreadMXBean; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -20,6 +19,8 @@ import org.opensearch.threadpool.ThreadPool; import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.lang.reflect.Method; import java.util.Collections; import java.util.List; import java.util.Map; @@ -39,7 +40,26 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene ); public static final String TASK_ID = "TASK_ID"; - private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); + private static ThreadMXBean threadMXBean; + private static Method isThreadAllocatedMemorySupported; + private static Method isThreadAllocatedMemoryEnabled; + private static Method getThreadAllocatedBytes; + private static Method getThreadCpuTime; + + static { + try { + isThreadAllocatedMemorySupported = Class.forName("com.sun.management.ThreadMXBean").getMethod("isThreadAllocatedMemorySupported"); + isThreadAllocatedMemoryEnabled = Class.forName("com.sun.management.ThreadMXBean").getMethod("isThreadAllocatedMemoryEnabled"); + getThreadAllocatedBytes = Class.forName("com.sun.management.ThreadMXBean").getMethod("getThreadAllocatedBytes", long.class); + getThreadCpuTime = Class.forName("com.sun.management.ThreadMXBean").getMethod("getThreadCpuTime", long.class); + threadMXBean = ManagementFactory.getThreadMXBean(); + } catch (Exception e) { + isThreadAllocatedMemorySupported = null; + isThreadAllocatedMemoryEnabled = null; + getThreadAllocatedBytes = null; + getThreadCpuTime = null; + } + } private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); private final ThreadPool threadPool; @@ -58,9 +78,13 @@ public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) } public boolean isTaskResourceTrackingEnabled() { - return taskResourceTrackingEnabled - && threadMXBean.isThreadAllocatedMemorySupported() - && threadMXBean.isThreadAllocatedMemoryEnabled(); + try { + return taskResourceTrackingEnabled + && (boolean) isThreadAllocatedMemorySupported.invoke(threadMXBean) + && (boolean) isThreadAllocatedMemoryEnabled.invoke(threadMXBean); + } catch (Exception e) { + return false; + } } /** @@ -163,12 +187,16 @@ public Map getResourceAwareTasks() { } private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { + long bytes = 0; + try { + bytes = (long) getThreadAllocatedBytes.invoke(threadMXBean, threadId); + } catch (Exception e) {} + ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( - ResourceStats.MEMORY, - threadMXBean.getThreadAllocatedBytes(threadId) + ResourceStats.MEMORY, bytes ); ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); - return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; + return new ResourceUsageMetric[]{currentMemoryUsage, currentCPUUsage}; } private boolean isThreadWorkingOnTask(Task task, long threadId) { From 576a4779338017682f4bf940b6c746f1cb63923b Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 19:40:02 +0530 Subject: [PATCH 17/26] Formatting Signed-off-by: Tushar Kharbanda --- .../opensearch/tasks/TaskResourceTrackingService.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 8a9e73253636d..fb779182c0a74 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -48,7 +48,8 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene static { try { - isThreadAllocatedMemorySupported = Class.forName("com.sun.management.ThreadMXBean").getMethod("isThreadAllocatedMemorySupported"); + isThreadAllocatedMemorySupported = Class.forName("com.sun.management.ThreadMXBean") + .getMethod("isThreadAllocatedMemorySupported"); isThreadAllocatedMemoryEnabled = Class.forName("com.sun.management.ThreadMXBean").getMethod("isThreadAllocatedMemoryEnabled"); getThreadAllocatedBytes = Class.forName("com.sun.management.ThreadMXBean").getMethod("getThreadAllocatedBytes", long.class); getThreadCpuTime = Class.forName("com.sun.management.ThreadMXBean").getMethod("getThreadCpuTime", long.class); @@ -192,11 +193,9 @@ private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { bytes = (long) getThreadAllocatedBytes.invoke(threadMXBean, threadId); } catch (Exception e) {} - ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( - ResourceStats.MEMORY, bytes - ); + ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric(ResourceStats.MEMORY, bytes); ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); - return new ResourceUsageMetric[]{currentMemoryUsage, currentCPUUsage}; + return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; } private boolean isThreadWorkingOnTask(Task task, long threadId) { From 046c652c3071f89b6c8f54cef6cbbb3062e888fc Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 7 Apr 2022 20:24:18 +0530 Subject: [PATCH 18/26] Replace RunnableTaskExecutionListenerFactory with AtomicReference Signed-off-by: Tushar Kharbanda --- .../util/concurrent/OpenSearchExecutors.java | 5 ++-- .../main/java/org/opensearch/node/Node.java | 11 ++++--- .../AutoQueueAdjustingExecutorBuilder.java | 5 ++-- .../RunnableTaskListenerFactory.java | 29 ------------------- .../threadpool/TaskAwareRunnable.java | 9 ++++-- .../org/opensearch/threadpool/ThreadPool.java | 3 +- .../node/tasks/ResourceAwareTasksTests.java | 2 +- .../node/tasks/TaskManagerTestCase.java | 7 +++-- .../opensearch/tasks/TaskManagerTests.java | 7 +++-- .../TaskResourceTrackingServiceTests.java | 4 +-- .../opensearch/threadpool/TestThreadPool.java | 9 ++++-- 11 files changed, 38 insertions(+), 53 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java index a20f196aa2f5d..9e28bb2b795c3 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/OpenSearchExecutors.java @@ -40,7 +40,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.node.Node; -import org.opensearch.threadpool.RunnableTaskListenerFactory; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.TaskAwareRunnable; import java.util.List; @@ -57,6 +57,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; public class OpenSearchExecutors { @@ -218,7 +219,7 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed( TimeValue targetedResponseTime, ThreadFactory threadFactory, ThreadContext contextHolder, - RunnableTaskListenerFactory runnableTaskListener + AtomicReference runnableTaskListener ) { if (initialQueueCapacity <= 0) { throw new IllegalArgumentException( diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 49f231e6f5057..7c8d67394bf20 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -38,7 +38,7 @@ import org.apache.lucene.util.SetOnce; import org.opensearch.index.IndexingPressureService; import org.opensearch.tasks.TaskResourceTrackingService; -import org.opensearch.threadpool.RunnableTaskListenerFactory; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.watcher.ResourceWatcherService; import org.opensearch.Assertions; import org.opensearch.Build; @@ -215,6 +215,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.UnaryOperator; import java.util.stream.Collectors; @@ -326,6 +327,7 @@ public static class DiscoverySettings { private final LocalNodeFactory localNodeFactory; private final NodeService nodeService; final NamedWriteableRegistry namedWriteableRegistry; + private final AtomicReference runnableTaskListener; public Node(Environment environment) { this(environment, Collections.emptyList(), true); @@ -435,7 +437,7 @@ protected Node( final List> executorBuilders = pluginsService.getExecutorBuilders(settings); - RunnableTaskListenerFactory runnableTaskListener = new RunnableTaskListenerFactory(); + runnableTaskListener = new AtomicReference<>(); final ThreadPool threadPool = new ThreadPool(settings, runnableTaskListener, executorBuilders.toArray(new ExecutorBuilder[0])); resourcesToClose.add(() -> ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS)); final ResourceWatcherService resourceWatcherService = new ResourceWatcherService(settings, threadPool); @@ -943,7 +945,6 @@ protected Node( b.bind(ShardLimitValidator.class).toInstance(shardLimitValidator); b.bind(FsHealthService.class).toInstance(fsHealthService); b.bind(SystemIndices.class).toInstance(systemIndices); - b.bind(RunnableTaskListenerFactory.class).toInstance(runnableTaskListener); }); injector = modules.createInjector(); @@ -1064,9 +1065,7 @@ public Node start() throws NodeValidationException { TaskResourceTrackingService taskResourceTrackingService = injector.getInstance(TaskResourceTrackingService.class); transportService.getTaskManager().setTaskResourceTrackingService(taskResourceTrackingService); - - RunnableTaskListenerFactory runnableTaskListener = injector.getInstance(RunnableTaskListenerFactory.class); - runnableTaskListener.accept(taskResourceTrackingService); + runnableTaskListener.set(taskResourceTrackingService); transportService.start(); assert localNodeFactory.getNode() != null; diff --git a/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java b/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java index ca018d46260d6..55b92c5d8bfcb 100644 --- a/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java +++ b/server/src/main/java/org/opensearch/threadpool/AutoQueueAdjustingExecutorBuilder.java @@ -48,6 +48,7 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicReference; /** * A builder for executors that automatically adjust the queue length as needed, depending on @@ -61,7 +62,7 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder maxQueueSizeSetting; private final Setting targetedResponseTimeSetting; private final Setting frameSizeSetting; - private final RunnableTaskListenerFactory runnableTaskListener; + private final AtomicReference runnableTaskListener; AutoQueueAdjustingExecutorBuilder( final Settings settings, @@ -83,7 +84,7 @@ public final class AutoQueueAdjustingExecutorBuilder extends ExecutorBuilder runnableTaskListener ) { super(name); final String prefix = "thread_pool." + name; diff --git a/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java b/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java deleted file mode 100644 index 40f7663beb860..0000000000000 --- a/server/src/main/java/org/opensearch/threadpool/RunnableTaskListenerFactory.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.threadpool; - -import org.apache.lucene.util.SetOnce; - -import java.util.function.Consumer; - -public class RunnableTaskListenerFactory implements Consumer { - - private final SetOnce listener = new SetOnce<>(); - - @Override - public void accept(RunnableTaskExecutionListener runnableTaskExecutionListener) { - listener.set(runnableTaskExecutionListener); - } - - public RunnableTaskExecutionListener get() { - assert listener.get() != null; - return listener.get(); - } - -} diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java index 7a0b55a2cf011..ed3c0d7523d17 100644 --- a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -14,6 +14,7 @@ import org.opensearch.common.util.concurrent.WrappedRunnable; import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; import static java.lang.Thread.currentThread; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; @@ -28,9 +29,13 @@ public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnab private final Runnable original; private final ThreadContext threadContext; - private final RunnableTaskListenerFactory runnableTaskListener; + private final AtomicReference runnableTaskListener; - public TaskAwareRunnable(ThreadContext threadContext, final Runnable original, final RunnableTaskListenerFactory runnableTaskListener) { + public TaskAwareRunnable( + ThreadContext threadContext, + final Runnable original, + final AtomicReference runnableTaskListener + ) { this.original = original; this.threadContext = threadContext; this.runnableTaskListener = runnableTaskListener; diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index 4371eb0bf617b..5e8f515f6c577 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -68,6 +68,7 @@ import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import static java.util.Collections.unmodifiableMap; @@ -189,7 +190,7 @@ public ThreadPool(final Settings settings, final ExecutorBuilder... customBui public ThreadPool( final Settings settings, - final RunnableTaskListenerFactory runnableTaskListener, + final AtomicReference runnableTaskListener, final ExecutorBuilder... customBuilders ) { assert Node.NODE_NAME_SETTING.exists(settings); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index 394b32cd70611..e3dd9ff616380 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -602,7 +602,7 @@ private void setup(boolean resourceTrackingEnabled, boolean useMockTaskManager) setupTestNodes(settings); connectNodes(testNodes[0]); - runnableTaskListener.accept(testNodes[0].taskResourceTrackingService); + runnableTaskListener.set(testNodes[0].taskResourceTrackingService); } private Throwable findActualException(Exception e) { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java index a3584d8a9b7b1..51fc5d80f2de3 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java @@ -62,7 +62,7 @@ import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; -import org.opensearch.threadpool.RunnableTaskListenerFactory; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -76,6 +76,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import static java.util.Collections.emptyMap; @@ -91,11 +92,11 @@ public abstract class TaskManagerTestCase extends OpenSearchTestCase { protected ThreadPool threadPool; protected TestNode[] testNodes; protected int nodesCount; - protected RunnableTaskListenerFactory runnableTaskListener; + protected AtomicReference runnableTaskListener; @Before public void setupThreadPool() { - runnableTaskListener = new RunnableTaskListenerFactory(); + runnableTaskListener = new AtomicReference<>(); threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } diff --git a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java index 1186d630b4bed..ab49109eb8247 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskManagerTests.java @@ -40,7 +40,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.RunnableTaskListenerFactory; +import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.FakeTcpChannel; @@ -60,6 +60,7 @@ import java.util.Set; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.everyItem; @@ -68,11 +69,11 @@ public class TaskManagerTests extends OpenSearchTestCase { private ThreadPool threadPool; - private RunnableTaskListenerFactory runnableTaskListener; + private AtomicReference runnableTaskListener; @Before public void setupThreadPool() { - runnableTaskListener = new RunnableTaskListenerFactory(); + runnableTaskListener = new AtomicReference<>(); threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), runnableTaskListener); } diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java index 15ddc253dfa6d..8ba23c5d3219c 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java @@ -16,11 +16,11 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.RunnableTaskListenerFactory; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import java.util.HashMap; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.tasks.ResourceStats.MEMORY; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; @@ -32,7 +32,7 @@ public class TaskResourceTrackingServiceTests extends OpenSearchTestCase { @Before public void setup() { - threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), new RunnableTaskListenerFactory()); + threadPool = new TestThreadPool(TransportTasksActionTests.class.getSimpleName(), new AtomicReference<>()); taskResourceTrackingService = new TaskResourceTrackingService( Settings.EMPTY, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), diff --git a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java index eeca3e4719ac9..2d97d5bffee01 100644 --- a/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java +++ b/test/framework/src/main/java/org/opensearch/threadpool/TestThreadPool.java @@ -40,6 +40,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicReference; public class TestThreadPool extends ThreadPool { @@ -47,7 +48,11 @@ public class TestThreadPool extends ThreadPool { private volatile boolean returnRejectingExecutor = false; private volatile ThreadPoolExecutor rejectingExecutor; - public TestThreadPool(String name, RunnableTaskListenerFactory runnableTaskListener, ExecutorBuilder... customBuilders) { + public TestThreadPool( + String name, + AtomicReference runnableTaskListener, + ExecutorBuilder... customBuilders + ) { this(name, Settings.EMPTY, runnableTaskListener, customBuilders); } @@ -62,7 +67,7 @@ public TestThreadPool(String name, Settings settings, ExecutorBuilder... cust public TestThreadPool( String name, Settings settings, - RunnableTaskListenerFactory runnableTaskListener, + AtomicReference runnableTaskListener, ExecutorBuilder... customBuilders ) { super(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), name).put(settings).build(), runnableTaskListener, customBuilders); From f135cf125079663530a90db0afeb55536a63701d Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Fri, 8 Apr 2022 12:54:29 +0530 Subject: [PATCH 19/26] Revert "Use reflection for ThreadMXBean" This reverts commit cbcf3c525bf516fb7164f0221491a7b25c1f96ec. Signed-off-by: Tushar Kharbanda --- .../tasks/TaskResourceTrackingService.java | 45 ++++--------------- 1 file changed, 9 insertions(+), 36 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index fb779182c0a74..cf8e07fc263a7 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -8,6 +8,7 @@ package org.opensearch.tasks; +import com.sun.management.ThreadMXBean; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -19,8 +20,6 @@ import org.opensearch.threadpool.ThreadPool; import java.lang.management.ManagementFactory; -import java.lang.management.ThreadMXBean; -import java.lang.reflect.Method; import java.util.Collections; import java.util.List; import java.util.Map; @@ -40,27 +39,7 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene ); public static final String TASK_ID = "TASK_ID"; - private static ThreadMXBean threadMXBean; - private static Method isThreadAllocatedMemorySupported; - private static Method isThreadAllocatedMemoryEnabled; - private static Method getThreadAllocatedBytes; - private static Method getThreadCpuTime; - - static { - try { - isThreadAllocatedMemorySupported = Class.forName("com.sun.management.ThreadMXBean") - .getMethod("isThreadAllocatedMemorySupported"); - isThreadAllocatedMemoryEnabled = Class.forName("com.sun.management.ThreadMXBean").getMethod("isThreadAllocatedMemoryEnabled"); - getThreadAllocatedBytes = Class.forName("com.sun.management.ThreadMXBean").getMethod("getThreadAllocatedBytes", long.class); - getThreadCpuTime = Class.forName("com.sun.management.ThreadMXBean").getMethod("getThreadCpuTime", long.class); - threadMXBean = ManagementFactory.getThreadMXBean(); - } catch (Exception e) { - isThreadAllocatedMemorySupported = null; - isThreadAllocatedMemoryEnabled = null; - getThreadAllocatedBytes = null; - getThreadCpuTime = null; - } - } + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); private final ThreadPool threadPool; @@ -79,13 +58,9 @@ public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) } public boolean isTaskResourceTrackingEnabled() { - try { - return taskResourceTrackingEnabled - && (boolean) isThreadAllocatedMemorySupported.invoke(threadMXBean) - && (boolean) isThreadAllocatedMemoryEnabled.invoke(threadMXBean); - } catch (Exception e) { - return false; - } + return taskResourceTrackingEnabled + && threadMXBean.isThreadAllocatedMemorySupported() + && threadMXBean.isThreadAllocatedMemoryEnabled(); } /** @@ -188,12 +163,10 @@ public Map getResourceAwareTasks() { } private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { - long bytes = 0; - try { - bytes = (long) getThreadAllocatedBytes.invoke(threadMXBean, threadId); - } catch (Exception e) {} - - ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric(ResourceStats.MEMORY, bytes); + ResourceUsageMetric currentMemoryUsage = new ResourceUsageMetric( + ResourceStats.MEMORY, + threadMXBean.getThreadAllocatedBytes(threadId) + ); ResourceUsageMetric currentCPUUsage = new ResourceUsageMetric(ResourceStats.CPU, threadMXBean.getThreadCpuTime(threadId)); return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; } From dea288b953941253f1b562fa8cf716f5e0e274d9 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Fri, 8 Apr 2022 12:57:08 +0530 Subject: [PATCH 20/26] Suppress Warning related to ThreadMXBean Signed-off-by: Tushar Kharbanda --- .../java/org/opensearch/tasks/TaskResourceTrackingService.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index cf8e07fc263a7..6e9f1800f9d8a 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -9,6 +9,7 @@ package org.opensearch.tasks; import com.sun.management.ThreadMXBean; +import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -29,6 +30,7 @@ /** * Service that helps track resource usage of tasks running on a node. */ +@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") public class TaskResourceTrackingService implements RunnableTaskExecutionListener { public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( From 167086ab857518017d6a96a452df2f6c817915fa Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Mon, 11 Apr 2022 14:24:56 +0530 Subject: [PATCH 21/26] Add separate method for task resource tracking supported check Signed-off-by: Tushar Kharbanda --- .../opensearch/tasks/TaskResourceTrackingService.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 6e9f1800f9d8a..a4a93c2086ae5 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -60,7 +60,11 @@ public void setTaskResourceTrackingEnabled(boolean taskResourceTrackingEnabled) } public boolean isTaskResourceTrackingEnabled() { - return taskResourceTrackingEnabled + return taskResourceTrackingEnabled; + } + + public boolean isTaskResourceTrackingSupported() { + return isTaskResourceTrackingEnabled() && threadMXBean.isThreadAllocatedMemorySupported() && threadMXBean.isThreadAllocatedMemoryEnabled(); } @@ -75,7 +79,7 @@ public boolean isTaskResourceTrackingEnabled() { * @return Autocloseable stored context to restore ThreadContext to the state before this method changed it. */ public ThreadContext.StoredContext startTracking(Task task) { - if (task.supportsResourceTracking() == false || isTaskResourceTrackingEnabled() == false) { + if (task.supportsResourceTracking() == false || isTaskResourceTrackingSupported() == false) { return () -> {}; } @@ -112,7 +116,7 @@ public void stopTracking(Task task) { * @param tasks for which resource stats needs to be refreshed. */ public void refreshResourceStats(Task... tasks) { - if (isTaskResourceTrackingEnabled() == false) { + if (isTaskResourceTrackingSupported() == false) { return; } From ff4a9eb65fda2a9f75450f2373c492b0fa6563c9 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Mon, 11 Apr 2022 23:55:39 +0530 Subject: [PATCH 22/26] Enabled setting by default Signed-off-by: Tushar Kharbanda --- .../tasks/TaskResourceTrackingService.java | 12 ++++++------ .../org/opensearch/threadpool/TaskAwareRunnable.java | 3 --- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index a4a93c2086ae5..89995cf03eea3 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -35,7 +35,7 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( "task_resource_tracking.enabled", - false, + true, Setting.Property.Dynamic, Setting.Property.NodeScope ); @@ -64,9 +64,7 @@ public boolean isTaskResourceTrackingEnabled() { } public boolean isTaskResourceTrackingSupported() { - return isTaskResourceTrackingEnabled() - && threadMXBean.isThreadAllocatedMemorySupported() - && threadMXBean.isThreadAllocatedMemoryEnabled(); + return threadMXBean.isThreadAllocatedMemorySupported() && threadMXBean.isThreadAllocatedMemoryEnabled(); } /** @@ -79,7 +77,9 @@ public boolean isTaskResourceTrackingSupported() { * @return Autocloseable stored context to restore ThreadContext to the state before this method changed it. */ public ThreadContext.StoredContext startTracking(Task task) { - if (task.supportsResourceTracking() == false || isTaskResourceTrackingSupported() == false) { + if (task.supportsResourceTracking() == false + || isTaskResourceTrackingEnabled() == false + || isTaskResourceTrackingSupported() == false) { return () -> {}; } @@ -116,7 +116,7 @@ public void stopTracking(Task task) { * @param tasks for which resource stats needs to be refreshed. */ public void refreshResourceStats(Task... tasks) { - if (isTaskResourceTrackingSupported() == false) { + if (isTaskResourceTrackingEnabled() == false || isTaskResourceTrackingSupported() == false) { return; } diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java index ed3c0d7523d17..3c0b0d7d68d11 100644 --- a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -63,9 +63,7 @@ public void onRejection(final Exception e) { @Override protected void doRun() throws Exception { assert runnableTaskListener.get() != null : "Listener should be attached"; - Long taskId = threadContext.getTransient(TASK_ID); - if (Objects.nonNull(taskId)) { runnableTaskListener.get().taskExecutionStartedOnThread(taskId, currentThread().getId()); } @@ -76,7 +74,6 @@ protected void doRun() throws Exception { runnableTaskListener.get().taskExecutionFinishedOnThread(taskId, currentThread().getId()); } } - } @Override From 3df4d63e9938efa52ff66b22c85f61965771a401 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Tue, 12 Apr 2022 00:15:53 +0530 Subject: [PATCH 23/26] Add debug logs for stale context id Signed-off-by: Tushar Kharbanda --- .../tasks/TaskResourceTrackingService.java | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 89995cf03eea3..964c2e13da5a9 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -9,6 +9,8 @@ package org.opensearch.tasks; import com.sun.management.ThreadMXBean; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; @@ -33,6 +35,8 @@ @SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") public class TaskResourceTrackingService implements RunnableTaskExecutionListener { + private static final Logger logger = LogManager.getLogger(TaskManager.class); + public static final Setting TASK_RESOURCE_TRACKING_ENABLED = Setting.boolSetting( "task_resource_tracking.enabled", true, @@ -210,9 +214,17 @@ private boolean validateNoActiveThread(Task task) { private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { ThreadContext threadContext = threadPool.getThreadContext(); - boolean noStaleTaskIdPresentInThreadContext = threadContext.getTransient(TASK_ID) == null - || resourceAwareTasks.containsKey((long) threadContext.getTransient(TASK_ID)); - assert noStaleTaskIdPresentInThreadContext : "Stale Task Id shouldn't be present in thread context"; + boolean staleIdPresentInThreadContext = threadContext.getTransient(TASK_ID) != null + && !resourceAwareTasks.containsKey((long) threadContext.getTransient(TASK_ID)); + + if (staleIdPresentInThreadContext) { + logger.debug( + "Stale Task Id should ideally be not present in thread context. Current task Id: {}, Stale task Id: {}, Thread id: {}", + task.getId(), + threadContext.getTransient(TASK_ID), + Thread.currentThread().getId() + ); + } ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID)); threadContext.putTransient(TASK_ID, task.getId()); From 5dcd53efc06d6dbe20e2ab05df60dc7970db03fb Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Thu, 14 Apr 2022 20:14:05 +0530 Subject: [PATCH 24/26] Remove hardcoded task overhead in tests Signed-off-by: Tushar Kharbanda --- .../threadpool/TaskAwareRunnable.java | 2 +- .../node/tasks/ResourceAwareTasksTests.java | 36 ++++++++++++------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java index 3c0b0d7d68d11..793c4f289845e 100644 --- a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -32,7 +32,7 @@ public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnab private final AtomicReference runnableTaskListener; public TaskAwareRunnable( - ThreadContext threadContext, + final ThreadContext threadContext, final Runnable original, final AtomicReference runnableTaskListener ) { diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java index e3dd9ff616380..23877ac0b7395 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/ResourceAwareTasksTests.java @@ -8,6 +8,7 @@ package org.opensearch.action.admin.cluster.node.tasks; +import com.sun.management.ThreadMXBean; import org.opensearch.ExceptionsHelper; import org.opensearch.action.ActionListener; import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; @@ -17,6 +18,7 @@ import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; @@ -32,6 +34,7 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -45,12 +48,10 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; +@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") public class ResourceAwareTasksTests extends TaskManagerTestCase { - // For every task there's a general overhead before and after the actual task operation code is executed. - // This includes things like creating threadContext, Transport Channel, Tracking task cancellation etc. - // For the tasks used for this test that maximum memory overhead can be 450Kb - private static final int TASK_MAX_GENERAL_MEMORY_OVERHEAD = 450000; + private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); public static class ResourceAwareNodeRequest extends BaseNodeRequest { protected String requestName; @@ -134,7 +135,6 @@ public boolean shouldCancelChildrenOnCancellation() { * Simulates a task which executes work on search executor. */ class ResourceAwareNodesAction extends AbstractTestNodesAction { - private final TaskTestContext taskTestContext; private final boolean blockForCancellation; @@ -168,7 +168,11 @@ public void onFailure(Exception e) { } @Override + @SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") protected void doRun() { + taskTestContext.memoryConsumptionWhenExecutionStarts = threadMXBean.getThreadAllocatedBytes( + Thread.currentThread().getId() + ); threadId.set(Thread.currentThread().getId()); if (taskTestContext.operationStartValidator != null) { @@ -205,9 +209,10 @@ protected void doRun() { result.get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(e.getCause()); - } - if (taskTestContext.operationFinishedValidator != null) { - taskTestContext.operationFinishedValidator.accept(threadId.get()); + } finally { + if (taskTestContext.operationFinishedValidator != null) { + taskTestContext.operationFinishedValidator.accept(threadId.get()); + } } return new NodeResponse(clusterService.localNode()); @@ -246,6 +251,7 @@ private static class TaskTestContext { private CountDownLatch requestCompleteLatch; private Consumer operationStartValidator; private Consumer operationFinishedValidator; + private long memoryConsumptionWhenExecutionStarts; } public void testBasicTaskResourceTracking() throws Exception { @@ -280,7 +286,8 @@ public void testBasicTaskResourceTracking() throws Exception { long expectedArrayAllocationOverhead = 2 * 4012688; // Task's memory overhead due to array allocations long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); - assertTrue(Math.abs(actualTaskMemoryOverhead - expectedArrayAllocationOverhead) < TASK_MAX_GENERAL_MEMORY_OVERHEAD); + + assertTrue(actualTaskMemoryOverhead - expectedArrayAllocationOverhead < taskTestContext.memoryConsumptionWhenExecutionStarts); assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); }; @@ -334,11 +341,13 @@ public void testTaskResourceTrackingDuringTaskCancellation() throws Exception { assertEquals(1, task.getResourceStats().get(threadId).size()); assertFalse(task.getResourceStats().get(threadId).get(0).isActive()); - long expectedArrayAllocationOverhead = 4012688; // Task's memory overhead due to array allocations. Only one out of 2 // allocations are completed before the task is cancelled - long actualArrayAllocationOverhead = task.getTotalResourceStats().getMemoryInBytes(); + long expectedArrayAllocationOverhead = 4012688; // Task's memory overhead due to array allocations + long taskCancellationOverhead = 30000; // Task cancellation overhead ~ 30Kb + long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); - assertTrue(Math.abs(actualArrayAllocationOverhead - expectedArrayAllocationOverhead) < TASK_MAX_GENERAL_MEMORY_OVERHEAD); + long expectedOverhead = expectedArrayAllocationOverhead + taskCancellationOverhead; + assertTrue(actualTaskMemoryOverhead - expectedOverhead < taskTestContext.memoryConsumptionWhenExecutionStarts); assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); }; @@ -437,7 +446,8 @@ public void testTaskResourceTrackingDisabledWhileTaskInProgress() throws Excepti long expectedArrayAllocationOverhead = 2 * 4012688; // Task's memory overhead due to array allocations long actualTaskMemoryOverhead = task.getTotalResourceStats().getMemoryInBytes(); - assertTrue(Math.abs(actualTaskMemoryOverhead - expectedArrayAllocationOverhead) < TASK_MAX_GENERAL_MEMORY_OVERHEAD); + + assertTrue(actualTaskMemoryOverhead - expectedArrayAllocationOverhead < taskTestContext.memoryConsumptionWhenExecutionStarts); assertTrue(task.getTotalResourceStats().getCpuTimeInNanos() > 0); }; From 9bd32cfd312acb336d4b2458ae4f670b732b52f3 Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Fri, 15 Apr 2022 01:21:48 +0530 Subject: [PATCH 25/26] Bump stale task id in thread context log level to warn Signed-off-by: Tushar Kharbanda --- .../org/opensearch/tasks/TaskResourceTrackingService.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 964c2e13da5a9..1de591bcf1a2d 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -218,8 +218,8 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { && !resourceAwareTasks.containsKey((long) threadContext.getTransient(TASK_ID)); if (staleIdPresentInThreadContext) { - logger.debug( - "Stale Task Id should ideally be not present in thread context. Current task Id: {}, Stale task Id: {}, Thread id: {}", + logger.warn( + "Previous task ID was not removed from thread context. Current task Id: {}, Stale task Id: {}, Thread id: {}", task.getId(), threadContext.getTransient(TASK_ID), Thread.currentThread().getId() From 0c301e1d0efc29c3d3ffa6be1fc51a1802a4b32b Mon Sep 17 00:00:00 2001 From: Tushar Kharbanda Date: Tue, 19 Apr 2022 11:38:40 +0530 Subject: [PATCH 26/26] Improve assertions and logging Signed-off-by: Tushar Kharbanda --- .../main/java/org/opensearch/tasks/Task.java | 2 +- .../org/opensearch/tasks/TaskManager.java | 2 +- .../tasks/TaskResourceTrackingService.java | 93 ++++++++++++------- .../opensearch/tasks/ThreadResourceInfo.java | 10 +- .../threadpool/TaskAwareRunnable.java | 7 ++ 5 files changed, 74 insertions(+), 40 deletions(-) diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 9aad853b070ba..a51af17ae8ea2 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -285,7 +285,7 @@ public void startThreadResourceTracking(long threadId, ResourceStatsType statsTy ); } } - threadResourceInfoList.add(new ThreadResourceInfo(statsType, resourceUsageMetrics)); + threadResourceInfoList.add(new ThreadResourceInfo(threadId, statsType, resourceUsageMetrics)); } /** diff --git a/server/src/main/java/org/opensearch/tasks/TaskManager.java b/server/src/main/java/org/opensearch/tasks/TaskManager.java index 09f435c971ecc..37c10dfc0e6ab 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskManager.java +++ b/server/src/main/java/org/opensearch/tasks/TaskManager.java @@ -105,7 +105,7 @@ public class TaskManager implements ClusterStateApplier { private final Map banedParents = new ConcurrentHashMap<>(); private TaskResultsService taskResultsService; - private SetOnce taskResourceTrackingService = new SetOnce<>(); + private final SetOnce taskResourceTrackingService = new SetOnce<>(); private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES; diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index 1de591bcf1a2d..71b829e023385 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -11,6 +11,7 @@ import com.sun.management.ThreadMXBean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; @@ -23,6 +24,7 @@ import org.opensearch.threadpool.ThreadPool; import java.lang.management.ManagementFactory; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -55,7 +57,6 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene public TaskResourceTrackingService(Settings settings, ClusterSettings clusterSettings, ThreadPool threadPool) { this.taskResourceTrackingEnabled = TASK_RESOURCE_TRACKING_ENABLED.get(settings); this.threadPool = threadPool; - clusterSettings.addSettingsUpdateConsumer(TASK_RESOURCE_TRACKING_ENABLED, this::setTaskResourceTrackingEnabled); } @@ -87,9 +88,9 @@ public ThreadContext.StoredContext startTracking(Task task) { return () -> {}; } + logger.debug("Starting resource tracking for task: {}", task.getId()); resourceAwareTasks.put(task.getId(), task); return addTaskIdToThreadContext(task); - } /** @@ -104,13 +105,23 @@ public ThreadContext.StoredContext startTracking(Task task) { * @param task task which has finished and doesn't need resource tracking. */ public void stopTracking(Task task) { - if (isThreadWorkingOnTask(task, Thread.currentThread().getId())) { - taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); - } - - assert validateNoActiveThread(task) : "No thread should be active when task is finished"; + logger.debug("Stopping resource tracking for task: {}", task.getId()); + try { + if (isCurrentThreadWorkingOnTask(task)) { + taskExecutionFinishedOnThread(task.getId(), Thread.currentThread().getId()); + } - resourceAwareTasks.remove(task.getId()); + List threadsWorkingOnTask = getThreadsWorkingOnTask(task); + if (threadsWorkingOnTask.size() > 0) { + logger.warn("No thread should be active when task finishes. Active threads: {}", threadsWorkingOnTask); + assert false : "No thread should be marked active when task finishes"; + } + } catch (Exception e) { + logger.warn("Failed while trying to mark the task execution on current thread completed.", e); + assert false; + } finally { + resourceAwareTasks.remove(task.getId()); + } } /** @@ -132,13 +143,16 @@ public void refreshResourceStats(Task... tasks) { } private void refreshResourceStats(Task resourceAwareTask) { - resourceAwareTask.getResourceStats().forEach((threadId, threadResourceInfos) -> { - for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { - if (threadResourceInfo.isActive()) { - resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); - } - } - }); + try { + logger.debug("Refreshing resource stats for Task: {}", resourceAwareTask.getId()); + List threadsWorkingOnTask = getThreadsWorkingOnTask(resourceAwareTask); + threadsWorkingOnTask.forEach( + threadId -> resourceAwareTask.updateThreadResourceStats(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)) + ); + } catch (IllegalStateException e) { + logger.debug("Resource stats already updated."); + } + } /** @@ -150,9 +164,18 @@ private void refreshResourceStats(Task resourceAwareTask) { */ @Override public void taskExecutionStartedOnThread(long taskId, long threadId) { - if (resourceAwareTasks.containsKey(taskId)) { - resourceAwareTasks.get(taskId).startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + try { + if (resourceAwareTasks.containsKey(taskId)) { + logger.debug("Task execution started on thread. Task: {}, Thread: {}", taskId, threadId); + + resourceAwareTasks.get(taskId) + .startThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } catch (Exception e) { + logger.warn(new ParameterizedMessage("Failed to mark thread execution started for task: [{}]", taskId), e); + assert false; } + } /** @@ -163,8 +186,15 @@ public void taskExecutionStartedOnThread(long taskId, long threadId) { */ @Override public void taskExecutionFinishedOnThread(long taskId, long threadId) { - if (resourceAwareTasks.containsKey(taskId)) { - resourceAwareTasks.get(taskId).stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + try { + if (resourceAwareTasks.containsKey(taskId)) { + logger.debug("Task execution finished on thread. Task: {}, Thread: {}", taskId, threadId); + resourceAwareTasks.get(taskId) + .stopThreadResourceTracking(threadId, WORKER_STATS, getResourceUsageMetricsForThread(threadId)); + } + } catch (Exception e) { + logger.warn(new ParameterizedMessage("Failed to mark thread execution finished for task: [{}]", taskId), e); + assert false; } } @@ -181,7 +211,8 @@ private ResourceUsageMetric[] getResourceUsageMetricsForThread(long threadId) { return new ResourceUsageMetric[] { currentMemoryUsage, currentCPUUsage }; } - private boolean isThreadWorkingOnTask(Task task, long threadId) { + private boolean isCurrentThreadWorkingOnTask(Task task) { + long threadId = Thread.currentThread().getId(); List threadResourceInfos = task.getResourceStats().getOrDefault(threadId, Collections.emptyList()); for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { @@ -192,13 +223,16 @@ private boolean isThreadWorkingOnTask(Task task, long threadId) { return false; } - private boolean validateNoActiveThread(Task task) { + private List getThreadsWorkingOnTask(Task task) { + List activeThreads = new ArrayList<>(); for (List threadResourceInfos : task.getResourceStats().values()) { for (ThreadResourceInfo threadResourceInfo : threadResourceInfos) { - if (threadResourceInfo.isActive()) return false; + if (threadResourceInfo.isActive()) { + activeThreads.add(threadResourceInfo.getThreadId()); + } } } - return true; + return activeThreads; } /** @@ -213,19 +247,6 @@ private boolean validateNoActiveThread(Task task) { */ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { ThreadContext threadContext = threadPool.getThreadContext(); - - boolean staleIdPresentInThreadContext = threadContext.getTransient(TASK_ID) != null - && !resourceAwareTasks.containsKey((long) threadContext.getTransient(TASK_ID)); - - if (staleIdPresentInThreadContext) { - logger.warn( - "Previous task ID was not removed from thread context. Current task Id: {}, Stale task Id: {}, Thread id: {}", - task.getId(), - threadContext.getTransient(TASK_ID), - Thread.currentThread().getId() - ); - } - ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true, Collections.singletonList(TASK_ID)); threadContext.putTransient(TASK_ID, task.getId()); return storedContext; diff --git a/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java b/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java index 8b45c38c8fb63..9ee683e3928f6 100644 --- a/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java +++ b/server/src/main/java/org/opensearch/tasks/ThreadResourceInfo.java @@ -15,11 +15,13 @@ * for a specific stats type like worker_stats or response_stats etc., */ public class ThreadResourceInfo { + private final long threadId; private volatile boolean isActive = true; private final ResourceStatsType statsType; private final ResourceUsageInfo resourceUsageInfo; - public ThreadResourceInfo(ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + public ThreadResourceInfo(long threadId, ResourceStatsType statsType, ResourceUsageMetric... resourceUsageMetrics) { + this.threadId = threadId; this.statsType = statsType; this.resourceUsageInfo = new ResourceUsageInfo(resourceUsageMetrics); } @@ -43,12 +45,16 @@ public ResourceStatsType getStatsType() { return statsType; } + public long getThreadId() { + return threadId; + } + public ResourceUsageInfo getResourceUsageInfo() { return resourceUsageInfo; } @Override public String toString() { - return resourceUsageInfo + ", stats_type=" + statsType + ", is_active=" + isActive; + return resourceUsageInfo + ", stats_type=" + statsType + ", is_active=" + isActive + ", threadId=" + threadId; } } diff --git a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java index 793c4f289845e..183b9b2f4cf9a 100644 --- a/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java +++ b/server/src/main/java/org/opensearch/threadpool/TaskAwareRunnable.java @@ -8,10 +8,13 @@ package org.opensearch.threadpool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.util.concurrent.WrappedRunnable; +import org.opensearch.tasks.TaskManager; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; @@ -27,6 +30,8 @@ */ public class TaskAwareRunnable extends AbstractRunnable implements WrappedRunnable { + private static final Logger logger = LogManager.getLogger(TaskManager.class); + private final Runnable original; private final ThreadContext threadContext; private final AtomicReference runnableTaskListener; @@ -66,6 +71,8 @@ protected void doRun() throws Exception { Long taskId = threadContext.getTransient(TASK_ID); if (Objects.nonNull(taskId)) { runnableTaskListener.get().taskExecutionStartedOnThread(taskId, currentThread().getId()); + } else { + logger.debug("Task Id not available in thread context. Skipping update. Thread Info: {}", Thread.currentThread()); } try { original.run();