Skip to content

Commit

Permalink
ActionListener#notifyOnce should release delegate on completion (#92452)
Browse files Browse the repository at this point in the history
There's no need to keep hold of the delegate after completing it, and in
some cases this might hold on to excessive heap. With this commit we
drop the reference to the delegate when it's complete.

Closes #92451
  • Loading branch information
DaveCTurner authored Dec 19, 2022
1 parent 4be7743 commit 661ea5f
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 128 deletions.
23 changes: 18 additions & 5 deletions server/src/main/java/org/elasticsearch/action/ActionListener.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

Expand Down Expand Up @@ -418,15 +419,27 @@ public String toString() {
* and {@link #onFailure(Exception)} of the provided listener will be called at most once.
*/
static <Response> ActionListener<Response> notifyOnce(ActionListener<Response> delegate) {
return new NotifyOnceListener<Response>() {
final var delegateRef = new AtomicReference<>(delegate);
return new ActionListener<>() {
@Override
protected void innerOnResponse(Response response) {
delegate.onResponse(response);
public void onResponse(Response response) {
final var acquired = delegateRef.getAndSet(null);
if (acquired != null) {
acquired.onResponse(response);
}
}

@Override
protected void innerOnFailure(Exception e) {
delegate.onFailure(e);
public void onFailure(Exception e) {
final var acquired = delegateRef.getAndSet(null);
if (acquired != null) {
acquired.onFailure(e);
}
}

@Override
public String toString() {
return "notifyOnce[" + delegate + "]";
}
};
}
Expand Down

This file was deleted.

17 changes: 12 additions & 5 deletions server/src/main/java/org/elasticsearch/action/StepListener.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiFunction;
import java.util.function.Consumer;

Expand All @@ -40,21 +41,27 @@
* }</pre>
*/

public final class StepListener<Response> extends NotifyOnceListener<Response> {
public final class StepListener<Response> implements ActionListener<Response> {

private final AtomicBoolean hasBeenCalled = new AtomicBoolean(false);
private final ListenableFuture<Response> delegate;

public StepListener() {
this.delegate = new ListenableFuture<>();
}

@Override
protected void innerOnResponse(Response response) {
delegate.onResponse(response);
public void onResponse(Response response) {
if (hasBeenCalled.compareAndSet(false, true)) {
delegate.onResponse(response);
}
}

@Override
protected void innerOnFailure(Exception e) {
delegate.onFailure(e);
public void onFailure(Exception e) {
if (hasBeenCalled.compareAndSet(false, true)) {
delegate.onFailure(e);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.NotifyOnceListener;
import org.elasticsearch.action.admin.indices.close.CloseIndexClusterStateUpdateRequest;
import org.elasticsearch.action.admin.indices.close.CloseIndexResponse;
import org.elasticsearch.action.admin.indices.close.CloseIndexResponse.IndexResult;
Expand Down Expand Up @@ -619,9 +618,9 @@ private void waitForShardsReadyForClosing(
for (int i = 0; i < indexRoutingTable.size(); i++) {
IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i);
final int shardId = shardRoutingTable.shardId().id();
sendVerifyShardBeforeCloseRequest(shardRoutingTable, closingBlock, new NotifyOnceListener<>() {
sendVerifyShardBeforeCloseRequest(shardRoutingTable, closingBlock, ActionListener.notifyOnce(new ActionListener<>() {
@Override
public void innerOnResponse(final ReplicationResponse replicationResponse) {
public void onResponse(ReplicationResponse replicationResponse) {
ShardResult.Failure[] failures = Arrays.stream(replicationResponse.getShardInfo().getFailures())
.map(f -> new ShardResult.Failure(f.index(), f.shardId(), f.getCause(), f.nodeId()))
.toArray(ShardResult.Failure[]::new);
Expand All @@ -630,7 +629,7 @@ public void innerOnResponse(final ReplicationResponse replicationResponse) {
}

@Override
public void innerOnFailure(final Exception e) {
public void onFailure(Exception e) {
ShardResult.Failure failure = new ShardResult.Failure(index.getName(), shardId, e);
results.setOnce(shardId, new ShardResult(shardId, new ShardResult.Failure[] { failure }));
processIfFinished();
Expand All @@ -641,7 +640,7 @@ private void processIfFinished() {
onResponse.accept(new IndexResult(index, results.toArray(new ShardResult[results.length()])));
}
}
});
}));
}
}

Expand Down Expand Up @@ -749,9 +748,9 @@ private void waitForShardsReady(
for (int i = 0; i < indexRoutingTable.size(); i++) {
IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i);
final int shardId = shardRoutingTable.shardId().id();
sendVerifyShardBlockRequest(shardRoutingTable, clusterBlock, new NotifyOnceListener<>() {
sendVerifyShardBlockRequest(shardRoutingTable, clusterBlock, ActionListener.notifyOnce(new ActionListener<>() {
@Override
public void innerOnResponse(final ReplicationResponse replicationResponse) {
public void onResponse(ReplicationResponse replicationResponse) {
AddBlockShardResult.Failure[] failures = Arrays.stream(replicationResponse.getShardInfo().getFailures())
.map(f -> new AddBlockShardResult.Failure(f.index(), f.shardId(), f.getCause(), f.nodeId()))
.toArray(AddBlockShardResult.Failure[]::new);
Expand All @@ -760,7 +759,7 @@ public void innerOnResponse(final ReplicationResponse replicationResponse) {
}

@Override
public void innerOnFailure(final Exception e) {
public void onFailure(Exception e) {
AddBlockShardResult.Failure failure = new AddBlockShardResult.Failure(index.getName(), shardId, e);
results.setOnce(shardId, new AddBlockShardResult(shardId, new AddBlockShardResult.Failure[] { failure }));
processIfFinished();
Expand All @@ -773,7 +772,7 @@ private void processIfFinished() {
onResponse.accept(result);
}
}
});
}));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NotifyOnceListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.UUIDs;
Expand Down Expand Up @@ -95,10 +94,10 @@ public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionL

// use NotifyOnceListener to make sure the following line does not result in onFailure being called when
// the connection is closed in the onResponse handler
transportService.handshake(connection, probeHandshakeTimeout, new NotifyOnceListener<>() {
transportService.handshake(connection, probeHandshakeTimeout, ActionListener.notifyOnce(new ActionListener<>() {

@Override
protected void innerOnResponse(DiscoveryNode remoteNode) {
public void onResponse(DiscoveryNode remoteNode) {
try {
// success means (amongst other things) that the cluster names match
logger.trace("[{}] handshake successful: {}", transportAddress, remoteNode);
Expand Down Expand Up @@ -166,7 +165,7 @@ public void onFailure(Exception e) {
}

@Override
protected void innerOnFailure(Exception e) {
public void onFailure(Exception e) {
// we opened a connection and successfully performed a low-level handshake, so we were definitely
// talking to an Elasticsearch node, but the high-level handshake failed indicating some kind of
// mismatched configurations (e.g. cluster name) that the user should address
Expand All @@ -175,7 +174,7 @@ protected void innerOnFailure(Exception e) {
listener.onFailure(e);
}

});
}));

})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -209,6 +211,53 @@ public void onFailure(Exception e) {
}
}

public void testConcurrentNotifyOnce() throws InterruptedException {
final var completed = new AtomicBoolean();
final var listener = ActionListener.notifyOnce(new ActionListener<Void>() {
@Override
public void onResponse(Void o) {
assertTrue(completed.compareAndSet(false, true));
}

@Override
public void onFailure(Exception e) {
assertTrue(completed.compareAndSet(false, true));
}

@Override
public String toString() {
return "inner-listener";
}
});
assertThat(listener.toString(), equalTo("notifyOnce[inner-listener]"));

final var threads = new Thread[between(1, 10)];
final var startBarrier = new CyclicBarrier(threads.length);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
try {
startBarrier.await(10, TimeUnit.SECONDS);
} catch (Exception e) {
throw new AssertionError(e);
}
if (randomBoolean()) {
listener.onResponse(null);
} else {
listener.onFailure(new RuntimeException("test"));
}
});
}

for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
thread.join();
}

assertTrue(completed.get());
}

public void testCompleteWith() {
PlainActionFuture<Integer> onResponseListener = new PlainActionFuture<>();
ActionListener.completeWith(onResponseListener, () -> 100);
Expand Down

This file was deleted.

0 comments on commit 661ea5f

Please sign in to comment.