diff --git a/server/src/main/java/org/elasticsearch/action/ActionListener.java b/server/src/main/java/org/elasticsearch/action/ActionListener.java index 1b918186009c4..6a41a8205b783 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/ActionListener.java @@ -16,6 +16,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -418,15 +419,27 @@ public String toString() { * and {@link #onFailure(Exception)} of the provided listener will be called at most once. */ static ActionListener notifyOnce(ActionListener delegate) { - return new NotifyOnceListener() { + final var delegateRef = new AtomicReference<>(delegate); + return new ActionListener<>() { @Override - protected void innerOnResponse(Response response) { - delegate.onResponse(response); + public void onResponse(Response response) { + final var acquired = delegateRef.getAndSet(null); + if (acquired != null) { + acquired.onResponse(response); + } } @Override - protected void innerOnFailure(Exception e) { - delegate.onFailure(e); + public void onFailure(Exception e) { + final var acquired = delegateRef.getAndSet(null); + if (acquired != null) { + acquired.onFailure(e); + } + } + + @Override + public String toString() { + return "notifyOnce[" + delegate + "]"; } }; } diff --git a/server/src/main/java/org/elasticsearch/action/NotifyOnceListener.java b/server/src/main/java/org/elasticsearch/action/NotifyOnceListener.java deleted file mode 100644 index 582290f2a4349..0000000000000 --- a/server/src/main/java/org/elasticsearch/action/NotifyOnceListener.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action; - -import java.util.concurrent.atomic.AtomicBoolean; - -/** - * A listener that ensures that only one of onResponse or onFailure is called. And the method - * the is called is only called once. Subclasses should implement notification logic with - * innerOnResponse and innerOnFailure. - */ -public abstract class NotifyOnceListener implements ActionListener { - - private final AtomicBoolean hasBeenCalled = new AtomicBoolean(false); - - protected abstract void innerOnResponse(Response response); - - protected abstract void innerOnFailure(Exception e); - - @Override - public final void onResponse(Response response) { - if (hasBeenCalled.compareAndSet(false, true)) { - innerOnResponse(response); - } - } - - @Override - public final void onFailure(Exception e) { - if (hasBeenCalled.compareAndSet(false, true)) { - innerOnFailure(e); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/action/StepListener.java b/server/src/main/java/org/elasticsearch/action/StepListener.java index dab36040e3e4f..e36b799b92903 100644 --- a/server/src/main/java/org/elasticsearch/action/StepListener.java +++ b/server/src/main/java/org/elasticsearch/action/StepListener.java @@ -14,6 +14,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -40,7 +41,9 @@ * } */ -public final class StepListener extends NotifyOnceListener { +public final class StepListener implements ActionListener { + + private final AtomicBoolean hasBeenCalled = new AtomicBoolean(false); private final ListenableFuture delegate; public StepListener() { @@ -48,13 +51,17 @@ public StepListener() { } @Override - protected void innerOnResponse(Response response) { - delegate.onResponse(response); + public void onResponse(Response response) { + if (hasBeenCalled.compareAndSet(false, true)) { + delegate.onResponse(response); + } } @Override - protected void innerOnFailure(Exception e) { - delegate.onFailure(e); + public void onFailure(Exception e) { + if (hasBeenCalled.compareAndSet(false, true)) { + delegate.onFailure(e); + } } /** diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java index 592654ac8c125..b77511f7f4088 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java @@ -14,7 +14,6 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; -import org.elasticsearch.action.NotifyOnceListener; import org.elasticsearch.action.admin.indices.close.CloseIndexClusterStateUpdateRequest; import org.elasticsearch.action.admin.indices.close.CloseIndexResponse; import org.elasticsearch.action.admin.indices.close.CloseIndexResponse.IndexResult; @@ -619,9 +618,9 @@ private void waitForShardsReadyForClosing( for (int i = 0; i < indexRoutingTable.size(); i++) { IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i); final int shardId = shardRoutingTable.shardId().id(); - sendVerifyShardBeforeCloseRequest(shardRoutingTable, closingBlock, new NotifyOnceListener<>() { + sendVerifyShardBeforeCloseRequest(shardRoutingTable, closingBlock, ActionListener.notifyOnce(new ActionListener<>() { @Override - public void innerOnResponse(final ReplicationResponse replicationResponse) { + public void onResponse(ReplicationResponse replicationResponse) { ShardResult.Failure[] failures = Arrays.stream(replicationResponse.getShardInfo().getFailures()) .map(f -> new ShardResult.Failure(f.index(), f.shardId(), f.getCause(), f.nodeId())) .toArray(ShardResult.Failure[]::new); @@ -630,7 +629,7 @@ public void innerOnResponse(final ReplicationResponse replicationResponse) { } @Override - public void innerOnFailure(final Exception e) { + public void onFailure(Exception e) { ShardResult.Failure failure = new ShardResult.Failure(index.getName(), shardId, e); results.setOnce(shardId, new ShardResult(shardId, new ShardResult.Failure[] { failure })); processIfFinished(); @@ -641,7 +640,7 @@ private void processIfFinished() { onResponse.accept(new IndexResult(index, results.toArray(new ShardResult[results.length()]))); } } - }); + })); } } @@ -749,9 +748,9 @@ private void waitForShardsReady( for (int i = 0; i < indexRoutingTable.size(); i++) { IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i); final int shardId = shardRoutingTable.shardId().id(); - sendVerifyShardBlockRequest(shardRoutingTable, clusterBlock, new NotifyOnceListener<>() { + sendVerifyShardBlockRequest(shardRoutingTable, clusterBlock, ActionListener.notifyOnce(new ActionListener<>() { @Override - public void innerOnResponse(final ReplicationResponse replicationResponse) { + public void onResponse(ReplicationResponse replicationResponse) { AddBlockShardResult.Failure[] failures = Arrays.stream(replicationResponse.getShardInfo().getFailures()) .map(f -> new AddBlockShardResult.Failure(f.index(), f.shardId(), f.getCause(), f.nodeId())) .toArray(AddBlockShardResult.Failure[]::new); @@ -760,7 +759,7 @@ public void innerOnResponse(final ReplicationResponse replicationResponse) { } @Override - public void innerOnFailure(final Exception e) { + public void onFailure(Exception e) { AddBlockShardResult.Failure failure = new AddBlockShardResult.Failure(index.getName(), shardId, e); results.setOnce(shardId, new AddBlockShardResult(shardId, new AddBlockShardResult.Failure[] { failure })); processIfFinished(); @@ -773,7 +772,7 @@ private void processIfFinished() { onResponse.accept(result); } } - }); + })); } } diff --git a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java index f852a4f584e61..74b1fda553ab6 100644 --- a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java +++ b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java @@ -12,7 +12,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.NotifyOnceListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.UUIDs; @@ -95,10 +94,10 @@ public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionL // use NotifyOnceListener to make sure the following line does not result in onFailure being called when // the connection is closed in the onResponse handler - transportService.handshake(connection, probeHandshakeTimeout, new NotifyOnceListener<>() { + transportService.handshake(connection, probeHandshakeTimeout, ActionListener.notifyOnce(new ActionListener<>() { @Override - protected void innerOnResponse(DiscoveryNode remoteNode) { + public void onResponse(DiscoveryNode remoteNode) { try { // success means (amongst other things) that the cluster names match logger.trace("[{}] handshake successful: {}", transportAddress, remoteNode); @@ -166,7 +165,7 @@ public void onFailure(Exception e) { } @Override - protected void innerOnFailure(Exception e) { + public void onFailure(Exception e) { // we opened a connection and successfully performed a low-level handshake, so we were definitely // talking to an Elasticsearch node, but the high-level handshake failed indicating some kind of // mismatched configurations (e.g. cluster name) that the user should address @@ -175,7 +174,7 @@ protected void innerOnFailure(Exception e) { listener.onFailure(e); } - }); + })); }) ); diff --git a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java index 84fb08f03adc9..a43864a9938c0 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java @@ -14,7 +14,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -209,6 +211,53 @@ public void onFailure(Exception e) { } } + public void testConcurrentNotifyOnce() throws InterruptedException { + final var completed = new AtomicBoolean(); + final var listener = ActionListener.notifyOnce(new ActionListener() { + @Override + public void onResponse(Void o) { + assertTrue(completed.compareAndSet(false, true)); + } + + @Override + public void onFailure(Exception e) { + assertTrue(completed.compareAndSet(false, true)); + } + + @Override + public String toString() { + return "inner-listener"; + } + }); + assertThat(listener.toString(), equalTo("notifyOnce[inner-listener]")); + + final var threads = new Thread[between(1, 10)]; + final var startBarrier = new CyclicBarrier(threads.length); + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + try { + startBarrier.await(10, TimeUnit.SECONDS); + } catch (Exception e) { + throw new AssertionError(e); + } + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new RuntimeException("test")); + } + }); + } + + for (Thread thread : threads) { + thread.start(); + } + for (Thread thread : threads) { + thread.join(); + } + + assertTrue(completed.get()); + } + public void testCompleteWith() { PlainActionFuture onResponseListener = new PlainActionFuture<>(); ActionListener.completeWith(onResponseListener, () -> 100); diff --git a/server/src/test/java/org/elasticsearch/action/NotifyOnceListenerTests.java b/server/src/test/java/org/elasticsearch/action/NotifyOnceListenerTests.java deleted file mode 100644 index fa6761b2bf3c9..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/NotifyOnceListenerTests.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.action; - -import org.elasticsearch.test.ESTestCase; - -import java.util.concurrent.atomic.AtomicReference; - -public class NotifyOnceListenerTests extends ESTestCase { - - public void testWhenSuccessCannotNotifyMultipleTimes() { - AtomicReference response = new AtomicReference<>(); - AtomicReference exception = new AtomicReference<>(); - - NotifyOnceListener listener = new NotifyOnceListener() { - @Override - public void innerOnResponse(String s) { - response.set(s); - } - - @Override - public void innerOnFailure(Exception e) { - exception.set(e); - } - }; - - listener.onResponse("response"); - listener.onResponse("wrong-response"); - listener.onFailure(new RuntimeException()); - - assertNull(exception.get()); - assertEquals("response", response.get()); - } - - public void testWhenErrorCannotNotifyMultipleTimes() { - AtomicReference response = new AtomicReference<>(); - AtomicReference exception = new AtomicReference<>(); - - NotifyOnceListener listener = new NotifyOnceListener() { - @Override - public void innerOnResponse(String s) { - response.set(s); - } - - @Override - public void innerOnFailure(Exception e) { - exception.set(e); - } - }; - - RuntimeException expected = new RuntimeException(); - listener.onFailure(expected); - listener.onFailure(new IllegalArgumentException()); - listener.onResponse("response"); - - assertNull(response.get()); - assertSame(expected, exception.get()); - } -}