Skip to content

Commit

Permalink
Ensure websocket compression is enabled when server is configured wit…
Browse files Browse the repository at this point in the history
…h HttpProtocol.H2C and HttpProtocol.HTTP1.1 (#3037)

Websocket compression handler has to be located after the HTTP codec.

Fixes #3036
  • Loading branch information
violetagg authored Jan 23, 2024
1 parent 9d6df28 commit 245aa54
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2011-2023 VMware, Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2011-2024 VMware, Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,12 +20,15 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
Expand Down Expand Up @@ -103,11 +106,21 @@ final class WebsocketServerOperations extends HttpServerOperations
WebSocketServerCompressionHandler wsServerCompressionHandler =
new WebSocketServerCompressionHandler();
try {
wsServerCompressionHandler.channelRead(channel.pipeline()
.context(NettyPipeline.ReactiveBridge),
request);

addHandlerFirst(NettyPipeline.WsCompressionHandler, wsServerCompressionHandler);
ChannelPipeline pipeline = channel.pipeline();
wsServerCompressionHandler.channelRead(pipeline.context(NettyPipeline.ReactiveBridge), request);

String baseName = null;
if (pipeline.get(NettyPipeline.HttpCodec) != null) {
baseName = NettyPipeline.HttpCodec;
}
else {
ChannelHandler httpServerCodec = pipeline.get(HttpServerCodec.class);
if (httpServerCodec != null) {
baseName = pipeline.context(httpServerCodec).name();
}
}

pipeline.addAfter(baseName, NettyPipeline.WsCompressionHandler, wsServerCompressionHandler);
}
catch (Throwable e) {
log.error(format(channel(), ""), e);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2011-2023 VMware, Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2011-2024 VMware, Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

import java.net.URI;
import java.nio.charset.Charset;
import java.security.cert.CertificateException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -29,6 +30,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;

import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
Expand All @@ -45,7 +47,14 @@
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -55,11 +64,16 @@
import reactor.netty.Connection;
import reactor.netty.ConnectionObserver;
import reactor.netty.channel.AbortedException;
import reactor.netty.http.Http11SslContextSpec;
import reactor.netty.http.Http2SslContextSpec;
import reactor.netty.http.HttpProtocol;
import reactor.netty.http.logging.ReactorNettyHttpMessageLogFactory;
import reactor.netty.http.server.HttpServer;
import reactor.netty.http.server.WebsocketServerSpec;
import reactor.netty.http.websocket.WebsocketInbound;
import reactor.netty.http.websocket.WebsocketOutbound;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.tcp.SslProvider;
import reactor.test.StepVerifier;
import reactor.util.Logger;
import reactor.util.Loggers;
Expand All @@ -81,6 +95,23 @@ class WebsocketTest extends BaseHttpTest {

static final Logger log = Loggers.getLogger(WebsocketTest.class);

static SelfSignedCertificate ssc;
static Http11SslContextSpec serverCtx11;
static Http2SslContextSpec serverCtx2;
static Http11SslContextSpec clientCtx11;
static Http2SslContextSpec clientCtx2;

@BeforeAll
static void createSelfSignedCertificate() throws CertificateException {
ssc = new SelfSignedCertificate();
serverCtx11 = Http11SslContextSpec.forServer(ssc.certificate(), ssc.privateKey());
serverCtx2 = Http2SslContextSpec.forServer(ssc.certificate(), ssc.privateKey());
clientCtx11 = Http11SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));
clientCtx2 = Http2SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));
}

@Test
void simpleTest() {
disposableServer = createServer()
Expand Down Expand Up @@ -1433,4 +1464,49 @@ private void doTestConnectionClosedWhenFailedUpgrade(

assertThat(latch.await(5, TimeUnit.SECONDS)).as("latch await").isTrue();
}

@ParameterizedTest
@MethodSource("http11CompatibleProtocols")
public void testIssue3036(HttpProtocol[] serverProtocols, HttpProtocol[] clientProtocols,
@Nullable SslProvider.ProtocolSslContextSpec serverCtx, @Nullable SslProvider.ProtocolSslContextSpec clientCtx) {
WebsocketServerSpec websocketServerSpec = WebsocketServerSpec.builder().compress(true).build();

HttpServer httpServer = createServer().protocol(serverProtocols);
if (serverCtx != null) {
httpServer = httpServer.secure(spec -> spec.sslContext(serverCtx));
}

disposableServer =
httpServer.handle((req, res) -> res.sendWebsocket((in, out) -> out.sendString(Mono.just("test")), websocketServerSpec))
.bindNow();

WebsocketClientSpec webSocketClientSpec = WebsocketClientSpec.builder().compress(true).build();

HttpClient httpClient = createClient(disposableServer::address).protocol(clientProtocols);
if (clientCtx != null) {
httpClient = httpClient.secure(spec -> spec.sslContext(clientCtx));
}

AtomicReference<List<String>> responseHeaders = new AtomicReference<>(new ArrayList<>());
httpClient.websocket(webSocketClientSpec)
.handle((in, out) -> {
responseHeaders.set(in.headers().getAll(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS));
return out.sendClose();
})
.then()
.block(Duration.ofSeconds(5));

assertThat(responseHeaders.get()).contains("permessage-deflate");
}

static Stream<Arguments> http11CompatibleProtocols() {
return Stream.of(
Arguments.of(new HttpProtocol[]{HttpProtocol.HTTP11}, new HttpProtocol[]{HttpProtocol.HTTP11}, null, null),
Arguments.of(new HttpProtocol[]{HttpProtocol.HTTP11}, new HttpProtocol[]{HttpProtocol.HTTP11},
Named.of("Http11SslContextSpec", serverCtx11), Named.of("Http11SslContextSpec", clientCtx11)),
Arguments.of(new HttpProtocol[]{HttpProtocol.H2, HttpProtocol.HTTP11}, new HttpProtocol[]{HttpProtocol.HTTP11},
Named.of("Http2SslContextSpec", serverCtx2), Named.of("Http11SslContextSpec", clientCtx11)),
Arguments.of(new HttpProtocol[]{HttpProtocol.H2C, HttpProtocol.HTTP11}, new HttpProtocol[]{HttpProtocol.HTTP11}, null, null)
);
}
}

0 comments on commit 245aa54

Please sign in to comment.