diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java new file mode 100644 index 0000000000000..4981903e4b097 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java @@ -0,0 +1,114 @@ +package io.quarkus.websockets.next.deployment; + +import java.util.Set; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.MethodParameterInfo; +import org.jboss.jandex.Type; + +import io.quarkus.gizmo.BytecodeCreator; +import io.quarkus.gizmo.ResultHandle; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.WebSocketServerException; + +/** + * Provides arguments for method parameters of a callback method declared on a WebSocket endpoint. + */ +interface CallbackArgument { + + /** + * + * @param context + * @return {@code true} if this provider matches the given parameter context, {@code false} otherwise + * @throws WebSocketServerException If an invalid parameter is detected + */ + boolean matches(ParameterContext context); + + /** + * This method is only used if {@link #matches(ParameterContext)} previously returned {@code true} for the same parameter + * context. + * + * @param context + * @return the result handle to be passed as an argument to a callback method + */ + ResultHandle get(InvocationBytecodeContext context); + + /** + * + * @return the priority + */ + default int priotity() { + return DEFAULT_PRIORITY; + } + + static final int DEFAULT_PRIORITY = 1; + + interface ParameterContext { + + /** + * + * @return the endpoint path + */ + String endpointPath(); + + /** + * + * @return the callback marker annotation + */ + AnnotationInstance callbackAnnotation(); + + /** + * + * @return the Java method parameter + */ + MethodParameterInfo parameter(); + + /** + * + * @return the set of parameter annotations, potentially transformed + */ + Set parameterAnnotations(); + + default boolean acceptsMessage() { + return WebSocketDotNames.ON_BINARY_MESSAGE.equals(callbackAnnotation().name()) + || WebSocketDotNames.ON_TEXT_MESSAGE.equals(callbackAnnotation().name()) + || WebSocketDotNames.ON_PONG_MESSAGE.equals(callbackAnnotation().name()); + } + + } + + interface InvocationBytecodeContext extends ParameterContext { + + /** + * + * @return the bytecode + */ + BytecodeCreator bytecode(); + + /** + * Obtains the message directly in the bytecode. + * + * @return the message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks + */ + ResultHandle getMessage(); + + /** + * Attempts to obtain the decoded message directly in the bytecode. + * + * @param parameterType + * @return the decoded message object or {@code null} for {@link OnOpen} and {@link OnClose} callbacks + */ + ResultHandle getDecodedMessage(Type parameterType); + + /** + * Obtains the current connection directly in the bytecode. + * + * @return the current {@link WebSocketConnection}, never {@code null} + */ + ResultHandle getConnection(); + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgumentBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgumentBuildItem.java new file mode 100644 index 0000000000000..6ae75119a9da4 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgumentBuildItem.java @@ -0,0 +1,17 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.builder.item.MultiBuildItem; + +final class CallbackArgumentBuildItem extends MultiBuildItem { + + private final CallbackArgument provider; + + CallbackArgumentBuildItem(CallbackArgument provider) { + this.provider = provider; + } + + CallbackArgument getProvider() { + return provider; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgumentsBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgumentsBuildItem.java new file mode 100644 index 0000000000000..8c0c9498d88c7 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgumentsBuildItem.java @@ -0,0 +1,32 @@ +package io.quarkus.websockets.next.deployment; + +import java.util.ArrayList; +import java.util.List; + +import io.quarkus.builder.item.SimpleBuildItem; +import io.quarkus.websockets.next.deployment.CallbackArgument.ParameterContext; + +final class CallbackArgumentsBuildItem extends SimpleBuildItem { + + final List sortedArguments; + + CallbackArgumentsBuildItem(List providers) { + this.sortedArguments = providers; + } + + /** + * + * @param context + * @return all matching providers, never {@code null} + */ + List findMatching(ParameterContext context) { + List matching = new ArrayList<>(); + for (CallbackArgument argument : sortedArguments) { + if (argument.matches(context)) { + matching.add(argument); + } + } + return matching; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionCallbackArgument.java new file mode 100644 index 0000000000000..d6d16fc468430 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionCallbackArgument.java @@ -0,0 +1,17 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.gizmo.ResultHandle; + +class ConnectionCallbackArgument implements CallbackArgument { + + @Override + public boolean matches(ParameterContext context) { + return context.parameter().type().name().equals(WebSocketDotNames.WEB_SOCKET_CONNECTION); + } + + @Override + public ResultHandle get(InvocationBytecodeContext context) { + return context.getConnection(); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/HandshakeRequestCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/HandshakeRequestCallbackArgument.java new file mode 100644 index 0000000000000..0e252ae9e26da --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/HandshakeRequestCallbackArgument.java @@ -0,0 +1,21 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.gizmo.MethodDescriptor; +import io.quarkus.gizmo.ResultHandle; +import io.quarkus.websockets.next.WebSocketConnection; + +class HandshakeRequestCallbackArgument implements CallbackArgument { + + @Override + public boolean matches(ParameterContext context) { + return context.parameter().type().name().equals(WebSocketDotNames.HANDSHAKE_REQUEST); + } + + @Override + public ResultHandle get(InvocationBytecodeContext context) { + ResultHandle connection = context.getConnection(); + return context.bytecode().invokeInterfaceMethod(MethodDescriptor.ofMethod(WebSocketConnection.class, "handshakeRequest", + WebSocketConnection.HandshakeRequest.class), connection); + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageCallbackArgument.java new file mode 100644 index 0000000000000..45e92e857cf39 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/MessageCallbackArgument.java @@ -0,0 +1,22 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.gizmo.ResultHandle; + +class MessageCallbackArgument implements CallbackArgument { + + @Override + public boolean matches(ParameterContext context) { + return context.acceptsMessage() && context.parameterAnnotations().isEmpty(); + } + + @Override + public ResultHandle get(InvocationBytecodeContext context) { + return context.getDecodedMessage(context.parameter().type()); + } + + @Override + public int priotity() { + return DEFAULT_PRIORITY - 1; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/PathParamCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/PathParamCallbackArgument.java new file mode 100644 index 0000000000000..c821aebd48088 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/PathParamCallbackArgument.java @@ -0,0 +1,78 @@ +package io.quarkus.websockets.next.deployment; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationValue; + +import io.quarkus.arc.processor.Annotations; +import io.quarkus.gizmo.MethodDescriptor; +import io.quarkus.gizmo.ResultHandle; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.WebSocketServerException; + +class PathParamCallbackArgument implements CallbackArgument { + + @Override + public boolean matches(ParameterContext context) { + String paramName = getParamName(context); + if (paramName != null) { + if (!context.parameter().type().name().equals(WebSocketDotNames.STRING)) { + throw new WebSocketServerException("Method parameter annotated with @PathParam must be java.lang.String: " + + WebSocketServerProcessor.callbackToString(context.parameter().method())); + } + List pathParams = getPathParamNames(context.endpointPath()); + if (!pathParams.contains(paramName)) { + throw new WebSocketServerException( + String.format( + "@PathParam name [%s] must be used in the endpoint path [%s]: %s", paramName, + context.endpointPath(), + WebSocketServerProcessor.callbackToString(context.parameter().method()))); + } + return true; + } + return false; + } + + @Override + public ResultHandle get(InvocationBytecodeContext context) { + ResultHandle connection = context.getConnection(); + String paramName = getParamName(context); + return context.bytecode().invokeInterfaceMethod( + MethodDescriptor.ofMethod(WebSocketConnection.class, "pathParam", String.class, String.class), connection, + context.bytecode().load(paramName)); + } + + private String getParamName(ParameterContext context) { + AnnotationInstance pathParamAnnotation = Annotations.find(context.parameterAnnotations(), WebSocketDotNames.PATH_PARAM); + if (pathParamAnnotation != null) { + String paramName; + AnnotationValue nameVal = pathParamAnnotation.value(); + if (nameVal != null) { + paramName = nameVal.asString(); + } else { + // Try to use the element name + paramName = context.parameter().name(); + } + if (paramName == null) { + throw new WebSocketServerException(String.format( + "Unable to extract the path parameter name - method parameter names not recorded for %s: compile the class with -parameters", + context.parameter().method().declaringClass().name())); + } + return paramName; + } + return null; + } + + static List getPathParamNames(String path) { + List names = new ArrayList<>(); + Matcher m = WebSocketServerProcessor.TRANSLATED_PATH_PARAM_PATTERN.matcher(path); + while (m.find()) { + names.add(m.group().substring(1)); + } + return names; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java index 6bfe88c3ceca1..c6803b6d7baf1 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java @@ -7,6 +7,7 @@ import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.OnPongMessage; import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; import io.quarkus.websockets.next.WebSocket; import io.quarkus.websockets.next.WebSocketConnection; import io.smallrye.common.annotation.Blocking; @@ -35,4 +36,6 @@ final class WebSocketDotNames { static final DotName JSON_OBJECT = DotName.createSimple(JsonObject.class); static final DotName JSON_ARRAY = DotName.createSimple(JsonArray.class); static final DotName VOID = DotName.createSimple(Void.class); + static final DotName PATH_PARAM = DotName.createSimple(PathParam.class); + static final DotName HANDSHAKE_REQUEST = DotName.createSimple(WebSocketConnection.HandshakeRequest.class); } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java index 58f7920d933b8..90064c89aa95f 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java @@ -1,16 +1,32 @@ package io.quarkus.websockets.next.deployment; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationValue; import org.jboss.jandex.DotName; import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.MethodParameterInfo; import org.jboss.jandex.Type; import org.jboss.jandex.Type.Kind; +import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; +import io.quarkus.arc.processor.Annotations; import io.quarkus.arc.processor.BeanInfo; +import io.quarkus.arc.processor.DotNames; import io.quarkus.builder.item.MultiBuildItem; +import io.quarkus.gizmo.BytecodeCreator; +import io.quarkus.gizmo.FieldDescriptor; +import io.quarkus.gizmo.ResultHandle; import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.deployment.CallbackArgument.InvocationBytecodeContext; +import io.quarkus.websockets.next.deployment.CallbackArgument.ParameterContext; import io.quarkus.websockets.next.runtime.WebSocketEndpoint.ExecutionModel; +import io.quarkus.websockets.next.runtime.WebSocketEndpointBase; /** * This build item represents a WebSocket endpoint class. @@ -44,8 +60,11 @@ public static class Callback { public final MethodInfo method; public final ExecutionModel executionModel; public final MessageType messageType; + public final List arguments; - public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel) { + public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + String endpointPath) { this.method = method; this.annotation = annotation; this.executionModel = executionModel; @@ -58,6 +77,15 @@ public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel } else { this.messageType = MessageType.NONE; } + this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, endpointPath); + } + + public boolean isOnOpen() { + return annotation.name().equals(WebSocketDotNames.ON_OPEN); + } + + public boolean isOnClose() { + return annotation.name().equals(WebSocketDotNames.ON_CLOSE); } public Type returnType() { @@ -118,13 +146,154 @@ private DotName getCodec(String valueName) { return null; } - enum MessageType { + public enum MessageType { NONE, PONG, TEXT, BINARY } + public List messageArguments() { + if (arguments.isEmpty()) { + return List.of(); + } + List ret = new ArrayList<>(); + for (CallbackArgument arg : arguments) { + if (arg instanceof MessageCallbackArgument) { + ret.add(arg); + } + } + return ret; + } + + public ResultHandle[] generateArguments(BytecodeCreator bytecode, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { + if (arguments.isEmpty()) { + return new ResultHandle[] {}; + } + ResultHandle[] resultHandles = new ResultHandle[arguments.size()]; + int idx = 0; + for (CallbackArgument argument : arguments) { + resultHandles[idx] = argument.get( + invocationBytecodeContext(annotation, method.parameters().get(idx), transformedAnnotations, + endpointPath, bytecode)); + idx++; + } + return resultHandles; + } + + static List collectArguments(AnnotationInstance annotation, MethodInfo method, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + String endpointPath) { + List parameters = method.parameters(); + if (parameters.isEmpty()) { + return List.of(); + } + List arguments = new ArrayList<>(parameters.size()); + for (MethodParameterInfo parameter : parameters) { + List found = callbackArguments + .findMatching(parameterContext(annotation, parameter, transformedAnnotations, endpointPath)); + if (found.isEmpty()) { + String msg = String.format("Unable to inject @%s callback parameter '%s' declared on %s: no injector found", + DotNames.simpleName(annotation.name()), + parameter.name() != null ? parameter.name() : "#" + parameter.position(), + WebSocketServerProcessor.callbackToString(method)); + throw new WebSocketServerException(msg); + } else if (found.size() > 1 && (found.get(0).priotity() == found.get(1).priotity())) { + String msg = String.format( + "Unable to inject @%s callback parameter '%s' declared on %s: ambiguous injectors found: %s", + DotNames.simpleName(annotation.name()), + parameter.name() != null ? parameter.name() : "#" + parameter.position(), + WebSocketServerProcessor.callbackToString(method), + found.stream().map(p -> p.getClass().getSimpleName() + ":" + p.priotity())); + throw new WebSocketServerException(msg); + } + arguments.add(found.get(0)); + } + return arguments; + } + + static ParameterContext parameterContext(AnnotationInstance callbackAnnotation, MethodParameterInfo parameter, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { + return new ParameterContext() { + + @Override + public MethodParameterInfo parameter() { + return parameter; + } + + @Override + public Set parameterAnnotations() { + return Annotations.getParameterAnnotations( + transformedAnnotations::getAnnotations, parameter.method(), parameter.position()); + } + + @Override + public AnnotationInstance callbackAnnotation() { + return callbackAnnotation; + } + + @Override + public String endpointPath() { + return endpointPath; + } + + }; + } + + private InvocationBytecodeContext invocationBytecodeContext(AnnotationInstance callbackAnnotation, + MethodParameterInfo parameter, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, + BytecodeCreator bytecode) { + return new InvocationBytecodeContext() { + + @Override + public AnnotationInstance callbackAnnotation() { + return callbackAnnotation; + } + + @Override + public MethodParameterInfo parameter() { + return parameter; + } + + @Override + public Set parameterAnnotations() { + return Annotations.getParameterAnnotations( + transformedAnnotations::getAnnotations, parameter.method(), parameter.position()); + } + + @Override + public String endpointPath() { + return endpointPath; + } + + @Override + public BytecodeCreator bytecode() { + return bytecode; + } + + @Override + public ResultHandle getMessage() { + return acceptsMessage() ? bytecode.getMethodParam(0) : null; + } + + @Override + public ResultHandle getDecodedMessage(Type parameterType) { + return acceptsMessage() + ? WebSocketServerProcessor.decodeMessage(bytecode, acceptsBinaryMessage(), parameterType, + getMessage(), Callback.this) + : null; + } + + @Override + public ResultHandle getConnection() { + return bytecode.readInstanceField( + FieldDescriptor.of(WebSocketEndpointBase.class, "connection", WebSocketConnection.class), + bytecode.getThis()); + } + }; + } + } } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java index a73cb0531725a..9654f58d5a141 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java @@ -3,6 +3,7 @@ import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT; import java.util.ArrayList; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,6 +33,7 @@ import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; import io.quarkus.arc.deployment.CustomScopeBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.DotNames; @@ -85,6 +87,7 @@ public class WebSocketServerProcessor { // Parameter names consist of alphanumeric characters and underscore private static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{[a-zA-Z0-9_]+\\}"); + public static final Pattern TRANSLATED_PATH_PARAM_PATTERN = Pattern.compile(":[a-zA-Z0-9_]+"); @BuildStep FeatureBuildItem feature() { @@ -104,6 +107,8 @@ void unremovableBeans(BuildProducer unremovableBeans) @BuildStep public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, + CallbackArgumentsBuildItem argumentProviders, + TransformedAnnotationsBuildItem transformedAnnotations, BuildProducer endpoints) { IndexView index = beanArchiveIndex.getIndex(); @@ -124,15 +129,16 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, String.format("Multiple endpoints [%s, %s] define the same path: %s", previous, beanClass, path)); } Callback onOpen = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN, - this::validateOnOpen); + argumentProviders, transformedAnnotations, path); Callback onTextMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_TEXT_MESSAGE, - this::validateOnTextMessage); + argumentProviders, transformedAnnotations, path); Callback onBinaryMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, - WebSocketDotNames.ON_BINARY_MESSAGE, - this::validateOnBinaryMessage); + WebSocketDotNames.ON_BINARY_MESSAGE, argumentProviders, transformedAnnotations, path); Callback onPongMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_PONG_MESSAGE, + argumentProviders, transformedAnnotations, path, this::validateOnPongMessage); Callback onClose = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE, + argumentProviders, transformedAnnotations, path, this::validateOnClose); if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) { throw new WebSocketServerException( @@ -152,8 +158,20 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, } } + @BuildStep + CallbackArgumentsBuildItem collectCallbackArguments(List callbackArguments) { + List sorted = new ArrayList<>(); + for (CallbackArgumentBuildItem callbackArgument : callbackArguments) { + sorted.add(callbackArgument.getProvider()); + } + sorted.sort(Comparator.comparingInt(CallbackArgument::priotity).reversed()); + return new CallbackArgumentsBuildItem(sorted); + } + @BuildStep public void generateEndpoints(List endpoints, + CallbackArgumentsBuildItem argumentProviders, + TransformedAnnotationsBuildItem transformedAnnotations, BuildProducer generatedClasses, BuildProducer generatedEndpoints, BuildProducer reflectiveClasses) { @@ -175,7 +193,7 @@ public String apply(String name) { // A new instance of this generated endpoint is created for each client connection // The generated endpoint ensures the correct execution model is used // and delegates callback invocations to the endpoint bean - String generatedName = generateEndpoint(endpoint, classOutput); + String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations, classOutput); reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build()); generatedEndpoints.produce(new GeneratedEndpointBuildItem(endpoint.bean.getImplClazz().name().toString(), generatedName, endpoint.path)); @@ -230,6 +248,14 @@ CustomScopeBuildItem registerSessionScope() { return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class.getName())); } + @BuildStep + void builtinCallbackArguments(BuildProducer providers) { + providers.produce(new CallbackArgumentBuildItem(new MessageCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new ConnectionCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new PathParamCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new HandshakeRequestCallbackArgument())); + } + static String mergePath(String prefix, String path) { if (prefix.endsWith("/")) { prefix = prefix.substring(0, prefix.length() - 1); @@ -260,7 +286,7 @@ static String getPath(String path) { return path.startsWith("/") ? sb.toString() : "/" + sb.toString(); } - private String callbackToString(MethodInfo callback) { + static String callbackToString(MethodInfo callback) { return callback.declaringClass().name() + "#" + callback.name() + "()"; } @@ -281,47 +307,24 @@ private String getPathPrefix(IndexView index, DotName enclosingClassName) { return ""; } - private void validateOnOpen(MethodInfo callback) { - if (!callback.parameters().isEmpty()) { - throw new WebSocketServerException( - "@OnOpen callback must not accept any parameters: " + callbackToString(callback)); - } - } - - private void validateOnTextMessage(MethodInfo callback) { - if (callback.parameters().size() != 1) { - throw new WebSocketServerException( - "@OnTextMessage callback must accept exactly one parameter: " + callbackToString(callback)); - } - } - - private void validateOnBinaryMessage(MethodInfo callback) { - if (callback.parameters().size() != 1) { - throw new WebSocketServerException( - "@OnTextMessage callback must accept exactly one parameter: " + callbackToString(callback)); - } - } - - private void validateOnPongMessage(MethodInfo callback) { + private void validateOnPongMessage(Callback callback) { if (callback.returnType().kind() != Kind.VOID && !WebSocketServerProcessor.isUniVoid(callback.returnType())) { throw new WebSocketServerException( - "@OnPongMessage callback must return void or Uni: " + callbackToString(callback)); - } - if (callback.parameters().size() != 1 || !callback.parameterType(0).name().equals(WebSocketDotNames.BUFFER)) { - throw new WebSocketServerException( - "@OnPongMessage callback must accept exactly one parameter of type io.vertx.core.buffer.Buffer: " - + callbackToString(callback)); + "@OnPongMessage callback must return void or Uni: " + callbackToString(callback.method)); } + // TODO validate message arguments + // List> messageArguments = getMessageArguments(providers); + // if (messageArguments.size() != 1 || !messageArguments.get(0).getKey().type().name().equals(WebSocketDotNames.BUFFER)) { + // throw new WebSocketServerException( + // "@OnPongMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: " + // + callbackToString(callback.method)); + // } } - private void validateOnClose(MethodInfo callback) { + private void validateOnClose(Callback callback) { if (callback.returnType().kind() != Kind.VOID && !WebSocketServerProcessor.isUniVoid(callback.returnType())) { throw new WebSocketServerException( - "@OnClose callback must return void or Uni: " + callbackToString(callback)); - } - if (!callback.parameters().isEmpty()) { - throw new WebSocketServerException( - "@OnClose callback must not accept any parameters: " + callbackToString(callback)); + "@OnClose callback must return void or Uni: " + callbackToString(callback.method)); } } @@ -360,7 +363,10 @@ private void validateOnClose(MethodInfo callback) { * @param classOutput * @return the name of the generated class */ - private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput classOutput) { + private String generateEndpoint(WebSocketEndpointBuildItem endpoint, + CallbackArgumentsBuildItem argumentProviders, + TransformedAnnotationsBuildItem transformedAnnotations, + ClassOutput classOutput) { ClassInfo implClazz = endpoint.bean.getImplClazz(); String baseName; if (implClazz.enclosingClass() != null) { @@ -389,6 +395,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput executionMode.returnValue(executionMode.load(endpoint.executionMode)); if (endpoint.onOpen != null) { + Callback callback = endpoint.onOpen; MethodCreator doOnOpen = endpointCreator.getMethodCreator("doOnOpen", Uni.class, Object.class); // Foo foo = beanInstance("foo"); ResultHandle beanInstance = doOnOpen.invokeSpecialMethod( @@ -396,19 +403,21 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput doOnOpen.getThis(), doOnOpen.load(endpoint.bean.getIdentifier())); // Call the business method TryBlock tryBlock = uniFailureTryBlock(doOnOpen); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(endpoint.onOpen.method), beanInstance); - encodeAndReturnResult(tryBlock, endpoint.onOpen, ret); + ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path); + ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + encodeAndReturnResult(tryBlock, callback, ret); MethodCreator onOpenExecutionModel = endpointCreator.getMethodCreator("onOpenExecutionModel", ExecutionModel.class); - onOpenExecutionModel.returnValue(onOpenExecutionModel.load(endpoint.onOpen.executionModel)); + onOpenExecutionModel.returnValue(onOpenExecutionModel.load(callback.executionModel)); } - generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage); - generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage); - generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage); + generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations); + generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations); + generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations); if (endpoint.onClose != null) { + Callback callback = endpoint.onClose; MethodCreator doOnClose = endpointCreator.getMethodCreator("doOnClose", Uni.class, Object.class); // Foo foo = beanInstance("foo"); ResultHandle beanInstance = doOnClose.invokeSpecialMethod( @@ -416,19 +425,22 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, ClassOutput doOnClose.getThis(), doOnClose.load(endpoint.bean.getIdentifier())); // Call the business method TryBlock tryBlock = uniFailureTryBlock(doOnClose); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(endpoint.onClose.method), beanInstance); - encodeAndReturnResult(tryBlock, endpoint.onClose, ret); + ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path); + ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + encodeAndReturnResult(tryBlock, callback, ret); MethodCreator onCloseExecutionModel = endpointCreator.getMethodCreator("onCloseExecutionModel", ExecutionModel.class); - onCloseExecutionModel.returnValue(onCloseExecutionModel.load(endpoint.onClose.executionModel)); + onCloseExecutionModel.returnValue(onCloseExecutionModel.load(callback.executionModel)); } endpointCreator.close(); return generatedName.replace('/', '.'); } - private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback) { + private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback, + CallbackArgumentsBuildItem paramInjectors, + TransformedAnnotationsBuildItem transformedAnnotations) { if (callback == null) { return; } @@ -456,15 +468,9 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu ResultHandle beanInstance = doOnMessage.invokeSpecialMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), doOnMessage.getThis(), doOnMessage.load(endpoint.bean.getIdentifier())); - ResultHandle[] args; - if (callback.acceptsMessage()) { - args = new ResultHandle[] { decodeMessage(doOnMessage, callback.acceptsBinaryMessage(), - callback.method.parameterType(0), doOnMessage.getMethodParam(0), callback) }; - } else { - args = new ResultHandle[] {}; - } // Call the business method TryBlock tryBlock = uniFailureTryBlock(doOnMessage); + ResultHandle[] args = callback.generateArguments(tryBlock, transformedAnnotations, endpoint.path); ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); encodeAndReturnResult(tryBlock, callback, ret); @@ -498,7 +504,7 @@ private TryBlock uniFailureTryBlock(BytecodeCreator method) { return tryBlock; } - private ResultHandle decodeMessage(MethodCreator method, boolean binaryMessage, Type valueType, ResultHandle value, + static ResultHandle decodeMessage(BytecodeCreator method, boolean binaryMessage, Type valueType, ResultHandle value, Callback callback) { if (WebSocketDotNames.MULTI.equals(valueType.name())) { // Multi is decoded at runtime in the recorder @@ -757,7 +763,15 @@ private void encodeAndReturnResult(BytecodeCreator method, Callback callback, Re } private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, - Consumer validator) { + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + String endpointPath) { + return findCallback(index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath, null); + } + + private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + String endpointPath, + Consumer validator) { ClassInfo aClass = beanClass; List annotations = new ArrayList<>(); while (aClass != null) { @@ -776,8 +790,30 @@ private Callback findCallback(IndexView index, ClassInfo beanClass, DotName anno } else if (annotations.size() == 1) { AnnotationInstance annotation = annotations.get(0); MethodInfo method = annotation.target().asMethod(); - validator.accept(method); - return new Callback(annotation, method, executionModel(method)); + Callback callback = new Callback(annotation, method, executionModel(method), callbackArguments, + transformedAnnotations, endpointPath); + int messageArguments = callback.messageArguments().size(); + if (callback.acceptsMessage()) { + if (messageArguments > 1) { + throw new WebSocketServerException( + String.format("@%s callback may accept at most 1 message parameter; found %s: %s", + DotNames.simpleName(callback.annotation.name()), + messageArguments, + callbackToString(callback.method))); + } + } else { + if (messageArguments != 0) { + throw new WebSocketServerException( + String.format("@%s callback must not accept a message parameter; found %s: %s", + DotNames.simpleName(callback.annotation.name()), + messageArguments, + callbackToString(callback.method))); + } + } + if (validator != null) { + validator.accept(callback); + } + return callback; } throw new WebSocketServerException( String.format("There can be only one callback annotated with %s declared on %s", annotationName, beanClass)); diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java index bc5fcba5fe155..23702fe6ff1bf 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java @@ -6,7 +6,6 @@ import java.util.List; import java.util.Map; import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; import io.quarkus.deployment.IsDevelopment; @@ -17,6 +16,7 @@ import io.quarkus.devui.spi.page.Page; import io.quarkus.websockets.next.deployment.GeneratedEndpointBuildItem; import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem; +import io.quarkus.websockets.next.deployment.WebSocketServerProcessor; import io.quarkus.websockets.next.runtime.devui.WebSocketNextJsonRPCService; public class WebSocketServerDevUIProcessor { @@ -74,11 +74,9 @@ private void addCallback(WebSocketEndpointBuildItem.Callback callback, List { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI testUri; + + @Test + void testArgument() { + String message = "ok"; + String header = "fool"; + WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header), + testUri); + JsonObject reply = client.sendAndAwaitReply(message).toJsonObject(); + assertEquals(header, reply.getString("header"), reply.toString()); + assertEquals(message, reply.getString("message"), reply.toString()); + } + + @WebSocket(path = "/echo") + public static class Echo { + + @Inject + WebSocketConnection c; + + @OnTextMessage + Uni process(WebSocketConnection connection, String message) throws InterruptedException { + assertEquals(c.id(), connection.id()); + return connection.sendText( + new JsonObject() + .put("id", connection.id()) + .put("message", message) + .put("header", connection.handshakeRequest().header("X-Test"))); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/HandshakeRequestArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/HandshakeRequestArgumentTest.java new file mode 100644 index 0000000000000..5bcb9ac19f0ba --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/HandshakeRequestArgumentTest.java @@ -0,0 +1,53 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocketConnectOptions; + +public class HandshakeRequestArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(XTest.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("xtest") + URI testUri; + + @Test + void testArgument() { + WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", "fool"), + testUri); + client.waitForMessages(1); + assertEquals("fool", client.getLastMessage().toString()); + } + + @WebSocket(path = "/xtest") + public static class XTest { + + @OnOpen + String open(HandshakeRequest request) { + return request.header("X-Test"); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java new file mode 100644 index 0000000000000..c5a934eeb6cdf --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java @@ -0,0 +1,38 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class OnCloseInvalidArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void testInvalidArgument() { + fail(); + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @OnClose + void close(List unsupported) { + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java new file mode 100644 index 0000000000000..5f3b9071cf546 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java @@ -0,0 +1,38 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class OnOpenInvalidArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class); + }) + .setExpectedException(WebSocketServerException.class); + + @Test + void testInvalidArgument() { + fail(); + } + + @WebSocket(path = "/end") + public static class Endpoint { + + @OnOpen + void open(List unsupported) { + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentExplicitNameTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentExplicitNameTest.java new file mode 100644 index 0000000000000..aa68c510d3eb4 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentExplicitNameTest.java @@ -0,0 +1,50 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class PathParamArgumentExplicitNameTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MontyEcho.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo/monty") + URI testUri; + + @Test + void testArgument() { + WSClient client = WSClient.create(vertx).connect(testUri); + assertEquals("python:monty", client.sendAndAwaitReply("python").toString()); + } + + @WebSocket(path = "/echo/{grail}") + public static class MontyEcho { + + @OnTextMessage + String process(@PathParam("grail") String life, String message) throws InterruptedException { + return message + ":" + life; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidNameTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidNameTest.java new file mode 100644 index 0000000000000..f23f0343cdf23 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidNameTest.java @@ -0,0 +1,37 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class PathParamArgumentInvalidNameTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MontyEcho.class); + }).setExpectedException(WebSocketServerException.class); + + @Test + void testInvalidArgument() { + fail(); + } + + @WebSocket(path = "/echo/{grail}") + public static class MontyEcho { + + @OnTextMessage + String process(@PathParam String life, String message) throws InterruptedException { + return message + ":" + life; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidTypeTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidTypeTest.java new file mode 100644 index 0000000000000..31097c8bf7180 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidTypeTest.java @@ -0,0 +1,37 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketServerException; + +public class PathParamArgumentInvalidTypeTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MontyEcho.class); + }).setExpectedException(WebSocketServerException.class); + + @Test + void testInvalidArgument() { + fail(); + } + + @WebSocket(path = "/echo/{grail}") + public static class MontyEcho { + + @OnTextMessage + String process(@PathParam Double grail, String message) throws InterruptedException { + return message + ":" + grail; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentTest.java new file mode 100644 index 0000000000000..ff97071f3042d --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentTest.java @@ -0,0 +1,50 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class PathParamArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MontyEcho.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo/monty") + URI testUri; + + @Test + void testArgument() { + WSClient client = WSClient.create(vertx).connect(testUri); + assertEquals("python:monty", client.sendAndAwaitReply("python").toString()); + } + + @WebSocket(path = "/echo/{grail}") + public static class MontyEcho { + + @OnTextMessage + String process(@PathParam String grail, String message) throws InterruptedException { + return message + ":" + grail; + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamConnectionArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamConnectionArgumentTest.java new file mode 100644 index 0000000000000..5a5e0843a9e64 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamConnectionArgumentTest.java @@ -0,0 +1,54 @@ +package io.quarkus.websockets.next.test.args; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocketConnectOptions; + +public class PathParamConnectionArgumentTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MontyEcho.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo/monty/and/foo") + URI testUri; + + @Test + void testArguments() { + String header = "fool"; + WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header), testUri); + assertEquals("foo:python:monty:fool", client.sendAndAwaitReply("python").toString()); + } + + @WebSocket(path = "/echo/{grail}/and/{life}") + public static class MontyEcho { + + @OnTextMessage + String process(@PathParam String life, @PathParam String grail, String message, WebSocketConnection connection) + throws InterruptedException { + return life + ":" + message + ":" + grail + ":" + connection.handshakeRequest().header("X-Test"); + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/PathParam.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/PathParam.java new file mode 100644 index 0000000000000..353e965e194bf --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/PathParam.java @@ -0,0 +1,34 @@ +package io.quarkus.websockets.next; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Identifies an endpoint callback method parameter that should be injected with a value returned from + * {@link WebSocketConnection#pathParam(String)}. + *

+ * The parameter type must be {@link String} and the name must be defined in the relevant endpoint path, otherwise + * the build fails. + * + * @see WebSocketConnection#pathParam(String) + * @see WebSocket + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +public @interface PathParam { + + /** + * Constant value for {@link #value()} indicating that the annotated element's name should be used as-is. + */ + String ELEMENT_NAME = "<>"; + + /** + * The name of the parameter. By default, the element's name is used as-is. + * + * @return the name of the parameter + */ + String value() default ELEMENT_NAME; + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java index 29b7d925fc25d..aa8ac39396031 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java @@ -13,11 +13,15 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import io.quarkus.vertx.core.runtime.VertxBufferImpl; import io.quarkus.websockets.next.WebSocketConnection; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.vertx.UniHelper; import io.vertx.core.buffer.Buffer; +import io.vertx.core.buffer.impl.BufferImpl; import io.vertx.core.http.ServerWebSocket; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; import io.vertx.ext.web.RoutingContext; class WebSocketConnectionImpl implements WebSocketConnection { @@ -75,7 +79,17 @@ public Uni sendBinary(Buffer message) { @Override public Uni sendText(M message) { - return UniHelper.toUni(webSocket.writeTextMessage(codecs.textEncode(message, null).toString())); + String text; + // Use the same conversion rules as defined for the OnTextMessage + if (message instanceof JsonObject || message instanceof JsonArray || message instanceof BufferImpl + || message instanceof VertxBufferImpl) { + text = message.toString(); + } else if (message.getClass().isArray() && message.getClass().arrayType().equals(byte.class)) { + text = Buffer.buffer((byte[]) message).toString(); + } else { + text = codecs.textEncode(message, null); + } + return sendText(text); } @Override diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index fdb032ae57ca4..8d9620f09c10a 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -27,7 +27,8 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private static final Logger LOG = Logger.getLogger(WebSocketEndpointBase.class); - protected final WebSocketConnection connection; + // Keep this field public - there's a problem with ConnectionArgumentProvider reading the protected field in the test mode + public final WebSocketConnection connection; protected final Codecs codecs;