Skip to content

Commit

Permalink
Fix MasterServiceTests#testThreadContext (elastic#118926) (elastic#…
Browse files Browse the repository at this point in the history
…119306)

This test would fail to see the expected response headers if the task
timed out before it started executing, which could happen very rarely.
It's also not a very good test because it never actually executed any of
the paths involving acking.

This commit fixes the rare failure and tightens up the assertions to
verify that it does indeed see the right thread context while handling
the end of the acking process, and indeed that it always completes the
acking process.

Closes elastic#118914
  • Loading branch information
DaveCTurner authored Dec 27, 2024
1 parent 36d3f6d commit faaede7
Showing 1 changed file with 63 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.StoppableExecutorServiceWrapper;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
Expand Down Expand Up @@ -258,9 +260,42 @@ public void clusterStatePublished(ClusterState newClusterState) {
assertThat(registeredActions.toString(), registeredActions, contains(MasterService.STATE_UPDATE_ACTION_NAME));
}

public void testThreadContext() throws InterruptedException {
public void testThreadContext() {
try (var master = createMasterService(true)) {
final CountDownLatch latch = new CountDownLatch(1);

master.setClusterStatePublisher((clusterStatePublicationEvent, publishListener, ackListener) -> {
ClusterServiceUtils.setAllElapsedMillis(clusterStatePublicationEvent);
try (var ignored = threadPool.getThreadContext().newEmptyContext()) {
if (randomBoolean()) {
randomExecutor(threadPool).execute(() -> publishListener.onResponse(null));
randomExecutor(threadPool).execute(() -> ackListener.onCommit(TimeValue.timeValueMillis(randomInt(10000))));
randomExecutor(threadPool).execute(
() -> ackListener.onNodeAck(
clusterStatePublicationEvent.getNewState().nodes().getMasterNode(),
randomBoolean() ? null : new RuntimeException("simulated ack failure")
)
);
} else {
randomExecutor(threadPool).execute(
() -> publishListener.onFailure(new FailedToCommitClusterStateException("simulated publish failure"))
);
}
}
});

final Releasable onPublishComplete;
final Releasable onAckingComplete;
final Runnable awaitComplete;
{
final var publishLatch = new CountDownLatch(1);
final var ackingLatch = new CountDownLatch(1);
onPublishComplete = Releasables.assertOnce(publishLatch::countDown);
onAckingComplete = Releasables.assertOnce(ackingLatch::countDown);
awaitComplete = () -> {
safeAwait(publishLatch);
safeAwait(ackingLatch);
};
}

try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {

Expand All @@ -271,15 +306,12 @@ public void testThreadContext() throws InterruptedException {
expectedHeaders.put(copiedHeader, randomIdentifier());
}
}

final Map<String, List<String>> expectedResponseHeaders = Collections.singletonMap(
"testResponse",
Collections.singletonList("testResponse")
);
threadPool.getThreadContext().putHeader(expectedHeaders);

final TimeValue ackTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
final TimeValue masterTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
final Map<String, List<String>> expectedResponseHeaders = Map.of("testResponse", List.of(randomIdentifier()));

final TimeValue ackTimeout = randomBoolean() ? TimeValue.MINUS_ONE : TimeValue.timeValueMillis(randomInt(10000));
final TimeValue masterTimeout = randomBoolean() ? TimeValue.MINUS_ONE : TimeValue.timeValueMillis(randomInt(10000));

master.submitUnbatchedStateUpdateTask(
"test",
Expand All @@ -288,8 +320,9 @@ public void testThreadContext() throws InterruptedException {
public ClusterState execute(ClusterState currentState) {
assertTrue(threadPool.getThreadContext().isSystemContext());
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());
threadPool.getThreadContext().addResponseHeader("testResponse", "testResponse");
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
expectedResponseHeaders.forEach(
(name, values) -> values.forEach(v -> threadPool.getThreadContext().addResponseHeader(name, v))
);

if (randomBoolean()) {
return ClusterState.builder(currentState).build();
Expand All @@ -302,44 +335,44 @@ public ClusterState execute(ClusterState currentState) {

@Override
public void onFailure(Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
latch.countDown();
assertExpectedThreadContext(
e instanceof ProcessClusterEventTimeoutException ? Map.of() : expectedResponseHeaders
);
onPublishComplete.close();
onAckingComplete.close(); // no acking takes place if publication failed
}

@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
latch.countDown();
assertExpectedThreadContext(expectedResponseHeaders);
onPublishComplete.close();
}

@Override
public void onAllNodesAcked() {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
latch.countDown();
onAckCompletion();
}

@Override
public void onAckFailure(Exception e) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
latch.countDown();
onAckCompletion();
}

@Override
public void onAckTimeout() {
onAckCompletion();
}

private void onAckCompletion() {
assertExpectedThreadContext(expectedResponseHeaders);
onAckingComplete.close();
}

private void assertExpectedThreadContext(Map<String, List<String>> expectedResponseHeaders) {
assertFalse(threadPool.getThreadContext().isSystemContext());
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
assertEquals(expectedResponseHeaders, threadPool.getThreadContext().getResponseHeaders());
latch.countDown();
}

}
);

Expand All @@ -348,7 +381,7 @@ public void onAckTimeout() {
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getResponseHeaders());
}

assertTrue(latch.await(10, TimeUnit.SECONDS));
awaitComplete.run();
}
}

Expand Down

0 comments on commit faaede7

Please sign in to comment.