Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ActionListener#notifyOnce should release delegate on completion #92452

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions server/src/main/java/org/elasticsearch/action/ActionListener.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

Expand Down Expand Up @@ -418,15 +419,22 @@ public String toString() {
* and {@link #onFailure(Exception)} of the provided listener will be called at most once.
*/
static <Response> ActionListener<Response> notifyOnce(ActionListener<Response> delegate) {
return new NotifyOnceListener<Response>() {
final var delegateRef = new AtomicReference<>(delegate);
return new ActionListener<>() {
@Override
protected void innerOnResponse(Response response) {
delegate.onResponse(response);
public void onResponse(Response response) {
final var acquired = delegateRef.getAndSet(null);
if (acquired != null) {
acquired.onResponse(response);
}
}

@Override
protected void innerOnFailure(Exception e) {
delegate.onFailure(e);
public void onFailure(Exception e) {
final var acquired = delegateRef.getAndSet(null);
if (acquired != null) {
acquired.onFailure(e);
}
}
};
}
Expand Down

This file was deleted.

17 changes: 12 additions & 5 deletions server/src/main/java/org/elasticsearch/action/StepListener.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiFunction;
import java.util.function.Consumer;

Expand All @@ -40,21 +41,27 @@
* }</pre>
*/

public final class StepListener<Response> extends NotifyOnceListener<Response> {
public final class StepListener<Response> implements ActionListener<Response> {

private final AtomicBoolean hasBeenCalled = new AtomicBoolean(false);
private final ListenableFuture<Response> delegate;

public StepListener() {
this.delegate = new ListenableFuture<>();
}

@Override
protected void innerOnResponse(Response response) {
delegate.onResponse(response);
public void onResponse(Response response) {
if (hasBeenCalled.compareAndSet(false, true)) {
delegate.onResponse(response);
}
}

@Override
protected void innerOnFailure(Exception e) {
delegate.onFailure(e);
public void onFailure(Exception e) {
if (hasBeenCalled.compareAndSet(false, true)) {
delegate.onFailure(e);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.NotifyOnceListener;
import org.elasticsearch.action.admin.indices.close.CloseIndexClusterStateUpdateRequest;
import org.elasticsearch.action.admin.indices.close.CloseIndexResponse;
import org.elasticsearch.action.admin.indices.close.CloseIndexResponse.IndexResult;
Expand Down Expand Up @@ -619,9 +618,9 @@ private void waitForShardsReadyForClosing(
for (int i = 0; i < indexRoutingTable.size(); i++) {
IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i);
final int shardId = shardRoutingTable.shardId().id();
sendVerifyShardBeforeCloseRequest(shardRoutingTable, closingBlock, new NotifyOnceListener<>() {
sendVerifyShardBeforeCloseRequest(shardRoutingTable, closingBlock, ActionListener.notifyOnce(new ActionListener<>() {
@Override
public void innerOnResponse(final ReplicationResponse replicationResponse) {
public void onResponse(ReplicationResponse replicationResponse) {
ShardResult.Failure[] failures = Arrays.stream(replicationResponse.getShardInfo().getFailures())
.map(f -> new ShardResult.Failure(f.index(), f.shardId(), f.getCause(), f.nodeId()))
.toArray(ShardResult.Failure[]::new);
Expand All @@ -630,7 +629,7 @@ public void innerOnResponse(final ReplicationResponse replicationResponse) {
}

@Override
public void innerOnFailure(final Exception e) {
public void onFailure(Exception e) {
ShardResult.Failure failure = new ShardResult.Failure(index.getName(), shardId, e);
results.setOnce(shardId, new ShardResult(shardId, new ShardResult.Failure[] { failure }));
processIfFinished();
Expand All @@ -641,7 +640,7 @@ private void processIfFinished() {
onResponse.accept(new IndexResult(index, results.toArray(new ShardResult[results.length()])));
}
}
});
}));
}
}

Expand Down Expand Up @@ -749,9 +748,9 @@ private void waitForShardsReady(
for (int i = 0; i < indexRoutingTable.size(); i++) {
IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i);
final int shardId = shardRoutingTable.shardId().id();
sendVerifyShardBlockRequest(shardRoutingTable, clusterBlock, new NotifyOnceListener<>() {
sendVerifyShardBlockRequest(shardRoutingTable, clusterBlock, ActionListener.notifyOnce(new ActionListener<>() {
@Override
public void innerOnResponse(final ReplicationResponse replicationResponse) {
public void onResponse(ReplicationResponse replicationResponse) {
AddBlockShardResult.Failure[] failures = Arrays.stream(replicationResponse.getShardInfo().getFailures())
.map(f -> new AddBlockShardResult.Failure(f.index(), f.shardId(), f.getCause(), f.nodeId()))
.toArray(AddBlockShardResult.Failure[]::new);
Expand All @@ -760,7 +759,7 @@ public void innerOnResponse(final ReplicationResponse replicationResponse) {
}

@Override
public void innerOnFailure(final Exception e) {
public void onFailure(Exception e) {
AddBlockShardResult.Failure failure = new AddBlockShardResult.Failure(index.getName(), shardId, e);
results.setOnce(shardId, new AddBlockShardResult(shardId, new AddBlockShardResult.Failure[] { failure }));
processIfFinished();
Expand All @@ -773,7 +772,7 @@ private void processIfFinished() {
onResponse.accept(result);
}
}
});
}));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NotifyOnceListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.UUIDs;
Expand Down Expand Up @@ -95,10 +94,10 @@ public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionL

// use NotifyOnceListener to make sure the following line does not result in onFailure being called when
// the connection is closed in the onResponse handler
transportService.handshake(connection, probeHandshakeTimeout, new NotifyOnceListener<>() {
transportService.handshake(connection, probeHandshakeTimeout, ActionListener.notifyOnce(new ActionListener<>() {

@Override
protected void innerOnResponse(DiscoveryNode remoteNode) {
public void onResponse(DiscoveryNode remoteNode) {
try {
// success means (amongst other things) that the cluster names match
logger.trace("[{}] handshake successful: {}", transportAddress, remoteNode);
Expand Down Expand Up @@ -166,7 +165,7 @@ public void onFailure(Exception e) {
}

@Override
protected void innerOnFailure(Exception e) {
public void onFailure(Exception e) {
// we opened a connection and successfully performed a low-level handshake, so we were definitely
// talking to an Elasticsearch node, but the high-level handshake failed indicating some kind of
// mismatched configurations (e.g. cluster name) that the user should address
Expand All @@ -175,7 +174,7 @@ protected void innerOnFailure(Exception e) {
listener.onFailure(e);
}

});
}));

})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -209,6 +211,46 @@ public void onFailure(Exception e) {
}
}

public void testConcurrentNotifyOnce() throws InterruptedException {
final var completed = new AtomicBoolean();
final var listener = ActionListener.notifyOnce(new ActionListener<Void>() {
@Override
public void onResponse(Void o) {
assertTrue(completed.compareAndSet(false, true));
}

@Override
public void onFailure(Exception e) {
assertTrue(completed.compareAndSet(false, true));
}
});
final var threads = new Thread[between(1, 10)];
final var startBarrier = new CyclicBarrier(threads.length);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
try {
startBarrier.await(10, TimeUnit.SECONDS);
} catch (Exception e) {
throw new AssertionError(e);
}
if (randomBoolean()) {
listener.onResponse(null);
} else {
listener.onFailure(new RuntimeException("test"));
}
});
}

for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
thread.join();
}

assertTrue(completed.get());
}

public void testCompleteWith() {
PlainActionFuture<Integer> onResponseListener = new PlainActionFuture<>();
ActionListener.completeWith(onResponseListener, () -> 100);
Expand Down

This file was deleted.