From 6e57d393ba15c8ff80ff6780dc848c23322af737 Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Wed, 23 Oct 2024 10:49:12 +0200 Subject: [PATCH] WebSockets Next: create a new event loop context for each client - we want to avoid a situation where if multiple clients/connections are created in a row, the same event loop is used and so writing/receiving messages is de-facto serialized --- .../client/BasicConnectorContextTest.java | 89 ++++++++++ .../next/test/client/ClientContextTest.java | 104 +++++++++++ .../runtime/BasicWebSocketConnectorImpl.java | 163 ++++++++++-------- .../next/runtime/WebSocketConnectorImpl.java | 78 ++++++--- 4 files changed, 341 insertions(+), 93 deletions(-) create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorContextTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientContextTest.java diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorContextTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorContextTest.java new file mode 100644 index 00000000000000..fcddf32bf6ffab --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorContextTest.java @@ -0,0 +1,89 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BasicWebSocketConnector; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClientConnection; + +public class BasicConnectorContextTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class); + }); + + @TestHTTPResource("/end") + URI uri; + + static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2); + + static final Set THREADS = ConcurrentHashMap.newKeySet(); + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(2); + + @Test + void testClient() throws InterruptedException { + BasicWebSocketConnector connector = BasicWebSocketConnector.create(); + connector + .executionModel(BasicWebSocketConnector.ExecutionModel.NON_BLOCKING) + .onTextMessage((c, m) -> { + String thread = Thread.currentThread().getName(); + THREADS.add(thread); + MESSAGE_LATCH.countDown(); + }) + .onClose((c, cr) -> { + CLOSED_LATCH.countDown(); + }) + .baseUri(uri); + WebSocketClientConnection conn1 = connector.connectAndAwait(); + WebSocketClientConnection conn2 = connector.connectAndAwait(); + assertTrue(MESSAGE_LATCH.await(10, TimeUnit.SECONDS)); + if (Runtime.getRuntime().availableProcessors() > 1) { + // Each client should be executed on a dedicated event loop thread + assertEquals(2, THREADS.size()); + } else { + // Single core - the event pool is shared + // Due to some CI weirdness it might happen that the system incorrectly reports single core + // Therefore, the assert checks if the number of threads used is >= 1 + assertTrue(THREADS.size() >= 1); + } + conn1.closeAndAwait(); + conn2.closeAndAwait(); + assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + } + + @WebSocket(path = "/end") + public static class ServerEndpoint { + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnOpen + String open() { + return "Hello!"; + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientContextTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientContextTest.java new file mode 100644 index 00000000000000..0ea1a055434e51 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientContextTest.java @@ -0,0 +1,104 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketConnector; +import io.smallrye.mutiny.Uni; + +public class ClientContextTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class, ClientEndpoint.class); + }); + + @Inject + WebSocketConnector connector; + + @TestHTTPResource("/") + URI uri; + + @Test + void testClient() throws InterruptedException { + connector.baseUri(uri); + WebSocketClientConnection conn1 = connector.connectAndAwait(); + WebSocketClientConnection conn2 = connector.connectAndAwait(); + assertTrue(ClientEndpoint.MESSAGE_LATCH.await(10, TimeUnit.SECONDS)); + if (Runtime.getRuntime().availableProcessors() > 1) { + // Each client should be executed on a dedicated event loop thread + assertEquals(2, ClientEndpoint.THREADS.size()); + } else { + // Single core - the event pool is shared + // Due to some CI weirdness it might happen that the system incorrectly reports single core + // Therefore, the assert checks if the number of threads used is >= 1 + assertTrue(ClientEndpoint.THREADS.size() >= 1); + } + conn1.closeAndAwait(); + conn2.closeAndAwait(); + assertTrue(ClientEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + } + + @WebSocket(path = "/end") + public static class ServerEndpoint { + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnOpen + String open() { + return "Hello!"; + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + + @WebSocketClient(path = "/end") + public static class ClientEndpoint { + + static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2); + + static final Set THREADS = ConcurrentHashMap.newKeySet(); + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(2); + + @OnTextMessage + Uni onMessage(String message) { + String thread = Thread.currentThread().getName(); + THREADS.add(thread); + MESSAGE_LATCH.countDown(); + return Uni.createFrom().voidItem(); + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java index b9e3af7e7395ff..80af2b91f40075 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java @@ -6,6 +6,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -22,12 +23,16 @@ import io.quarkus.websockets.next.WebSocketClientException; import io.quarkus.websockets.next.WebSocketsClientRuntimeConfig; import io.smallrye.mutiny.Uni; +import io.vertx.core.AsyncResult; import io.vertx.core.Context; import io.vertx.core.Handler; import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.WebSocket; import io.vertx.core.http.WebSocketClient; import io.vertx.core.http.WebSocketConnectOptions; +import io.vertx.core.impl.ContextImpl; +import io.vertx.core.impl.VertxImpl; @Typed(BasicWebSocketConnector.class) @Dependent @@ -111,10 +116,10 @@ public Uni connect() { throw new WebSocketClientException("Endpoint URI not set!"); } - // Currently we create a new client for each connection + // A new client is created for each connection + // The client is created when the returned Uni is subscribed // The client is closed when the connection is closed - // TODO would it make sense to share clients? - WebSocketClient client = vertx.createWebSocketClient(populateClientOptions()); + AtomicReference client = new AtomicReference<>(); WebSocketConnectOptions connectOptions = newConnectOptions(baseUri); StringBuilder requestUri = new StringBuilder(); @@ -140,87 +145,109 @@ public Uni connect() { throw new WebSocketClientException(e); } - return Uni.createFrom().completionStage(() -> client.connect(connectOptions).toCompletionStage()) - .map(ws -> { - String clientId = BasicWebSocketConnector.class.getName(); - TrafficLogger trafficLogger = TrafficLogger.forClient(config); - WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId, ws, - codecs, - pathParams, - serverEndpointUri, - headers, trafficLogger); - if (trafficLogger != null) { - trafficLogger.connectionOpened(connection); - } - connectionManager.add(BasicWebSocketConnectorImpl.class.getName(), connection); + Uni websocket = Uni.createFrom(). emitter(e -> { + // Create a new event loop context for each client, otherwise the current context is used + // We want to avoid a situation where if multiple clients/connections are created in a row, + // the same event loop is used and so writing/receiving messages is de-facto serialized + ContextImpl context = ((VertxImpl) vertx).createEventLoopContext(); + context.dispatch(new Handler() { + @Override + public void handle(Void event) { + WebSocketClient c = vertx.createWebSocketClient(populateClientOptions()); + client.setPlain(c); + c.connect(connectOptions, new Handler>() { + @Override + public void handle(AsyncResult r) { + if (r.succeeded()) { + e.complete(r.result()); + } else { + e.fail(r.cause()); + } + } + }); + } + }); + }); + return websocket.map(ws -> { + String clientId = BasicWebSocketConnector.class.getName(); + TrafficLogger trafficLogger = TrafficLogger.forClient(config); + WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId, ws, + codecs, + pathParams, + serverEndpointUri, + headers, trafficLogger); + if (trafficLogger != null) { + trafficLogger.connectionOpened(connection); + } + connectionManager.add(BasicWebSocketConnectorImpl.class.getName(), connection); - if (openHandler != null) { - doExecute(connection, null, (c, ignored) -> openHandler.accept(c)); - } + if (openHandler != null) { + doExecute(connection, null, (c, ignored) -> openHandler.accept(c)); + } - if (textMessageHandler != null) { - ws.textMessageHandler(new Handler() { - @Override - public void handle(String message) { - if (trafficLogger != null) { - trafficLogger.textMessageReceived(connection, message); - } - doExecute(connection, message, textMessageHandler); - } - }); + if (textMessageHandler != null) { + ws.textMessageHandler(new Handler() { + @Override + public void handle(String message) { + if (trafficLogger != null) { + trafficLogger.textMessageReceived(connection, message); + } + doExecute(connection, message, textMessageHandler); } + }); + } - if (binaryMessageHandler != null) { - ws.binaryMessageHandler(new Handler() { + if (binaryMessageHandler != null) { + ws.binaryMessageHandler(new Handler() { - @Override - public void handle(Buffer message) { - if (trafficLogger != null) { - trafficLogger.binaryMessageReceived(connection, message); - } - doExecute(connection, message, binaryMessageHandler); - } - }); + @Override + public void handle(Buffer message) { + if (trafficLogger != null) { + trafficLogger.binaryMessageReceived(connection, message); + } + doExecute(connection, message, binaryMessageHandler); } + }); + } - if (pongMessageHandler != null) { - ws.pongHandler(new Handler() { + if (pongMessageHandler != null) { + ws.pongHandler(new Handler() { - @Override - public void handle(Buffer event) { - doExecute(connection, event, pongMessageHandler); - } - }); + @Override + public void handle(Buffer event) { + doExecute(connection, event, pongMessageHandler); } + }); + } - if (errorHandler != null) { - ws.exceptionHandler(new Handler() { + if (errorHandler != null) { + ws.exceptionHandler(new Handler() { - @Override - public void handle(Throwable event) { - doExecute(connection, event, errorHandler); - } - }); + @Override + public void handle(Throwable event) { + doExecute(connection, event, errorHandler); } + }); + } - ws.closeHandler(new Handler() { + ws.closeHandler(new Handler() { - @Override - public void handle(Void event) { - if (trafficLogger != null) { - trafficLogger.connectionClosed(connection); - } - if (closeHandler != null) { - doExecute(connection, new CloseReason(ws.closeStatusCode(), ws.closeReason()), closeHandler); - } - connectionManager.remove(BasicWebSocketConnectorImpl.class.getName(), connection); - client.close(); - } + @Override + public void handle(Void event) { + if (trafficLogger != null) { + trafficLogger.connectionClosed(connection); + } + if (closeHandler != null) { + doExecute(connection, new CloseReason(ws.closeStatusCode(), ws.closeReason()), closeHandler); + } + connectionManager.remove(BasicWebSocketConnectorImpl.class.getName(), connection); + client.get().close(); + } - }); + }); - return connection; - }); + return connection; + }); } private void doExecute(WebSocketClientConnectionImpl connection, MESSAGE message, diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java index 05b41ce6a336c9..f5525bd6269f9c 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java @@ -7,6 +7,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; import jakarta.enterprise.context.Dependent; import jakarta.enterprise.inject.Typed; @@ -23,9 +24,14 @@ import io.quarkus.websockets.next.runtime.WebSocketClientRecorder.ClientEndpoint; import io.quarkus.websockets.next.runtime.WebSocketClientRecorder.ClientEndpointsContext; import io.smallrye.mutiny.Uni; +import io.vertx.core.AsyncResult; +import io.vertx.core.Handler; import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocket; import io.vertx.core.http.WebSocketClient; import io.vertx.core.http.WebSocketConnectOptions; +import io.vertx.core.impl.ContextImpl; +import io.vertx.core.impl.VertxImpl; @Typed(WebSocketConnector.class) @Dependent @@ -46,10 +52,10 @@ public class WebSocketConnectorImpl extends WebSocketConnectorBase connect() { - // Currently we create a new client for each connection + // A new client is created for each connection + // The client is created when the returned Uni is subscribed // The client is closed when the connection is closed - // TODO would it make sense to share clients? - WebSocketClient client = vertx.createWebSocketClient(populateClientOptions()); + AtomicReference client = new AtomicReference<>(); StringBuilder serverEndpoint = new StringBuilder(); if (baseUri != null) { @@ -88,28 +94,50 @@ public Uni connect() { } subprotocols.forEach(connectOptions::addSubProtocol); - return Uni.createFrom().completionStage(() -> client.connect(connectOptions).toCompletionStage()) - .map(ws -> { - TrafficLogger trafficLogger = TrafficLogger.forClient(config); - WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientEndpoint.clientId, ws, - codecs, - pathParams, - serverEndpointUri, headers, trafficLogger); - if (trafficLogger != null) { - trafficLogger.connectionOpened(connection); - } - connectionManager.add(clientEndpoint.generatedEndpointClass, connection); - - Endpoints.initialize(vertx, Arc.container(), codecs, connection, ws, - clientEndpoint.generatedEndpointClass, config.autoPingInterval(), SecuritySupport.NOOP, - config.unhandledFailureStrategy(), trafficLogger, - () -> { - connectionManager.remove(clientEndpoint.generatedEndpointClass, connection); - client.close(); - }, true); - - return connection; - }); + Uni websocket = Uni.createFrom(). emitter(e -> { + // Create a new event loop context for each client, otherwise the current context is used + // We want to avoid a situation where if multiple clients/connections are created in a row, + // the same event loop is used and so writing/receiving messages is de-facto serialized + ContextImpl context = ((VertxImpl) vertx).createEventLoopContext(); + context.dispatch(new Handler() { + @Override + public void handle(Void event) { + WebSocketClient c = vertx.createWebSocketClient(populateClientOptions()); + client.setPlain(c); + c.connect(connectOptions, new Handler>() { + @Override + public void handle(AsyncResult r) { + if (r.succeeded()) { + e.complete(r.result()); + } else { + e.fail(r.cause()); + } + } + }); + } + }); + }); + return websocket.map(ws -> { + TrafficLogger trafficLogger = TrafficLogger.forClient(config); + WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientEndpoint.clientId, ws, + codecs, + pathParams, + serverEndpointUri, headers, trafficLogger); + if (trafficLogger != null) { + trafficLogger.connectionOpened(connection); + } + connectionManager.add(clientEndpoint.generatedEndpointClass, connection); + + Endpoints.initialize(vertx, Arc.container(), codecs, connection, ws, + clientEndpoint.generatedEndpointClass, config.autoPingInterval(), SecuritySupport.NOOP, + config.unhandledFailureStrategy(), trafficLogger, + () -> { + connectionManager.remove(clientEndpoint.generatedEndpointClass, connection); + client.get().close(); + }, true); + + return connection; + }); } String getEndpointClass(InjectionPoint injectionPoint) {