Skip to content

Commit

Permalink
Fix task cancellation on remote cluster when original request fails (e…
Browse files Browse the repository at this point in the history
…lastic#109440)

Fixes a bug where the task on the remote cluster node is not cancelled
when the original request (that started the task) fails (returns an
exception).
  • Loading branch information
albertzaharovits authored Jun 7, 2024
1 parent a609258 commit df96199
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 9 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/109440.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 109440
summary: Fix task cancellation on remote cluster when original request fails
area: Network
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.
reason
);
final CancelChildRequest request = CancelChildRequest.createCancelChildRequest(parentTask, childRequestId, reason);
transportService.sendRequest(childNode, CANCEL_CHILD_ACTION_NAME, request, TransportRequestOptions.EMPTY, NOOP_HANDLER);
transportService.sendRequest(childConnection, CANCEL_CHILD_ACTION_NAME, request, TransportRequestOptions.EMPTY, NOOP_HANDLER);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import static java.util.Collections.emptySet;
import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
import static org.elasticsearch.test.ClusterServiceUtils.setState;
import static org.elasticsearch.test.transport.MockTransportService.createTaskManager;

/**
* The test case for unit testing task manager and related transport actions
Expand Down Expand Up @@ -176,12 +177,7 @@ public TestNode(String name, ThreadPool threadPool, Settings settings) {
discoveryNode.set(DiscoveryNodeUtils.create(name, address.publishAddress(), emptyMap(), emptySet()));
return discoveryNode.get();
};
TaskManager taskManager;
if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) {
taskManager = new MockTaskManager(settings, threadPool, emptySet());
} else {
taskManager = new TaskManager(settings, threadPool, emptySet());
}
TaskManager taskManager = createTaskManager(settings, threadPool, emptySet(), Tracer.NOOP);
transportService = new TransportService(
settings,
new Netty4Transport(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancellationService;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
Expand All @@ -31,11 +34,19 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

import static org.elasticsearch.test.tasks.MockTaskManager.SPY_TASK_MANAGER_SETTING;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;

public class RemoteClusterAwareClientTests extends ESTestCase {

Expand All @@ -62,6 +73,89 @@ private MockTransportService startTransport(String id, List<DiscoveryNode> known
);
}

public void testRemoteTaskCancellationOnFailedResponse() throws Exception {
Settings.Builder remoteTransportSettingsBuilder = Settings.builder();
remoteTransportSettingsBuilder.put(SPY_TASK_MANAGER_SETTING.getKey(), true);
try (
MockTransportService remoteTransport = RemoteClusterConnectionTests.startTransport(
"seed_node",
new CopyOnWriteArrayList<>(),
VersionInformation.CURRENT,
TransportVersion.current(),
threadPool,
remoteTransportSettingsBuilder.build()
)
) {
remoteTransport.getTaskManager().setTaskCancellationService(new TaskCancellationService(remoteTransport));
Settings.Builder builder = Settings.builder();
builder.putList("cluster.remote.cluster1.seeds", remoteTransport.getLocalDiscoNode().getAddress().toString());
try (
MockTransportService localService = MockTransportService.createNewService(
builder.build(),
VersionInformation.CURRENT,
TransportVersion.current(),
threadPool,
null
)
) {
// the TaskCancellationService references the same TransportService instance
// this is identically to how it works in the Node constructor
localService.getTaskManager().setTaskCancellationService(new TaskCancellationService(localService));
localService.start();
localService.acceptIncomingRequests();

SearchShardsRequest searchShardsRequest = new SearchShardsRequest(
new String[] { "test-index" },
IndicesOptions.strictExpandOpen(),
new MatchAllQueryBuilder(),
null,
"index_not_found", // this request must fail
randomBoolean(),
null
);
Task parentTask = localService.getTaskManager().register("test_type", "test_action", searchShardsRequest);
TaskId parentTaskId = new TaskId("test-mock-node-id", parentTask.getId());
searchShardsRequest.setParentTask(parentTaskId);
var client = new RemoteClusterAwareClient(
localService,
"cluster1",
threadPool.executor(TEST_THREAD_POOL_NAME),
randomBoolean()
);

CountDownLatch cancelChildReceived = new CountDownLatch(1);
remoteTransport.addRequestHandlingBehavior(
TaskCancellationService.CANCEL_CHILD_ACTION_NAME,
(handler, request, channel, task) -> {
handler.messageReceived(request, channel, task);
cancelChildReceived.countDown();
}
);
AtomicLong searchShardsRequestId = new AtomicLong(-1);
CountDownLatch cancelChildSent = new CountDownLatch(1);
localService.addSendBehavior(remoteTransport, (connection, requestId, action, request, options) -> {
connection.sendRequest(requestId, action, request, options);
if (action.equals("indices:admin/search/search_shards")) {
searchShardsRequestId.set(requestId);
} else if (action.equals(TaskCancellationService.CANCEL_CHILD_ACTION_NAME)) {
cancelChildSent.countDown();
}
});

// assert original request failed
var future = new PlainActionFuture<SearchShardsResponse>();
client.execute(TransportSearchShardsAction.REMOTE_TYPE, searchShardsRequest, future);
ExecutionException e = expectThrows(ExecutionException.class, future::get);
assertThat(e.getCause(), instanceOf(RemoteTransportException.class));

// assert remote task is cancelled
safeAwait(cancelChildSent);
safeAwait(cancelChildReceived);
verify(remoteTransport.getTaskManager()).cancelChildLocal(eq(parentTaskId), eq(searchShardsRequestId.get()), anyString());
}
}
}

public void testSearchShards() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
try (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ public class MockTaskManager extends TaskManager {
Property.NodeScope
);

public static final Setting<Boolean> SPY_TASK_MANAGER_SETTING = Setting.boolSetting(
"tests.spy.taskmanager.enabled",
false,
Property.NodeScope
);

private final Collection<MockTaskManagerListener> listeners = new CopyOnWriteArrayList<>();

public MockTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import java.util.function.Supplier;

import static org.junit.Assert.assertNotNull;
import static org.mockito.Mockito.spy;

/**
* A mock delegate service that allows to simulate different network topology failures.
Expand All @@ -102,7 +103,7 @@ public class MockTransportService extends TransportService {
public static class TestPlugin extends Plugin {
@Override
public List<Setting<?>> getSettings() {
return List.of(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING);
return List.of(MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING, MockTaskManager.SPY_TASK_MANAGER_SETTING);
}
}

Expand Down Expand Up @@ -310,7 +311,15 @@ private static TransportAddress[] extractTransportAddresses(TransportService tra
return transportAddresses.toArray(new TransportAddress[transportAddresses.size()]);
}

private static TaskManager createTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders, Tracer tracer) {
public static TaskManager createTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders, Tracer tracer) {
if (MockTaskManager.SPY_TASK_MANAGER_SETTING.get(settings)) {
return spy(createMockTaskManager(settings, threadPool, taskHeaders, tracer));
} else {
return createMockTaskManager(settings, threadPool, taskHeaders, tracer);
}
}

private static TaskManager createMockTaskManager(Settings settings, ThreadPool threadPool, Set<String> taskHeaders, Tracer tracer) {
if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) {
return new MockTaskManager(settings, threadPool, taskHeaders);
} else {
Expand Down

0 comments on commit df96199

Please sign in to comment.