Skip to content

Commit

Permalink
[Backport 1.3] Allow customization of netty channel handles before an…
Browse files Browse the repository at this point in the history
…d during decompression (opensearch-project#10261)

Signed-off-by: Peter Nied <petern@amazon.com>
  • Loading branch information
cwperks authored and peternied committed Nov 3, 2023
1 parent 77adb21 commit fa8ccf0
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased 1.3.x]

### Added
- Improve compressed request handling ([#10261](https://github.com/opensearch-project/OpenSearch/pull/10261))

### Dependencies
- Bump asm from 9.5 to 9.6 ([#10302](https://github.com/opensearch-project/OpenSearch/pull/10302))
- Bump netty from 4.1.97.Final to 4.1.99.Final ([#10303](https://github.com/opensearch-project/OpenSearch/pull/10303))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.http.netty4;

import org.opensearch.OpenSearchNetty4IntegTestCase;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.http.HttpServerTransport;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.OpenSearchIntegTestCase.ClusterScope;
import org.opensearch.test.OpenSearchIntegTestCase.Scope;
import org.opensearch.transport.Netty4BlockingPlugin;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import io.netty.buffer.ByteBufUtil;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http2.HttpConversionUtil;
import io.netty.util.ReferenceCounted;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static io.netty.handler.codec.http.HttpHeaderNames.HOST;

@ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1)
public class Netty4HeaderVerifierIT extends OpenSearchNetty4IntegTestCase {

@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singletonList(Netty4BlockingPlugin.class);
}

public void testThatNettyHttpServerRequestBlockedWithHeaderVerifier() throws Exception {
HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class);
TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses();
TransportAddress transportAddress = randomFrom(boundAddresses);

final FullHttpRequest blockedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
blockedRequest.headers().add("blockme", "Not Allowed");
blockedRequest.headers().add(HOST, "localhost");
blockedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http");

final List<FullHttpResponse> responses = new ArrayList<>();
try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2()) {
try {
FullHttpResponse blockedResponse = nettyHttpClient.send(transportAddress.address(), blockedRequest);
responses.add(blockedResponse);
String blockedResponseContent = new String(ByteBufUtil.getBytes(blockedResponse.content()), StandardCharsets.UTF_8);
assertThat(blockedResponseContent, containsString("Hit header_verifier"));
assertThat(blockedResponse.status().code(), equalTo(401));
} finally {
responses.forEach(ReferenceCounted::release);
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport;

import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.http.HttpServerTransport;
import org.opensearch.http.netty4.Netty4HttpServerTransport;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;

import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.function.Supplier;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.ReferenceCountUtil;

public class Netty4BlockingPlugin extends Netty4ModulePlugin {

public class Netty4BlockingHttpServerTransport extends Netty4HttpServerTransport {

public Netty4BlockingHttpServerTransport(
Settings settings,
NetworkService networkService,
BigArrays bigArrays,
ThreadPool threadPool,
NamedXContentRegistry xContentRegistry,
Dispatcher dispatcher,
ClusterSettings clusterSettings,
SharedGroupFactory sharedGroupFactory,
Tracer tracer
) {
super(
settings,
networkService,
bigArrays,
threadPool,
xContentRegistry,
dispatcher,
clusterSettings,
sharedGroupFactory,
tracer
);
}

@Override
protected ChannelInboundHandlerAdapter createHeaderVerifier() {
return new ExampleBlockingNetty4HeaderVerifier();
}
}

@Override
public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
Settings settings,
ThreadPool threadPool,
BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler,
CircuitBreakerService circuitBreakerService,
NamedXContentRegistry xContentRegistry,
NetworkService networkService,
HttpServerTransport.Dispatcher dispatcher,
ClusterSettings clusterSettings,
Tracer tracer
) {
return Collections.singletonMap(
NETTY_HTTP_TRANSPORT_NAME,
() -> new Netty4BlockingHttpServerTransport(
settings,
networkService,
bigArrays,
threadPool,
xContentRegistry,
dispatcher,
clusterSettings,
getSharedGroupFactory(settings),
tracer
)
);
}

/** POC for how an external header verifier would be implemented */
public class ExampleBlockingNetty4HeaderVerifier extends SimpleChannelInboundHandler<DefaultHttpRequest> {

@Override
public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) throws Exception {
ReferenceCountUtil.retain(msg);
if (isBlocked(msg)) {
ByteBuf buf = Unpooled.copiedBuffer("Hit header_verifier".getBytes(StandardCharsets.UTF_8));
final FullHttpResponse response = new DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.UNAUTHORIZED, buf);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
ReferenceCountUtil.release(msg);
} else {
// Lets the request pass to the next channel handler
ctx.fireChannelRead(msg);
}
}

private boolean isBlocked(HttpRequest request) {
final boolean shouldBlock = request.headers().contains("blockme");

return shouldBlock;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, handlingSettings);
}

static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
public static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
static final AttributeKey<Netty4HttpServerChannel> HTTP_SERVER_CHANNEL_KEY = AttributeKey.newInstance("es-http-server-channel");

protected static class HttpChannelHandler extends ChannelInitializer<Channel> {
Expand Down Expand Up @@ -348,7 +348,8 @@ protected void initChannel(Channel ch) throws Exception {
);
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
ch.pipeline().addLast("decoder", decoder);
ch.pipeline().addLast("decoder_compress", new HttpContentDecompressor());
ch.pipeline().addLast("header_verifier", transport.createHeaderVerifier());
ch.pipeline().addLast("decoder_compress", transport.createDecompressor());
ch.pipeline().addLast("encoder", new HttpResponseEncoder());
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
Expand Down Expand Up @@ -390,4 +391,21 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}
}

/**
* Extension point that allows a NetworkPlugin to extend the netty pipeline and inspect headers after request decoding
*/
protected ChannelInboundHandlerAdapter createHeaderVerifier() {
// pass-through
return new ChannelInboundHandlerAdapter();
}

/**
* Extension point that allows a NetworkPlugin to override the default netty HttpContentDecompressor and supply a custom decompressor.
*
* Used in instances to conditionally decompress depending on the outcome from header verification
*/
protected ChannelInboundHandlerAdapter createDecompressor() {
return new HttpContentDecompressor();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
);
}

private SharedGroupFactory getSharedGroupFactory(Settings settings) {
SharedGroupFactory getSharedGroupFactory(Settings settings) {
SharedGroupFactory groupFactory = this.groupFactory.get();
if (groupFactory != null) {
assert groupFactory.getSettings().equals(settings) : "Different settings than originally provided";
Expand Down

0 comments on commit fa8ccf0

Please sign in to comment.