From be7e98dedd695221e44c2b6b676faeca3c1a6711 Mon Sep 17 00:00:00 2001 From: Kaushal Kumar Date: Mon, 22 Jul 2024 10:16:57 -0700 Subject: [PATCH] address comments Signed-off-by: Kaushal Kumar --- .../action/search/TransportSearchAction.java | 2 +- .../org/opensearch/search/SearchService.java | 12 ++++----- .../main/java/org/opensearch/tasks/Task.java | 14 +++++++--- .../admin/cluster/node/tasks/TaskTests.java | 26 +++++++++++++++++-- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 3353b8ec93b98..e92725e5bfe78 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -464,7 +464,7 @@ private void executeRequest( // At this point either the QUERY_GROUP_ID header will be present in ThreadContext either via ActionFilter // or HTTP header (HTTP header will be deprecated once ActionFilter is implemented) - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); PipelinedRequest searchRequest; ActionListener listener; diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 82ada89e5ae49..aa3e409190ae5 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -557,7 +557,7 @@ public void executeDfsPhase( ActionListener listener ) { final IndexShard shard = getShard(request); - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); rewriteAndFetchShardRequest(shard, request, new ActionListener() { @Override public void onResponse(ShardSearchRequest rewritten) { @@ -611,7 +611,7 @@ public void executeQueryPhase( ) { assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1 : "empty responses require more than one shard"; - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); final IndexShard shard = getShard(request); rewriteAndFetchShardRequest(shard, request, new ActionListener() { @Override @@ -721,7 +721,7 @@ public void executeQueryPhase( freeReaderContext(readerContext.id()); throw e; } - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); runAsync(getExecutor(readerContext.indexShard()), () -> { final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null); try ( @@ -748,7 +748,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task, final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest()); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest()); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); runAsync(getExecutor(readerContext.indexShard()), () -> { readerContext.setAggregatedDfs(request.dfs()); try ( @@ -799,7 +799,7 @@ public void executeFetchPhase( ) { final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request); final Releasable markAsUsed; - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); try { markAsUsed = readerContext.markAsUsed(getScrollKeepAlive(request.scroll())); } catch (Exception e) { @@ -835,7 +835,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A final ReaderContext readerContext = findReaderContext(request.contextId(), request); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); runAsync(getExecutor(readerContext.indexShard()), () -> { try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) { if (request.lastEmittedDoc() != null) { diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 707bda3e49c69..ff3ffaccff429 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -278,7 +278,6 @@ public TaskId getParentTaskId() { return parentTask; } - /** * Build a status for this task or null if this task doesn't have status. * Since most tasks don't have status this defaults to returning null. While @@ -525,12 +524,21 @@ public String getHeader(String header) { return headers.get(header); } - public void addQueryGroupHeadersTo(final ThreadContext threadContext) { + /** + * This method adds the queryGroupHeader in the task headers, We need this method since the query group is not determined at the task creation time + * hence it is not possible to copy this header from request headers. This header is required to group the tasks into queryGroups to account for the QueryGroup level resource footprint + * @param threadContext + */ + public void addQueryGroupHeaders(final ThreadContext threadContext) { // For now this header will be coming from HTTP headers but in second phase this header // We will use this constant from QueryGroup Service once the framework changes are done final String QUERY_GROUP_ID_HEADER = "queryGroupId"; - final String requestQueryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER); + String requestQueryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER); + + if (requestQueryGroupId == null) { + requestQueryGroupId = "DEFAULT_QUERY_GROUP_ID"; // TODO: move this constant either to QueryGroupService or Tracking equivalent + } final Map newHeaders = new HashMap<>(headers); diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java index fed91afa88e38..ad95ffc59e5ac 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java @@ -239,7 +239,7 @@ public void testTaskResourceStats() { } } - public void testAddQueryGroupHeadersTo() { + public void testAddQueryGroupHeaders() { ThreadPool threadPool = new TestThreadPool(getClass().getName()); try { Task task = new Task( @@ -253,7 +253,7 @@ public void testAddQueryGroupHeadersTo() { threadPool.getThreadContext().putHeader("queryGroupId", "afakgkagj09532059"); - task.addQueryGroupHeadersTo(threadPool.getThreadContext()); + task.addQueryGroupHeaders(threadPool.getThreadContext()); String queryGroupId = task.getHeader("queryGroupId"); @@ -262,4 +262,26 @@ public void testAddQueryGroupHeadersTo() { threadPool.shutdown(); } } + + public void testAddQueryGroupHeadersWhenHeaderIsNotPresentInThreadContext() { + ThreadPool threadPool = new TestThreadPool(getClass().getName()); + try { + Task task = new Task( + randomLong(), + "transport", + SearchAction.NAME, + "description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + + task.addQueryGroupHeaders(threadPool.getThreadContext()); + + String queryGroupId = task.getHeader("queryGroupId"); + + assertEquals("DEFAULT_QUERY_GROUP_ID", queryGroupId); + } finally { + threadPool.shutdown(); + } + } }