From 40887b2d3ca0b5d8d33009d8682f60a4a43f8b68 Mon Sep 17 00:00:00 2001 From: James Bloom Date: Mon, 25 Dec 2017 07:16:20 +0000 Subject: [PATCH] #257 further clean-up and integration tests for harmonizing mocking and proxy logic --- .../org/mockserver/client/AbstractClient.java | 1 - .../client/{server => }/ClientException.java | 6 +- ...rverClientServerVallidationErrorsTest.java | 2 +- .../client/netty/NettyHttpClient.java | 3 +- .../org/mockserver/echo/http/EchoServer.java | 13 ++ .../echo/http/EchoServerHandler.java | 10 +- .../echo/http/EchoServerInitializer.java | 12 +- .../EchoServerPortUnificationHandler.java} | 63 ++++++--- .../EchoServerUnificationHandler.java | 56 -------- .../PortUnificationEchoServer.java | 2 +- ...vletRequestToMockServerRequestDecoder.java | 20 +-- .../org/mockserver/mock/HttpStateHandler.java | 2 + .../mockserver/mock/action/ActionHandler.java | 11 +- .../org/mockserver/socket/PortFactory.java | 2 +- .../mock/action/ActionHandlerTest.java | 23 ++-- .../org/mockserver/lifecycle/LifeCycle.java | 10 +- .../org/mockserver/mockserver/MockServer.java | 12 +- .../mockserver/MockServerHandler.java | 24 +++- .../mockserver/MockServerInitializer.java | 42 +++--- .../main/java/org/mockserver/proxy/Proxy.java | 23 +++- .../mockserver/proxy/direct/DirectProxy.java | 2 + .../direct/DirectProxyUnificationHandler.java | 14 +- .../org/mockserver/proxy/http/HttpProxy.java | 7 + .../proxy/http/HttpProxyHandler.java | 17 ++- .../http/HttpProxyUnificationHandler.java | 15 ++- .../proxy/relay/RelayConnectHandler.java | 12 +- .../relay/UpstreamProxyRelayHandler.java | 4 +- .../proxy/socks/SocksProxyHandler.java | 14 +- .../unification/HttpContentLengthRemover.java | 4 +- .../unification/PortUnificationHandler.java | 55 ++++++-- ...tAndDirectProxyMockingIntegrationTest.java | 67 ++++++++++ .../ClientAndProxyMockingIntegrationTest.java | 65 +++++++++ ...erverAutoAllocatedPortIntegrationTest.java | 2 +- .../mockserver/MockServerHandlerTest.java | 125 ++++++++++++------ .../DirectProxyUnificationHandlerTest.java | 12 +- .../proxy/http/HttpProxyHandlerTest.java | 54 +++++++- ...ProxyUnificationHandlerSOCKSErrorTest.java | 104 ++++++++------- .../http/HttpProxyUnificationHandlerTest.java | 14 +- .../org/mockserver/proxy/ProxyServlet.java | 20 ++- .../mockserver/proxy/ProxyServletTest.java | 116 +++++++++++++++- .../mockserver/server/MockServerServlet.java | 15 ++- ...ARAbstractClientServerIntegrationTest.java | 3 +- .../server/MockServerServletTest.java | 115 +++++++++++++++- 43 files changed, 890 insertions(+), 303 deletions(-) rename mockserver-client-java/src/main/java/org/mockserver/client/{server => }/ClientException.java (54%) rename mockserver-core/src/main/java/org/mockserver/{server/unification/PortUnificationHandler.java => echo/unification/EchoServerPortUnificationHandler.java} (60%) delete mode 100644 mockserver-core/src/main/java/org/mockserver/echo/unification/EchoServerUnificationHandler.java rename {mockserver-core/src/main/java/org/mockserver/server => mockserver-netty/src/main/java/org/mockserver}/unification/HttpContentLengthRemover.java (88%) rename mockserver-netty/src/main/java/org/mockserver/{proxy => }/unification/PortUnificationHandler.java (74%) create mode 100644 mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndDirectProxyMockingIntegrationTest.java create mode 100644 mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndProxyMockingIntegrationTest.java diff --git a/mockserver-client-java/src/main/java/org/mockserver/client/AbstractClient.java b/mockserver-client-java/src/main/java/org/mockserver/client/AbstractClient.java index cd9764e37..9959b9c1e 100644 --- a/mockserver-client-java/src/main/java/org/mockserver/client/AbstractClient.java +++ b/mockserver-client-java/src/main/java/org/mockserver/client/AbstractClient.java @@ -6,7 +6,6 @@ import org.mockserver.client.netty.NettyHttpClient; import org.mockserver.client.netty.SocketConnectionException; import org.mockserver.client.serialization.*; -import org.mockserver.client.server.ClientException; import org.mockserver.matchers.TimeToLive; import org.mockserver.matchers.Times; import org.mockserver.mock.Expectation; diff --git a/mockserver-client-java/src/main/java/org/mockserver/client/server/ClientException.java b/mockserver-client-java/src/main/java/org/mockserver/client/ClientException.java similarity index 54% rename from mockserver-client-java/src/main/java/org/mockserver/client/server/ClientException.java rename to mockserver-client-java/src/main/java/org/mockserver/client/ClientException.java index 5688edc14..8e4ab4109 100644 --- a/mockserver-client-java/src/main/java/org/mockserver/client/server/ClientException.java +++ b/mockserver-client-java/src/main/java/org/mockserver/client/ClientException.java @@ -1,14 +1,10 @@ -package org.mockserver.client.server; +package org.mockserver.client; /** * @author jamesdbloom */ public class ClientException extends RuntimeException { - public ClientException(String message, Throwable cause) { - super(message, cause); - } - public ClientException(String message) { super(message); } diff --git a/mockserver-client-java/src/test/java/org/mockserver/client/server/MockServerClientServerVallidationErrorsTest.java b/mockserver-client-java/src/test/java/org/mockserver/client/server/MockServerClientServerVallidationErrorsTest.java index db2be098b..ec9ed7dfe 100644 --- a/mockserver-client-java/src/test/java/org/mockserver/client/server/MockServerClientServerVallidationErrorsTest.java +++ b/mockserver-client-java/src/test/java/org/mockserver/client/server/MockServerClientServerVallidationErrorsTest.java @@ -4,11 +4,11 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockserver.client.ClientException; import org.mockserver.echo.http.EchoServer; import org.mockserver.socket.PortFactory; import static org.hamcrest.Matchers.containsString; -import static org.mockito.MockitoAnnotations.initMocks; import static org.mockserver.character.Character.NEW_LINE; import static org.mockserver.model.HttpRequest.request; import static org.mockserver.model.HttpResponse.response; diff --git a/mockserver-core/src/main/java/org/mockserver/client/netty/NettyHttpClient.java b/mockserver-core/src/main/java/org/mockserver/client/netty/NettyHttpClient.java index ea25128d3..25db722df 100644 --- a/mockserver-core/src/main/java/org/mockserver/client/netty/NettyHttpClient.java +++ b/mockserver-core/src/main/java/org/mockserver/client/netty/NettyHttpClient.java @@ -37,8 +37,7 @@ private InetSocketAddress socketAddressFromHostHeader(HttpRequest request) { if (!Strings.isNullOrEmpty(request.getFirstHeader(HOST.toString()))) { boolean isSsl = request.isSecure() != null && request.isSecure(); String[] hostHeaderParts = request.getFirstHeader(HOST.toString()).split(":"); - return new InetSocketAddress(hostHeaderParts[0], hostHeaderParts.length > 1 ? Integer.parseInt(hostHeaderParts[1]) : isSsl ? 443 : 80 - ); + return new InetSocketAddress(hostHeaderParts[0], hostHeaderParts.length > 1 ? Integer.parseInt(hostHeaderParts[1]) : isSsl ? 443 : 80); } else { throw new IllegalArgumentException("Host header must be provided for requests being forwarded, the following request does not include the \"Host\" header:" + NEW_LINE + request); } diff --git a/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServer.java b/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServer.java index d95c36648..5de60d996 100644 --- a/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServer.java +++ b/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServer.java @@ -26,9 +26,11 @@ public class EchoServer { static final AttributeKey LOG_FILTER = AttributeKey.valueOf("SERVER_LOG_FILTER"); static final AttributeKey NEXT_RESPONSE = AttributeKey.valueOf("NEXT_RESPONSE"); + static final AttributeKey ONLY_RESPONSE = AttributeKey.valueOf("ONLY_RESPONSE"); private final LogFilter logFilter = new LogFilter(new LoggingFormatter(LoggerFactory.getLogger(this.getClass()), null)); private final NextResponse nextResponse = new NextResponse(); + private final OnlyResponse onlyResponse = new OnlyResponse(); private final NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup(); @@ -52,6 +54,7 @@ public void run() { .childHandler(new EchoServerInitializer(secure, error)) .childAttr(LOG_FILTER, logFilter) .childAttr(NEXT_RESPONSE, nextResponse) + .childAttr(ONLY_RESPONSE, onlyResponse) .bind(port) .addListener(new ChannelFutureListener() { @Override @@ -90,6 +93,12 @@ public EchoServer withNextResponse(HttpResponse... httpResponses) { return this; } + public EchoServer withOnlyResponse(HttpResponse httpResponse) { + // WARNING: this logic is only for unit tests that run in series and is NOT thread safe!!! + onlyResponse.httpResponse = httpResponse; + return this; + } + public enum Error { CLOSE_CONNECTION, LARGER_CONTENT_LENGTH, @@ -100,4 +109,8 @@ public enum Error { public class NextResponse { public final Queue httpResponse = new LinkedList(); } + + public class OnlyResponse { + public HttpResponse httpResponse; + } } diff --git a/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerHandler.java b/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerHandler.java index 86c750fb4..37e12dff9 100644 --- a/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerHandler.java +++ b/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerHandler.java @@ -28,11 +28,13 @@ public class EchoServerHandler extends SimpleChannelInboundHandler private final EchoServer.Error error; private final LogFilter logFilter; private final EchoServer.NextResponse nextResponse; + private final EchoServer.OnlyResponse onlyResponse; - public EchoServerHandler(EchoServer.Error error, LogFilter logFilter, EchoServer.NextResponse nextResponse) { + EchoServerHandler(EchoServer.Error error, LogFilter logFilter, EchoServer.NextResponse nextResponse, EchoServer.OnlyResponse onlyResponse) { this.error = error; this.logFilter = logFilter; this.nextResponse = nextResponse; + this.onlyResponse = onlyResponse; } protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) { @@ -41,7 +43,11 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) { logFilter.onRequest(new RequestLogEntry(request)); - if (!nextResponse.httpResponse.isEmpty()) { + if (onlyResponse.httpResponse != null) { + // WARNING: this logic is only for unit tests that run in series and is NOT thread safe!!! + DefaultFullHttpResponse httpResponse = new MockServerResponseEncoder().encode(onlyResponse.httpResponse); + ctx.writeAndFlush(httpResponse); + } else if (!nextResponse.httpResponse.isEmpty()) { // WARNING: this logic is only for unit tests that run in series and is NOT thread safe!!! DefaultFullHttpResponse httpResponse = new MockServerResponseEncoder().encode(nextResponse.httpResponse.remove()); ctx.writeAndFlush(httpResponse); diff --git a/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerInitializer.java b/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerInitializer.java index fdbbea632..a608733b8 100644 --- a/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerInitializer.java +++ b/mockserver-core/src/main/java/org/mockserver/echo/http/EchoServerInitializer.java @@ -13,6 +13,7 @@ import static org.mockserver.echo.http.EchoServer.LOG_FILTER; import static org.mockserver.echo.http.EchoServer.NEXT_RESPONSE; +import static org.mockserver.echo.http.EchoServer.ONLY_RESPONSE; import static org.mockserver.socket.NettySslContextFactory.nettySslContextFactory; /** @@ -56,6 +57,15 @@ public void initChannel(SocketChannel channel) throws Exception { pipeline.addLast(new MockServerServerCodec(secure)); - pipeline.addLast(new EchoServerHandler(error, channel.attr(LOG_FILTER).get(), channel.attr(NEXT_RESPONSE).get())); + if (!secure && error == EchoServer.Error.CLOSE_CONNECTION) { + throw new IllegalArgumentException("Error type CLOSE_CONNECTION is not supported in non-secure mode"); + } + + pipeline.addLast(new EchoServerHandler( + error, + channel.attr(LOG_FILTER).get(), + channel.attr(NEXT_RESPONSE).get(), + channel.attr(ONLY_RESPONSE).get() + )); } } diff --git a/mockserver-core/src/main/java/org/mockserver/server/unification/PortUnificationHandler.java b/mockserver-core/src/main/java/org/mockserver/echo/unification/EchoServerPortUnificationHandler.java similarity index 60% rename from mockserver-core/src/main/java/org/mockserver/server/unification/PortUnificationHandler.java rename to mockserver-core/src/main/java/org/mockserver/echo/unification/EchoServerPortUnificationHandler.java index c1b7fabda..ef66f3f99 100644 --- a/mockserver-core/src/main/java/org/mockserver/server/unification/PortUnificationHandler.java +++ b/mockserver-core/src/main/java/org/mockserver/echo/unification/EchoServerPortUnificationHandler.java @@ -1,36 +1,40 @@ -package org.mockserver.server.unification; +package org.mockserver.echo.unification; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.codec.http.HttpContentDecompressor; -import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.HttpServerCodec; -import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.codec.http.*; import io.netty.handler.ssl.SslHandler; import io.netty.util.AttributeKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpResponseStatus.*; +import static io.netty.handler.codec.http.HttpUtil.is100ContinueExpected; +import static io.netty.handler.codec.http.HttpUtil.isKeepAlive; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.mockserver.socket.NettySslContextFactory.nettySslContextFactory; /** * @author jamesdbloom */ @ChannelHandler.Sharable -public abstract class PortUnificationHandler extends SimpleChannelInboundHandler { +public class EchoServerPortUnificationHandler extends SimpleChannelInboundHandler { - public static final AttributeKey SSL_ENABLED = AttributeKey.valueOf("SSL_ENABLED"); + private static final AttributeKey SSL_ENABLED = AttributeKey.valueOf("SSL_ENABLED"); private final Logger logger = LoggerFactory.getLogger(this.getClass()); - public PortUnificationHandler() { + EchoServerPortUnificationHandler() { super(false); } @Override - protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { // Will use the first five bytes to detect a protocol. if (msg.readableBytes() < 3) { return; @@ -80,7 +84,7 @@ private boolean isHttp(ByteBuf msg) { private void enableSsl(ChannelHandlerContext ctx, ByteBuf msg) { ChannelPipeline pipeline = ctx.pipeline(); pipeline.addFirst(nettySslContextFactory().createServerSslContext().newHandler(ctx.alloc())); - ctx.channel().attr(PortUnificationHandler.SSL_ENABLED).set(Boolean.TRUE); + ctx.channel().attr(SSL_ENABLED).set(Boolean.TRUE); // re-unify (with SSL enabled) ctx.pipeline().fireChannelRead(msg); @@ -91,23 +95,52 @@ private void switchToHttp(ChannelHandlerContext ctx, ByteBuf msg) { addLastIfNotPresent(pipeline, new HttpServerCodec(8192, 8192, 8192)); addLastIfNotPresent(pipeline, new HttpContentDecompressor()); - addLastIfNotPresent(pipeline, new HttpContentLengthRemover()); addLastIfNotPresent(pipeline, new HttpObjectAggregator(Integer.MAX_VALUE)); if (logger.isDebugEnabled()) { - addLastIfNotPresent(pipeline, new LoggingHandler()); + addLastIfNotPresent(pipeline, new io.netty.handler.logging.LoggingHandler()); } - configurePipeline(ctx, pipeline); + configurePipeline(pipeline); pipeline.remove(this); // fire message back through pipeline ctx.fireChannelRead(msg); } - protected void addLastIfNotPresent(ChannelPipeline pipeline, ChannelHandler channelHandler) { + private void addLastIfNotPresent(ChannelPipeline pipeline, ChannelHandler channelHandler) { if (pipeline.get(channelHandler.getClass()) == null) { pipeline.addLast(channelHandler); } } - protected abstract void configurePipeline(ChannelHandlerContext ctx, ChannelPipeline pipeline); + private void configurePipeline(ChannelPipeline pipeline) { + pipeline.addLast(new SimpleChannelInboundHandler() { + + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { + HttpResponseStatus responseStatus = OK; + if (request.uri().equals("/not_found")) { + responseStatus = NOT_FOUND; + } + // echo back request headers and body + FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, responseStatus, Unpooled.copiedBuffer(request.content())); + response.headers().add(request.headers()); + + // set hop-by-hop headers + response.headers().set(CONTENT_LENGTH, response.content().readableBytes()); + if (isKeepAlive(request)) { + response.headers().set(CONNECTION, HttpHeaderValues.KEEP_ALIVE); + } + if (is100ContinueExpected(request)) { + ctx.write(new DefaultFullHttpResponse(HTTP_1_1, CONTINUE)); + } + + // write and flush + ctx.writeAndFlush(response); + } + + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } + }); + } } diff --git a/mockserver-core/src/main/java/org/mockserver/echo/unification/EchoServerUnificationHandler.java b/mockserver-core/src/main/java/org/mockserver/echo/unification/EchoServerUnificationHandler.java deleted file mode 100644 index f86f1e043..000000000 --- a/mockserver-core/src/main/java/org/mockserver/echo/unification/EchoServerUnificationHandler.java +++ /dev/null @@ -1,56 +0,0 @@ -package org.mockserver.echo.unification; - -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.codec.http.*; -import org.mockserver.server.unification.PortUnificationHandler; - -import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; -import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpResponseStatus.*; -import static io.netty.handler.codec.http.HttpUtil.is100ContinueExpected; -import static io.netty.handler.codec.http.HttpUtil.isKeepAlive; -import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; - -/** - * @author jamesdbloom - */ -@ChannelHandler.Sharable -public class EchoServerUnificationHandler extends PortUnificationHandler { - - @Override - protected void configurePipeline(ChannelHandlerContext ctx, ChannelPipeline pipeline) { - pipeline.addLast(new SimpleChannelInboundHandler() { - - protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { - HttpResponseStatus responseStatus = OK; - if (request.uri().equals("/not_found")) { - responseStatus = NOT_FOUND; - } - // echo back request headers and body - FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, responseStatus, Unpooled.copiedBuffer(request.content())); - response.headers().add(request.headers()); - - // set hop-by-hop headers - response.headers().set(CONTENT_LENGTH, response.content().readableBytes()); - if (isKeepAlive(request)) { - response.headers().set(CONNECTION, HttpHeaderValues.KEEP_ALIVE); - } - if (is100ContinueExpected(request)) { - ctx.write(new DefaultFullHttpResponse(HTTP_1_1, CONTINUE)); - } - - // write and flush - ctx.writeAndFlush(response); - } - - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - cause.printStackTrace(); - ctx.close(); - } - }); - } -} diff --git a/mockserver-core/src/main/java/org/mockserver/echo/unification/PortUnificationEchoServer.java b/mockserver-core/src/main/java/org/mockserver/echo/unification/PortUnificationEchoServer.java index 1201a4b8d..60b7ec7e0 100644 --- a/mockserver-core/src/main/java/org/mockserver/echo/unification/PortUnificationEchoServer.java +++ b/mockserver-core/src/main/java/org/mockserver/echo/unification/PortUnificationEchoServer.java @@ -33,7 +33,7 @@ public void run() { .channel(NioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, 100) .handler(new LoggingHandler("EchoServer Handler")) - .childHandler(new EchoServerUnificationHandler()) + .childHandler(new EchoServerPortUnificationHandler()) .bind(port) .addListener(new ChannelFutureListener() { @Override diff --git a/mockserver-core/src/main/java/org/mockserver/mappers/HttpServletRequestToMockServerRequestDecoder.java b/mockserver-core/src/main/java/org/mockserver/mappers/HttpServletRequestToMockServerRequestDecoder.java index 5dcc92cd0..245076998 100644 --- a/mockserver-core/src/main/java/org/mockserver/mappers/HttpServletRequestToMockServerRequestDecoder.java +++ b/mockserver-core/src/main/java/org/mockserver/mappers/HttpServletRequestToMockServerRequestDecoder.java @@ -22,19 +22,19 @@ */ public class HttpServletRequestToMockServerRequestDecoder { public HttpRequest mapHttpServletRequestToMockServerRequest(HttpServletRequest httpServletRequest) { - HttpRequest httpRequest = new HttpRequest(); - setMethod(httpRequest, httpServletRequest); + HttpRequest request = new HttpRequest(); + setMethod(request, httpServletRequest); - setPath(httpRequest, httpServletRequest); - setQueryString(httpRequest, httpServletRequest); + setPath(request, httpServletRequest); + setQueryString(request, httpServletRequest); - setBody(httpRequest, httpServletRequest); - setHeaders(httpRequest, httpServletRequest); - setCookies(httpRequest, httpServletRequest); + setBody(request, httpServletRequest); + setHeaders(request, httpServletRequest); + setCookies(request, httpServletRequest); - httpRequest.withKeepAlive(isKeepAlive(httpServletRequest)); - httpRequest.withSecure(httpServletRequest.isSecure()); - return httpRequest; + request.withKeepAlive(isKeepAlive(httpServletRequest)); + request.withSecure(httpServletRequest.isSecure()); + return request; } private void setMethod(HttpRequest httpRequest, HttpServletRequest httpServletRequest) { diff --git a/mockserver-core/src/main/java/org/mockserver/mock/HttpStateHandler.java b/mockserver-core/src/main/java/org/mockserver/mock/HttpStateHandler.java index c2c788794..006a7a17d 100644 --- a/mockserver-core/src/main/java/org/mockserver/mock/HttpStateHandler.java +++ b/mockserver-core/src/main/java/org/mockserver/mock/HttpStateHandler.java @@ -3,6 +3,7 @@ import com.google.common.base.Function; import com.google.common.base.Strings; import com.google.common.collect.Lists; +import io.netty.util.AttributeKey; import org.apache.commons.lang3.StringUtils; import org.mockserver.client.serialization.ExpectationSerializer; import org.mockserver.client.serialization.HttpRequestSerializer; @@ -39,6 +40,7 @@ */ public class HttpStateHandler { + public static final AttributeKey STATE_HANDLER = AttributeKey.valueOf("PROXY_STATE_HANDLER"); public static final String LOG_SEPARATOR = "------------------------------------\n"; // mockserver private LoggingFormatter logFormatter = new LoggingFormatter(LoggerFactory.getLogger(this.getClass()), this); diff --git a/mockserver-core/src/main/java/org/mockserver/mock/action/ActionHandler.java b/mockserver-core/src/main/java/org/mockserver/mock/action/ActionHandler.java index c0e247460..bb901e4c7 100644 --- a/mockserver-core/src/main/java/org/mockserver/mock/action/ActionHandler.java +++ b/mockserver-core/src/main/java/org/mockserver/mock/action/ActionHandler.java @@ -15,7 +15,9 @@ import org.mockserver.responsewriter.ResponseWriter; import java.net.InetSocketAddress; +import java.util.Set; +import static io.netty.handler.codec.http.HttpHeaderNames.HOST; import static org.mockserver.character.Character.NEW_LINE; import static org.mockserver.model.HttpResponse.notFoundResponse; @@ -25,7 +27,7 @@ public class ActionHandler { public static final AttributeKey REMOTE_SOCKET = AttributeKey.valueOf("REMOTE_SOCKET"); - private final boolean forwardNonMatching; + private HttpStateHandler httpStateHandler; private LoggingFormatter logFormatter; private HttpResponseActionHandler httpResponseActionHandler; @@ -41,10 +43,9 @@ public class ActionHandler { private HopByHopHeaderFilter hopByHopHeaderFilter = new HopByHopHeaderFilter(); private HttpRequestToCurlSerializer httpRequestToCurlSerializer = new HttpRequestToCurlSerializer(); - public ActionHandler(HttpStateHandler httpStateHandler, boolean forwardNonMatching) { + public ActionHandler(HttpStateHandler httpStateHandler) { this.httpStateHandler = httpStateHandler; this.logFormatter = httpStateHandler.getLogFormatter(); - this.forwardNonMatching = forwardNonMatching; this.httpResponseActionHandler = new HttpResponseActionHandler(); this.httpResponseTemplateActionHandler = new HttpResponseTemplateActionHandler(logFormatter); this.httpForwardActionHandler = new HttpForwardActionHandler(); @@ -53,7 +54,7 @@ public ActionHandler(HttpStateHandler httpStateHandler, boolean forwardNonMatchi this.httpObjectCallbackActionHandler = new HttpObjectCallbackActionHandler(httpStateHandler); } - public void processAction(HttpRequest request, ResponseWriter responseWriter, ChannelHandlerContext ctx) { + public void processAction(HttpRequest request, ResponseWriter responseWriter, ChannelHandlerContext ctx, Set localAddresses, boolean proxyRequest) { Expectation expectation = httpStateHandler.firstMatchingExpectation(request); if (expectation != null && expectation.getAction() != null) { Action action = expectation.getAction(); @@ -105,7 +106,7 @@ public void processAction(HttpRequest request, ResponseWriter responseWriter, Ch break; } } - } else if (forwardNonMatching) { + } else if (proxyRequest || !localAddresses.contains(request.getFirstHeader(HOST.toString()))) { InetSocketAddress remoteAddress = ctx != null ? ctx.channel().attr(REMOTE_SOCKET).get() : null; HttpResponse response = httpClient.sendRequest(hopByHopHeaderFilter.onRequest(request), remoteAddress); if (response == null) { diff --git a/mockserver-core/src/main/java/org/mockserver/socket/PortFactory.java b/mockserver-core/src/main/java/org/mockserver/socket/PortFactory.java index 1296b57ae..a84cc9d4b 100644 --- a/mockserver-core/src/main/java/org/mockserver/socket/PortFactory.java +++ b/mockserver-core/src/main/java/org/mockserver/socket/PortFactory.java @@ -14,7 +14,7 @@ public static int findFreePort() { port = server.getLocalPort(); server.close(); // allow time for the socket to be released - TimeUnit.MILLISECONDS.sleep(150); + TimeUnit.MILLISECONDS.sleep(250); } catch (Exception e) { throw new RuntimeException("Exception while trying to find a free port", e); } diff --git a/mockserver-core/src/test/java/org/mockserver/mock/action/ActionHandlerTest.java b/mockserver-core/src/test/java/org/mockserver/mock/action/ActionHandlerTest.java index 3ccbeca89..0fa5d4ce4 100644 --- a/mockserver-core/src/test/java/org/mockserver/mock/action/ActionHandlerTest.java +++ b/mockserver-core/src/test/java/org/mockserver/mock/action/ActionHandlerTest.java @@ -20,6 +20,7 @@ import org.mockserver.responsewriter.ResponseWriter; import java.net.InetSocketAddress; +import java.util.HashSet; import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.initMocks; @@ -79,7 +80,7 @@ public class ActionHandlerTest { @Before public void setupMocks() { mockHttpStateHandler = mock(HttpStateHandler.class); - actionHandler = new ActionHandler(mockHttpStateHandler, true); + actionHandler = new ActionHandler(mockHttpStateHandler); initMocks(this); request = request("some_path"); response = response("some_body"); @@ -103,7 +104,7 @@ public void shouldProcessForwardAction() { when(mockHttpStateHandler.firstMatchingExpectation(request)).thenReturn(expectation); // when - actionHandler.processAction(request, mockResponseWriter, null); + actionHandler.processAction(request, mockResponseWriter, null, new HashSet(), false); // then verify(mockHttpForwardActionHandler).handle(forward, request); @@ -120,7 +121,7 @@ public void shouldProcessForwardTemplateAction() { when(mockHttpStateHandler.firstMatchingExpectation(request)).thenReturn(expectation); // when - actionHandler.processAction(request, mockResponseWriter, null); + actionHandler.processAction(request, mockResponseWriter, null, new HashSet(), false); // then verify(mockHttpForwardTemplateActionHandler).handle(template, request); @@ -137,7 +138,7 @@ public void shouldProcessResponseAction() { when(mockHttpStateHandler.firstMatchingExpectation(request)).thenReturn(expectation); // when - actionHandler.processAction(request, mockResponseWriter, null); + actionHandler.processAction(request, mockResponseWriter, null, new HashSet(), false); // then verify(mockHttpResponseActionHandler).handle(response); @@ -154,7 +155,7 @@ public void shouldProcessResponseTemplateAction() { when(mockHttpStateHandler.firstMatchingExpectation(request)).thenReturn(expectation); // when - actionHandler.processAction(request, mockResponseWriter, null); + actionHandler.processAction(request, mockResponseWriter, null, new HashSet(), false); // then verify(mockHttpResponseTemplateActionHandler).handle(template, request); @@ -171,7 +172,7 @@ public void shouldProcessClassCallbackAction() { when(mockHttpStateHandler.firstMatchingExpectation(request)).thenReturn(expectation); // when - actionHandler.processAction(request, mockResponseWriter, null); + actionHandler.processAction(request, mockResponseWriter, null, new HashSet(), false); // then verify(mockHttpClassCallbackActionHandler).handle(callback, request); @@ -189,7 +190,7 @@ public void shouldProcessObjectCallbackAction() { ResponseWriter mockResponseWriter = mock(ResponseWriter.class); // when - actionHandler.processAction(request, mockResponseWriter, null); + actionHandler.processAction(request, mockResponseWriter, null, new HashSet(), false); // then verify(mockHttpStateHandler, times(1)).log(new ExpectationMatchLogEntry(request, expectation)); @@ -206,7 +207,7 @@ public void shouldProcessErrorAction() { ChannelHandlerContext mockChannelHandlerContext = mock(ChannelHandlerContext.class); // when - actionHandler.processAction(request, mockResponseWriter, mockChannelHandlerContext); + actionHandler.processAction(request, mockResponseWriter, mockChannelHandlerContext, new HashSet(), false); // then verify(mockHttpStateHandler, times(1)).log(new ExpectationMatchLogEntry(request, expectation)); @@ -220,11 +221,11 @@ public void shouldProxyRequestsWithRemoteSocketAttribute() { HttpRequest request = request("request_one"); // and - remote socket attribute - InetSocketAddress remoteAddress = new InetSocketAddress(1080); ChannelHandlerContext mockChannelHandlerContext = mock(ChannelHandlerContext.class); Channel mockChannel = mock(Channel.class); - Attribute inetSocketAddressAttribute = mock(Attribute.class); when(mockChannelHandlerContext.channel()).thenReturn(mockChannel); + InetSocketAddress remoteAddress = new InetSocketAddress(1080); + Attribute inetSocketAddressAttribute = mock(Attribute.class); when(inetSocketAddressAttribute.get()).thenReturn(remoteAddress); when(mockChannel.attr(REMOTE_SOCKET)).thenReturn(inetSocketAddressAttribute); @@ -232,7 +233,7 @@ public void shouldProxyRequestsWithRemoteSocketAttribute() { when(mockNettyHttpClient.sendRequest(request, remoteAddress)).thenReturn(response("response_one")); // when - actionHandler.processAction(request, mockResponseWriter, mockChannelHandlerContext); + actionHandler.processAction(request, mockResponseWriter, mockChannelHandlerContext, new HashSet(), true); // then verify(mockHttpStateHandler).log(new RequestResponseLogEntry(request, response("response_one"))); diff --git a/mockserver-netty/src/main/java/org/mockserver/lifecycle/LifeCycle.java b/mockserver-netty/src/main/java/org/mockserver/lifecycle/LifeCycle.java index 5604397d3..788acb471 100644 --- a/mockserver-netty/src/main/java/org/mockserver/lifecycle/LifeCycle.java +++ b/mockserver-netty/src/main/java/org/mockserver/lifecycle/LifeCycle.java @@ -85,7 +85,7 @@ public void run() { .bind(portToBind) .addListener(new ChannelFutureListener() { @Override - public void operationComplete(ChannelFuture future) throws Exception { + public void operationComplete(ChannelFuture future) { if (future.isSuccess()) { channelOpened.set(future.channel()); } else { @@ -95,11 +95,9 @@ public void operationComplete(ChannelFuture future) throws Exception { }) .channel(); - int boundPort = ((InetSocketAddress) channelOpened.get().localAddress()).getPort(); - started(boundPort); - logger.info("MockServer started on port: {}", boundPort); - + started(((InetSocketAddress) channelOpened.get().localAddress()).getPort()); channel.closeFuture().syncUninterruptibly(); + } catch (Exception e) { throw new RuntimeException("Exception while binding MockServer to port " + portToBind, e.getCause()); } @@ -115,7 +113,7 @@ public void operationComplete(ChannelFuture future) throws Exception { } protected void started(Integer port) { - + logger.info("MockServer started on port: {}", port); } protected void stopped() { diff --git a/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServer.java b/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServer.java index ab8779cf9..bfab3819e 100644 --- a/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServer.java +++ b/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServer.java @@ -5,15 +5,23 @@ import io.netty.channel.ChannelOption; import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.util.AttributeKey; import org.mockserver.lifecycle.LifeCycle; +import org.mockserver.mock.HttpStateHandler; +import org.mockserver.proxy.Proxy; +import org.mockserver.proxy.http.HttpProxy; import java.util.Arrays; +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; + /** * @author jamesdbloom */ public class MockServer extends LifeCycle { + public static final AttributeKey MOCK_SERVER = AttributeKey.valueOf("MOCK_SERVER"); + /** * Start the instance using the ports provided * @@ -31,7 +39,9 @@ public MockServer(final Integer... requestedPortBindings) { .childOption(ChannelOption.AUTO_READ, true) .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) .option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(8 * 1024, 32 * 1024)) - .childHandler(new MockServerInitializer(MockServer.this)); + .childHandler(new MockServerInitializer()) + .childAttr(MOCK_SERVER, MockServer.this) + .childAttr(STATE_HANDLER, new HttpStateHandler()); bindToPorts(Arrays.asList(requestedPortBindings)); diff --git a/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerHandler.java b/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerHandler.java index 4151f0b8d..14955e216 100644 --- a/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerHandler.java @@ -10,11 +10,14 @@ import org.mockserver.mock.action.ActionHandler; import org.mockserver.model.HttpRequest; import org.mockserver.model.PortBinding; +import org.mockserver.proxy.connect.HttpConnectHandler; import org.mockserver.responsewriter.NettyResponseWriter; import org.mockserver.responsewriter.ResponseWriter; +import org.mockserver.socket.KeyAndCertificateFactory; import org.slf4j.LoggerFactory; import java.net.BindException; +import java.util.HashSet; import java.util.List; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; @@ -23,6 +26,10 @@ import static org.mockserver.exception.ExceptionHandler.shouldIgnoreException; import static org.mockserver.model.HttpResponse.response; import static org.mockserver.model.PortBinding.portBinding; +import static org.mockserver.proxy.Proxy.PROXYING; +import static org.mockserver.proxy.Proxy.getLocalAddresses; +import static org.mockserver.proxy.Proxy.isProxyingRequest; +import static org.mockserver.unification.PortUnificationHandler.enabledSslUpstreamAndDownstream; /** * @author jamesdbloom @@ -45,7 +52,7 @@ public MockServerHandler(MockServer server, HttpStateHandler httpStateHandler) { this.server = server; this.httpStateHandler = httpStateHandler; this.logFormatter = httpStateHandler.getLogFormatter(); - this.actionHandler = new ActionHandler(httpStateHandler, false); + this.actionHandler = new ActionHandler(httpStateHandler); } @Override @@ -84,9 +91,20 @@ public void run() { } }).start(); + } else if (request.getMethod().getValue().equals("CONNECT")) { + + ctx.channel().attr(PROXYING).set(Boolean.TRUE); + // assume SSL for CONNECT request + enabledSslUpstreamAndDownstream(ctx.channel()); + // add Subject Alternative Name for SSL certificate + KeyAndCertificateFactory.addSubjectAlternativeName(request.getPath().getValue()); + ctx.pipeline().addLast(new HttpConnectHandler(request.getPath().getValue(), -1)); + ctx.pipeline().remove(this); + ctx.fireChannelRead(request); + } else { - actionHandler.processAction(request, responseWriter, ctx); + actionHandler.processAction(request, responseWriter, ctx, getLocalAddresses(ctx), isProxyingRequest(ctx)); } } @@ -101,7 +119,7 @@ public void run() { } @Override - public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + public void channelReadComplete(ChannelHandlerContext ctx) { ctx.flush(); } diff --git a/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerInitializer.java b/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerInitializer.java index 0be5b28da..e6619ec24 100644 --- a/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerInitializer.java +++ b/mockserver-netty/src/main/java/org/mockserver/mockserver/MockServerInitializer.java @@ -1,40 +1,30 @@ package org.mockserver.mockserver; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; -import org.mockserver.logging.LoggingHandler; -import org.mockserver.mock.HttpStateHandler; import org.mockserver.mockserver.callback.WebSocketServerHandler; import org.mockserver.server.netty.codec.MockServerServerCodec; -import org.mockserver.server.unification.PortUnificationHandler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.mockserver.unification.PortUnificationHandler; -public class MockServerInitializer extends PortUnificationHandler { - - private final Logger logger = LoggerFactory.getLogger(this.getClass()); - private final HttpStateHandler httpStateHandler = new HttpStateHandler(); - private final MockServer mockServer; +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; +import static org.mockserver.mockserver.MockServer.MOCK_SERVER; - MockServerInitializer(MockServer mockServer) { - this.mockServer = mockServer; - } +/** + * @author jamesdbloom + */ +@ChannelHandler.Sharable +public class MockServerInitializer extends PortUnificationHandler { @Override protected void configurePipeline(ChannelHandlerContext ctx, ChannelPipeline pipeline) { - // add logging - if (logger.isDebugEnabled()) { - pipeline.addLast(new LoggingHandler(logger)); - } - - boolean isSecure = false; - if (ctx.channel().attr(PortUnificationHandler.SSL_ENABLED).get() != null) { - isSecure = ctx.channel().attr(PortUnificationHandler.SSL_ENABLED).get(); - } - pipeline.addLast(new WebSocketServerHandler(httpStateHandler.getWebSocketClientRegistry())); - pipeline.addLast(new MockServerServerCodec(isSecure)); + pipeline.addLast(new WebSocketServerHandler(ctx.channel().attr(STATE_HANDLER).get().getWebSocketClientRegistry())); + pipeline.addLast(new MockServerServerCodec(isSslEnabledUpstream(ctx.channel()))); - // add mock server handlers - pipeline.addLast(new MockServerHandler(mockServer, httpStateHandler)); + pipeline.addLast(new MockServerHandler( + ctx.channel().attr(MOCK_SERVER).get(), + ctx.channel().attr(STATE_HANDLER).get()) + ); } + } diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/Proxy.java b/mockserver-netty/src/main/java/org/mockserver/proxy/Proxy.java index c193bf523..55a6db122 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/Proxy.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/Proxy.java @@ -1,10 +1,14 @@ package org.mockserver.proxy; +import io.netty.channel.ChannelHandlerContext; import io.netty.util.AttributeKey; import org.mockserver.lifecycle.LifeCycle; import org.mockserver.mock.HttpStateHandler; +import org.mockserver.model.HttpRequest; import java.net.InetSocketAddress; +import java.util.HashSet; +import java.util.Set; /** * This class should not be constructed directly instead use HttpProxyBuilder to build and configure this class @@ -15,7 +19,24 @@ public class Proxy extends LifeCycle { public static final AttributeKey HTTP_PROXY = AttributeKey.valueOf("HTTP_PROXY"); - public static final AttributeKey STATE_HANDLER = AttributeKey.valueOf("PROXY_STATE_HANDLER"); public static final AttributeKey HTTP_CONNECT_SOCKET = AttributeKey.valueOf("HTTP_CONNECT_SOCKET"); + public static final AttributeKey LOCAL_HOST_HEADERS = AttributeKey.valueOf("LOCAL_HOST_HEADERS"); + public static final AttributeKey PROXYING = AttributeKey.valueOf("PROXYING"); + + public static boolean isProxyingRequest(ChannelHandlerContext ctx) { + if (ctx != null && ctx.channel().attr(PROXYING).get() != null) { + return ctx.channel().attr(PROXYING).get(); + } + return false; + } + + public static Set getLocalAddresses(ChannelHandlerContext ctx) { + if (ctx != null && + ctx.channel().attr(LOCAL_HOST_HEADERS) != null && + ctx.channel().attr(LOCAL_HOST_HEADERS).get() != null) { + return ctx.channel().attr(LOCAL_HOST_HEADERS).get(); + } + return new HashSet<>(); + } } diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxy.java b/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxy.java index 83cb29cd9..922bf327b 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxy.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxy.java @@ -11,6 +11,7 @@ import java.net.InetSocketAddress; import java.util.Arrays; +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; import static org.mockserver.mock.action.ActionHandler.REMOTE_SOCKET; /** @@ -49,6 +50,7 @@ public DirectProxy(final String remoteHost, final Integer remotePort, final Inte .childHandler(new DirectProxyUnificationHandler()) .childAttr(HTTP_PROXY, DirectProxy.this) .childAttr(REMOTE_SOCKET, remoteSocket) +// .childAttr(PROXYING, true) .childAttr(STATE_HANDLER, new HttpStateHandler()); bindToPorts(Arrays.asList(localPorts)); diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxyUnificationHandler.java b/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxyUnificationHandler.java index 0492bae39..b95396ca8 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxyUnificationHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/direct/DirectProxyUnificationHandler.java @@ -3,10 +3,13 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; -import org.mockserver.proxy.Proxy; +import org.mockserver.mockserver.callback.WebSocketServerHandler; import org.mockserver.proxy.http.HttpProxyHandler; -import org.mockserver.proxy.unification.PortUnificationHandler; import org.mockserver.server.netty.codec.MockServerServerCodec; +import org.mockserver.unification.PortUnificationHandler; + +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; +import static org.mockserver.proxy.Proxy.HTTP_PROXY; /** * @author jamesdbloom @@ -16,10 +19,13 @@ public class DirectProxyUnificationHandler extends PortUnificationHandler { @Override protected void configurePipeline(ChannelHandlerContext ctx, ChannelPipeline pipeline) { + pipeline.addLast(new WebSocketServerHandler(ctx.channel().attr(STATE_HANDLER).get().getWebSocketClientRegistry())); pipeline.addLast(new MockServerServerCodec(isSslEnabledDownstream(ctx.channel()))); + pipeline.addLast(new HttpProxyHandler( - ctx.channel().attr(Proxy.HTTP_PROXY).get(), - ctx.channel().attr(Proxy.STATE_HANDLER).get() + ctx.channel().attr(HTTP_PROXY).get(), + ctx.channel().attr(STATE_HANDLER).get() )); } + } diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxy.java b/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxy.java index d6cb655c6..f8f692566 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxy.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxy.java @@ -2,6 +2,7 @@ import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOption; import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.socket.nio.NioServerSocketChannel; @@ -9,8 +10,13 @@ import org.mockserver.mock.HttpStateHandler; import org.mockserver.proxy.Proxy; +import java.net.InetAddress; import java.net.InetSocketAddress; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; + +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; /** * @author jamesdbloom @@ -59,6 +65,7 @@ public void run() { } protected void started(Integer port) { + super.started(port); ConfigurationProperties.proxyPort(port); System.setProperty("http.proxyHost", "127.0.0.1"); System.setProperty("http.proxyPort", port.toString()); diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyHandler.java b/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyHandler.java index 3ef3c1377..a0461f72a 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyHandler.java @@ -12,13 +12,13 @@ import org.mockserver.model.PortBinding; import org.mockserver.proxy.Proxy; import org.mockserver.proxy.connect.HttpConnectHandler; -import org.mockserver.proxy.unification.PortUnificationHandler; import org.mockserver.responsewriter.NettyResponseWriter; import org.mockserver.responsewriter.ResponseWriter; import org.mockserver.socket.KeyAndCertificateFactory; import org.slf4j.LoggerFactory; import java.net.BindException; +import java.util.HashSet; import java.util.List; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; @@ -27,6 +27,10 @@ import static org.mockserver.exception.ExceptionHandler.shouldIgnoreException; import static org.mockserver.model.HttpResponse.response; import static org.mockserver.model.PortBinding.portBinding; +import static org.mockserver.proxy.Proxy.PROXYING; +import static org.mockserver.proxy.Proxy.getLocalAddresses; +import static org.mockserver.proxy.Proxy.isProxyingRequest; +import static org.mockserver.unification.PortUnificationHandler.enabledSslUpstreamAndDownstream; /** * @author jamesdbloom @@ -49,7 +53,7 @@ public HttpProxyHandler(Proxy server, HttpStateHandler httpStateHandler) { this.server = server; this.httpStateHandler = httpStateHandler; this.logFormatter = httpStateHandler.getLogFormatter(); - this.actionHandler = new ActionHandler(httpStateHandler, true); + this.actionHandler = new ActionHandler(httpStateHandler); } @Override @@ -90,8 +94,9 @@ public void run() { } else if (request.getMethod().getValue().equals("CONNECT")) { - // assume CONNECT always for SSL - PortUnificationHandler.enabledSslUpstreamAndDownstream(ctx.channel()); + ctx.channel().attr(PROXYING).set(Boolean.TRUE); + // assume SSL for CONNECT request + enabledSslUpstreamAndDownstream(ctx.channel()); // add Subject Alternative Name for SSL certificate KeyAndCertificateFactory.addSubjectAlternativeName(request.getPath().getValue()); ctx.pipeline().addLast(new HttpConnectHandler(request.getPath().getValue(), -1)); @@ -100,7 +105,7 @@ public void run() { } else { - actionHandler.processAction(request, responseWriter, ctx); + actionHandler.processAction(request, responseWriter, ctx, getLocalAddresses(ctx), isProxyingRequest(ctx)); } } @@ -115,7 +120,7 @@ public void run() { } @Override - public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + public void channelReadComplete(ChannelHandlerContext ctx) { ctx.flush(); } diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyUnificationHandler.java b/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyUnificationHandler.java index 3ecb8c7f7..123f261a4 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyUnificationHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/http/HttpProxyUnificationHandler.java @@ -3,9 +3,12 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; -import org.mockserver.proxy.Proxy; -import org.mockserver.proxy.unification.PortUnificationHandler; +import org.mockserver.mockserver.callback.WebSocketServerHandler; import org.mockserver.server.netty.codec.MockServerServerCodec; +import org.mockserver.unification.PortUnificationHandler; + +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; +import static org.mockserver.proxy.Proxy.HTTP_PROXY; /** * @author jamesdbloom @@ -15,10 +18,12 @@ public class HttpProxyUnificationHandler extends PortUnificationHandler { @Override protected void configurePipeline(ChannelHandlerContext ctx, ChannelPipeline pipeline) { - pipeline.addLast(new MockServerServerCodec(isSslEnabledDownstream(ctx.channel()))); + pipeline.addLast(new WebSocketServerHandler(ctx.channel().attr(STATE_HANDLER).get().getWebSocketClientRegistry())); + pipeline.addLast(new MockServerServerCodec(isSslEnabledUpstream(ctx.channel()))); + pipeline.addLast(new HttpProxyHandler( - ctx.channel().attr(Proxy.HTTP_PROXY).get(), - ctx.channel().attr(Proxy.STATE_HANDLER).get() + ctx.channel().attr(HTTP_PROXY).get(), + ctx.channel().attr(STATE_HANDLER).get() )); } diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/relay/RelayConnectHandler.java b/mockserver-netty/src/main/java/org/mockserver/proxy/relay/RelayConnectHandler.java index e70578b7e..9e821432f 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/relay/RelayConnectHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/relay/RelayConnectHandler.java @@ -10,9 +10,7 @@ import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; import org.mockserver.logging.LoggingHandler; -import org.mockserver.mock.HttpStateHandler; -import org.mockserver.proxy.http.HttpProxy; -import org.mockserver.proxy.unification.PortUnificationHandler; +import org.mockserver.mock.action.ActionHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,7 +19,10 @@ import static org.mockserver.exception.ExceptionHandler.shouldIgnoreException; import static org.mockserver.mock.action.ActionHandler.REMOTE_SOCKET; import static org.mockserver.proxy.Proxy.HTTP_CONNECT_SOCKET; +import static org.mockserver.proxy.Proxy.PROXYING; import static org.mockserver.socket.NettySslContextFactory.nettySslContextFactory; +import static org.mockserver.unification.PortUnificationHandler.isSslEnabledDownstream; +import static org.mockserver.unification.PortUnificationHandler.isSslEnabledUpstream; @ChannelHandler.Sharable public abstract class RelayConnectHandler extends SimpleChannelInboundHandler { @@ -49,11 +50,12 @@ public void channelActive(final ChannelHandlerContext clientCtx) throws Exceptio @Override public void operationComplete(ChannelFuture channelFuture) throws Exception { removeCodecSupport(serverCtx); + serverCtx.channel().attr(PROXYING).set(Boolean.TRUE); // downstream ChannelPipeline downstreamPipeline = clientCtx.channel().pipeline(); - if (PortUnificationHandler.isSslEnabledDownstream(serverCtx.channel())) { + if (isSslEnabledDownstream(serverCtx.channel())) { downstreamPipeline.addLast(nettySslContextFactory().createClientSslContext().newHandler(clientCtx.alloc(), host, port)); } @@ -73,7 +75,7 @@ public void operationComplete(ChannelFuture channelFuture) throws Exception { // upstream ChannelPipeline upstreamPipeline = serverCtx.channel().pipeline(); - if (PortUnificationHandler.isSslEnabledUpstream(serverCtx.channel())) { + if (isSslEnabledUpstream(serverCtx.channel())) { upstreamPipeline.addLast(nettySslContextFactory().createServerSslContext().newHandler(serverCtx.alloc())); } diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/relay/UpstreamProxyRelayHandler.java b/mockserver-netty/src/main/java/org/mockserver/proxy/relay/UpstreamProxyRelayHandler.java index 6930d941a..63d4fb391 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/relay/UpstreamProxyRelayHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/relay/UpstreamProxyRelayHandler.java @@ -4,7 +4,6 @@ import io.netty.channel.*; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.ssl.SslHandler; -import org.mockserver.proxy.unification.PortUnificationHandler; import org.slf4j.Logger; import java.nio.channels.ClosedChannelException; @@ -13,6 +12,7 @@ import static org.mockserver.exception.ExceptionHandler.closeOnFlush; import static org.mockserver.exception.ExceptionHandler.shouldIgnoreException; import static org.mockserver.socket.NettySslContextFactory.nettySslContextFactory; +import static org.mockserver.unification.PortUnificationHandler.isSslEnabledDownstream; public class UpstreamProxyRelayHandler extends SimpleChannelInboundHandler { @@ -35,7 +35,7 @@ public void channelActive(ChannelHandlerContext ctx) { @Override public void channelRead0(final ChannelHandlerContext ctx, final FullHttpRequest request) { - if (PortUnificationHandler.isSslEnabledDownstream(upstreamChannel) && downstreamChannel.pipeline().get(SslHandler.class) == null) { + if (isSslEnabledDownstream(upstreamChannel) && downstreamChannel.pipeline().get(SslHandler.class) == null) { downstreamChannel.pipeline().addFirst(nettySslContextFactory().createClientSslContext().newHandler(ctx.alloc())); } downstreamChannel.writeAndFlush(request).addListener(new ChannelFutureListener() { diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/socks/SocksProxyHandler.java b/mockserver-netty/src/main/java/org/mockserver/proxy/socks/SocksProxyHandler.java index 812101b83..6cb86af5f 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/socks/SocksProxyHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/proxy/socks/SocksProxyHandler.java @@ -4,12 +4,15 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.socks.*; -import org.mockserver.proxy.unification.PortUnificationHandler; -import org.mockserver.socket.KeyAndCertificateFactory; +import org.mockserver.mock.action.ActionHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import static org.mockserver.exception.ExceptionHandler.shouldIgnoreException; +import static org.mockserver.proxy.Proxy.PROXYING; +import static org.mockserver.socket.KeyAndCertificateFactory.addSubjectAlternativeName; +import static org.mockserver.unification.PortUnificationHandler.disableSslDownstream; +import static org.mockserver.unification.PortUnificationHandler.enabledSslDownstream; public class SocksProxyHandler extends SimpleChannelInboundHandler { @@ -41,14 +44,15 @@ protected void channelRead0(final ChannelHandlerContext ctx, SocksRequest socksR if (req.cmdType() == SocksCmdType.CONNECT) { Channel channel = ctx.channel(); + channel.attr(PROXYING).set(Boolean.TRUE); if (String.valueOf(req.port()).endsWith("80")) { - PortUnificationHandler.disableSslDownstream(channel); + disableSslDownstream(channel); } else if (String.valueOf(req.port()).endsWith("443")) { - PortUnificationHandler.enabledSslDownstream(channel); + enabledSslDownstream(channel); } // add Subject Alternative Name for SSL certificate - KeyAndCertificateFactory.addSubjectAlternativeName(req.host()); + addSubjectAlternativeName(req.host()); ctx.pipeline().addAfter(getClass().getSimpleName() + "#0", SocksConnectHandler.class.getSimpleName() + "#0", new SocksConnectHandler(req.host(), req.port())); ctx.pipeline().remove(this); diff --git a/mockserver-core/src/main/java/org/mockserver/server/unification/HttpContentLengthRemover.java b/mockserver-netty/src/main/java/org/mockserver/unification/HttpContentLengthRemover.java similarity index 88% rename from mockserver-core/src/main/java/org/mockserver/server/unification/HttpContentLengthRemover.java rename to mockserver-netty/src/main/java/org/mockserver/unification/HttpContentLengthRemover.java index d9e3e2141..2b973d036 100644 --- a/mockserver-core/src/main/java/org/mockserver/server/unification/HttpContentLengthRemover.java +++ b/mockserver-netty/src/main/java/org/mockserver/unification/HttpContentLengthRemover.java @@ -1,4 +1,4 @@ -package org.mockserver.server.unification; +package org.mockserver.unification; import com.google.common.net.HttpHeaders; import io.netty.channel.ChannelHandlerContext; @@ -13,7 +13,7 @@ */ public class HttpContentLengthRemover extends MessageToMessageEncoder { @Override - protected void encode(ChannelHandlerContext ctx, DefaultHttpMessage defaultHttpMessage, List out) throws Exception { + protected void encode(ChannelHandlerContext ctx, DefaultHttpMessage defaultHttpMessage, List out) { if (defaultHttpMessage.headers().contains(HttpHeaders.CONTENT_LENGTH, "", true)) { defaultHttpMessage.headers().remove(HttpHeaders.CONTENT_LENGTH); } diff --git a/mockserver-netty/src/main/java/org/mockserver/proxy/unification/PortUnificationHandler.java b/mockserver-netty/src/main/java/org/mockserver/unification/PortUnificationHandler.java similarity index 74% rename from mockserver-netty/src/main/java/org/mockserver/proxy/unification/PortUnificationHandler.java rename to mockserver-netty/src/main/java/org/mockserver/unification/PortUnificationHandler.java index d0570a31f..3dba0b1ef 100644 --- a/mockserver-netty/src/main/java/org/mockserver/proxy/unification/PortUnificationHandler.java +++ b/mockserver-netty/src/main/java/org/mockserver/unification/PortUnificationHandler.java @@ -1,4 +1,4 @@ -package org.mockserver.proxy.unification; +package org.mockserver.unification; import com.google.common.annotations.VisibleForTesting; import io.netty.buffer.ByteBuf; @@ -13,12 +13,22 @@ import io.netty.handler.ssl.SslHandler; import io.netty.util.AttributeKey; import org.mockserver.logging.LoggingHandler; +import org.mockserver.model.HttpRequest; import org.mockserver.proxy.socks.SocksProxyHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + import static org.mockserver.exception.ExceptionHandler.closeOnFlush; import static org.mockserver.exception.ExceptionHandler.shouldIgnoreException; +import static org.mockserver.proxy.Proxy.LOCAL_HOST_HEADERS; import static org.mockserver.socket.NettySslContextFactory.nettySslContextFactory; /** @@ -27,8 +37,8 @@ @ChannelHandler.Sharable public abstract class PortUnificationHandler extends SimpleChannelInboundHandler { - public static final AttributeKey SSL_ENABLED_UPSTREAM = AttributeKey.valueOf("PROXY_SSL_ENABLED_UPSTREAM"); - public static final AttributeKey SSL_ENABLED_DOWNSTREAM = AttributeKey.valueOf("SSL_ENABLED_DOWNSTREAM"); + private static final AttributeKey SSL_ENABLED_UPSTREAM = AttributeKey.valueOf("PROXY_SSL_ENABLED_UPSTREAM"); + private static final AttributeKey SSL_ENABLED_DOWNSTREAM = AttributeKey.valueOf("SSL_ENABLED_DOWNSTREAM"); @VisibleForTesting public static Logger logger = LoggerFactory.getLogger(PortUnificationHandler.class); @@ -38,8 +48,8 @@ public PortUnificationHandler() { } public static void enabledSslUpstreamAndDownstream(Channel channel) { - channel.attr(PortUnificationHandler.SSL_ENABLED_UPSTREAM).set(Boolean.TRUE); - channel.attr(PortUnificationHandler.SSL_ENABLED_DOWNSTREAM).set(Boolean.TRUE); + channel.attr(SSL_ENABLED_UPSTREAM).set(Boolean.TRUE); + channel.attr(SSL_ENABLED_DOWNSTREAM).set(Boolean.TRUE); } public static boolean isSslEnabledUpstream(Channel channel) { @@ -51,11 +61,11 @@ public static boolean isSslEnabledUpstream(Channel channel) { } public static void enabledSslDownstream(Channel channel) { - channel.attr(PortUnificationHandler.SSL_ENABLED_DOWNSTREAM).set(Boolean.TRUE); + channel.attr(SSL_ENABLED_DOWNSTREAM).set(Boolean.TRUE); } public static void disableSslDownstream(Channel channel) { - channel.attr(PortUnificationHandler.SSL_ENABLED_DOWNSTREAM).set(Boolean.FALSE); + channel.attr(SSL_ENABLED_DOWNSTREAM).set(Boolean.FALSE); } public static boolean isSslEnabledDownstream(Channel channel) { @@ -142,9 +152,9 @@ private boolean isHttp(ByteBuf msg) { private void enableSsl(ChannelHandlerContext ctx, ByteBuf msg) { ChannelPipeline pipeline = ctx.pipeline(); pipeline.addFirst(nettySslContextFactory().createServerSslContext().newHandler(ctx.alloc())); + PortUnificationHandler.enabledSslUpstreamAndDownstream(ctx.channel()); // re-unify (with SSL enabled) - PortUnificationHandler.enabledSslUpstreamAndDownstream(ctx.channel()); ctx.pipeline().fireChannelRead(msg); } @@ -163,16 +173,41 @@ private void switchToHttp(ChannelHandlerContext ctx, ByteBuf msg) { addLastIfNotPresent(pipeline, new HttpServerCodec(8192, 8192, 8192)); addLastIfNotPresent(pipeline, new HttpContentDecompressor()); + addLastIfNotPresent(pipeline, new HttpContentLengthRemover()); addLastIfNotPresent(pipeline, new HttpObjectAggregator(Integer.MAX_VALUE)); + if (logger.isDebugEnabled()) { + addLastIfNotPresent(pipeline, new LoggingHandler(logger)); + } configurePipeline(ctx, pipeline); pipeline.remove(this); - // pass message to next stage in pipeline + ctx.channel().attr(LOCAL_HOST_HEADERS).set(getLocalAddresses(ctx)); + + // fire message back through pipeline ctx.fireChannelRead(msg); } - protected void addLastIfNotPresent(ChannelPipeline pipeline, ChannelHandler channelHandler) { + private Set getLocalAddresses(ChannelHandlerContext ctx) { + Set localAddresses = new HashSet<>(); + SocketAddress localAddress = ctx.channel().localAddress(); + if (localAddress instanceof InetSocketAddress) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) localAddress; + String portExtension = ""; + if (!(inetSocketAddress.getPort() == 443 && isSslEnabledUpstream(ctx.channel()) || inetSocketAddress.getPort() == 80)) { + portExtension = ":" + inetSocketAddress.getPort(); + } + InetAddress socketAddress = inetSocketAddress.getAddress(); + localAddresses.add(socketAddress.getHostAddress() + portExtension); + localAddresses.add(socketAddress.getCanonicalHostName() + portExtension); + localAddresses.add(socketAddress.getHostName() + portExtension); + localAddresses.add("localhost" + portExtension); + localAddresses.add("127.0.0.1" + portExtension); + } + return localAddresses; + } + + private void addLastIfNotPresent(ChannelPipeline pipeline, ChannelHandler channelHandler) { if (pipeline.get(channelHandler.getClass()) == null) { pipeline.addLast(channelHandler); } diff --git a/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndDirectProxyMockingIntegrationTest.java b/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndDirectProxyMockingIntegrationTest.java new file mode 100644 index 000000000..fac945a0b --- /dev/null +++ b/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndDirectProxyMockingIntegrationTest.java @@ -0,0 +1,67 @@ +package org.mockserver.integration.mockserver; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.mockserver.client.server.MockServerClient; +import org.mockserver.echo.http.EchoServer; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.proxy.ProxyBuilder; +import org.mockserver.socket.PortFactory; + +import static org.mockserver.integration.ClientAndServer.startClientAndServer; +import static org.mockserver.model.HttpResponse.notFoundResponse; + +/** + * @author jamesdbloom + */ +public class ClientAndDirectProxyMockingIntegrationTest extends AbstractRestartableMockServerNettyIntegrationTest { + + private static final int SERVER_HTTP_PORT = PortFactory.findFreePort(); + private final static int TEST_SERVER_HTTP_PORT = PortFactory.findFreePort(); + private static EchoServer echoServer; + + @BeforeClass + public static void startServer() { + // start direct proxy and client + new ProxyBuilder() + .withLocalPort(SERVER_HTTP_PORT) + .withDirect("localhost", TEST_SERVER_HTTP_PORT) + .build(); + mockServerClient = new MockServerClient("localhost", SERVER_HTTP_PORT); + + // start echo servers + echoServer = new EchoServer(TEST_SERVER_HTTP_PORT, false); + } + + @AfterClass + public static void stopServer() { + // stop mock server and client + if (mockServerClient instanceof ClientAndServer) { + mockServerClient.stop(); + } + + // stop echo server + echoServer.stop(); + } + + @Override + public void startServerAgain() { + startClientAndServer(SERVER_HTTP_PORT); + } + + @Override + public int getMockServerPort() { + return SERVER_HTTP_PORT; + } + + @Override + public int getMockServerSecurePort() { + return SERVER_HTTP_PORT; + } + + @Override + public int getTestServerPort() { + return TEST_SERVER_HTTP_PORT; + } +} diff --git a/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndProxyMockingIntegrationTest.java b/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndProxyMockingIntegrationTest.java new file mode 100644 index 000000000..a9e6a73fa --- /dev/null +++ b/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/ClientAndProxyMockingIntegrationTest.java @@ -0,0 +1,65 @@ +package org.mockserver.integration.mockserver; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.mockserver.client.server.MockServerClient; +import org.mockserver.echo.http.EchoServer; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpResponse; +import org.mockserver.proxy.ProxyBuilder; +import org.mockserver.socket.PortFactory; + +import static org.mockserver.integration.ClientAndServer.startClientAndServer; +import static org.mockserver.model.HttpResponse.notFoundResponse; + +/** + * @author jamesdbloom + */ +public class ClientAndProxyMockingIntegrationTest extends AbstractRestartableMockServerNettyIntegrationTest { + + private static final int SERVER_HTTP_PORT = PortFactory.findFreePort(); + private final static int TEST_SERVER_HTTP_PORT = PortFactory.findFreePort(); + private static EchoServer echoServer; + + @BeforeClass + public static void startServer() { + // start proxy and client + new ProxyBuilder().withLocalPort(SERVER_HTTP_PORT).build(); + mockServerClient = new MockServerClient("localhost", SERVER_HTTP_PORT); + + // start echo servers + echoServer = new EchoServer(TEST_SERVER_HTTP_PORT, false); + } + + @AfterClass + public static void stopServer() { + // stop mock server and client + if (mockServerClient instanceof ClientAndServer) { + mockServerClient.stop(); + } + + // stop echo server + echoServer.stop(); + } + + @Override + public void startServerAgain() { + startClientAndServer(SERVER_HTTP_PORT); + } + + @Override + public int getMockServerPort() { + return SERVER_HTTP_PORT; + } + + @Override + public int getMockServerSecurePort() { + return SERVER_HTTP_PORT; + } + + @Override + public int getTestServerPort() { + return TEST_SERVER_HTTP_PORT; + } +} diff --git a/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/MockServerAutoAllocatedPortIntegrationTest.java b/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/MockServerAutoAllocatedPortIntegrationTest.java index 49d77a9bb..4697f7f23 100644 --- a/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/MockServerAutoAllocatedPortIntegrationTest.java +++ b/mockserver-netty/src/test/java/org/mockserver/integration/mockserver/MockServerAutoAllocatedPortIntegrationTest.java @@ -23,7 +23,7 @@ public class MockServerAutoAllocatedPortIntegrationTest extends AbstractRestarta public static void startServer() throws InterruptedException, ExecutionException { // start mock server and client mockServerClient = startClientAndServer(0); - severHttpPort = ((ClientAndServer)mockServerClient).getPort(); + severHttpPort = ((ClientAndServer) mockServerClient).getPort(); // start echo servers echoServer = new EchoServer(TEST_SERVER_HTTP_PORT, false); diff --git a/mockserver-netty/src/test/java/org/mockserver/mockserver/MockServerHandlerTest.java b/mockserver-netty/src/test/java/org/mockserver/mockserver/MockServerHandlerTest.java index 3327a1405..e536d0670 100644 --- a/mockserver-netty/src/test/java/org/mockserver/mockserver/MockServerHandlerTest.java +++ b/mockserver-netty/src/test/java/org/mockserver/mockserver/MockServerHandlerTest.java @@ -1,5 +1,6 @@ package org.mockserver.mockserver; +import com.google.common.collect.ImmutableSet; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; import org.junit.Before; @@ -18,10 +19,13 @@ import org.mockserver.model.HttpRequest; import org.mockserver.model.HttpResponse; import org.mockserver.model.RetrieveType; +import org.mockserver.responsewriter.NettyResponseWriter; import org.mockserver.responsewriter.ResponseWriter; +import org.mockserver.server.ServletResponseWriter; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.concurrent.TimeUnit; import static com.google.common.net.MediaType.JSON_UTF_8; @@ -32,12 +36,15 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyListOf; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.isNull; import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.initMocks; import static org.mockserver.character.Character.NEW_LINE; import static org.mockserver.model.HttpRequest.request; import static org.mockserver.model.HttpResponse.response; import static org.mockserver.model.PortBinding.portBinding; +import static org.mockserver.proxy.Proxy.LOCAL_HOST_HEADERS; +import static org.mockserver.proxy.Proxy.PROXYING; /** * @author jamesdbloom @@ -83,17 +90,17 @@ public void shouldRetrieveRequests() { // given httpStateHandler.log(new RequestLogEntry(request("request_one"))); HttpRequest expectationRetrieveRequestsRequest = request("/retrieve") - .withMethod("PUT") - .withBody( - httpRequestSerializer.serialize(request("request_one")) - ); + .withMethod("PUT") + .withBody( + httpRequestSerializer.serialize(request("request_one")) + ); // when embeddedChannel.writeInbound(expectationRetrieveRequestsRequest); // then assertResponse(200, httpRequestSerializer.serialize(Collections.singletonList( - request("request_one") + request("request_one") ))); } @@ -104,10 +111,10 @@ public void shouldClear() { httpStateHandler.add(new Expectation(request("request_one")).thenRespond(response("response_one"))); httpStateHandler.log(new RequestLogEntry(request("request_one"))); HttpRequest clearRequest = request("/clear") - .withMethod("PUT") - .withBody( - httpRequestSerializer.serialize(request("request_one")) - ); + .withMethod("PUT") + .withBody( + httpRequestSerializer.serialize(request("request_one")) + ); // when embeddedChannel.writeInbound(clearRequest); @@ -116,10 +123,10 @@ public void shouldClear() { assertResponse(200, ""); assertThat(httpStateHandler.firstMatchingExpectation(request("request_one")), is(nullValue())); assertThat(httpStateHandler.retrieve(request("/retrieve") - .withMethod("PUT") - .withBody( - httpRequestSerializer.serialize(request("request_one")) - )), is(response().withBody("", JSON_UTF_8).withStatusCode(200))); + .withMethod("PUT") + .withBody( + httpRequestSerializer.serialize(request("request_one")) + )), is(response().withBody("", JSON_UTF_8).withStatusCode(200))); } @Test @@ -133,7 +140,7 @@ public void shouldReturnStatus() { // then assertResponse(200, portBindingSerializer.serialize( - portBinding(1080, 1090) + portBinding(1080, 1090) )); } @@ -142,10 +149,10 @@ public void shouldBindNewPorts() { // given when(mockMockServer.bindToPorts(anyListOf(Integer.class))).thenReturn(Arrays.asList(1080, 1090)); HttpRequest statusRequest = request("/bind") - .withMethod("PUT") - .withBody(portBindingSerializer.serialize( - portBinding(1080, 1090) - )); + .withMethod("PUT") + .withBody(portBindingSerializer.serialize( + portBinding(1080, 1090) + )); // when embeddedChannel.writeInbound(statusRequest); @@ -153,7 +160,7 @@ public void shouldBindNewPorts() { // then verify(mockMockServer).bindToPorts(Arrays.asList(1080, 1090)); assertResponse(200, portBindingSerializer.serialize( - portBinding(1080, 1090) + portBinding(1080, 1090) )); } @@ -161,7 +168,7 @@ public void shouldBindNewPorts() { public void shouldStop() throws InterruptedException { // given HttpRequest statusRequest = request("/stop") - .withMethod("PUT"); + .withMethod("PUT"); // when embeddedChannel.writeInbound(statusRequest); @@ -177,22 +184,22 @@ public void shouldRetrieveRecordedExpectations() { // given Expectation expectationOne = new Expectation(request("request_one")).thenRespond(response("response_one")); httpStateHandler.log(new ExpectationMatchLogEntry( - request("request_one"), - expectationOne + request("request_one"), + expectationOne )); HttpRequest expectationRetrieveExpectationsRequest = request("/retrieve") - .withMethod("PUT") - .withQueryStringParameter("type", RetrieveType.RECORDED_EXPECTATIONS.name()) - .withBody( - httpRequestSerializer.serialize(request("request_one")) - ); + .withMethod("PUT") + .withQueryStringParameter("type", RetrieveType.RECORDED_EXPECTATIONS.name()) + .withBody( + httpRequestSerializer.serialize(request("request_one")) + ); // when embeddedChannel.writeInbound(expectationRetrieveExpectationsRequest); // then assertResponse(200, expectationSerializer.serialize(Collections.singletonList( - expectationOne + expectationOne ))); } @@ -253,7 +260,7 @@ public void shouldAddExpectation() { // given Expectation expectationOne = new Expectation(request("request_one")).thenRespond(response("response_one")); HttpRequest request = request("/expectation").withMethod("PUT").withBody( - expectationSerializer.serialize(expectationOne) + expectationSerializer.serialize(expectationOne) ); // when @@ -270,31 +277,75 @@ public void shouldRetrieveActiveExpectations() { Expectation expectationOne = new Expectation(request("request_one")).thenRespond(response("response_one")); httpStateHandler.add(expectationOne); HttpRequest expectationRetrieveExpectationsRequest = request("/retrieve") - .withMethod("PUT") - .withQueryStringParameter("type", RetrieveType.ACTIVE_EXPECTATIONS.name()) - .withBody( - httpRequestSerializer.serialize(request("request_one")) - ); + .withMethod("PUT") + .withQueryStringParameter("type", RetrieveType.ACTIVE_EXPECTATIONS.name()) + .withBody( + httpRequestSerializer.serialize(request("request_one")) + ); // when embeddedChannel.writeInbound(expectationRetrieveExpectationsRequest); // then assertResponse(200, expectationSerializer.serialize(Collections.singletonList( - expectationOne + expectationOne ))); } @Test - public void shouldUseActionHandlerToHandleNonAPIRequests() { + public void shouldUseActionHandlerToHandleNonAPIRequestsWhenProxying() { + // given + HttpRequest request = request("request_one"); + embeddedChannel.attr(LOCAL_HOST_HEADERS).set(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )); + embeddedChannel.attr(PROXYING).set(true); + + // when + embeddedChannel.writeInbound(request); + + // then + verify(mockActionHandler).processAction( + eq(request), + any(NettyResponseWriter.class), + any(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )), + eq(true) + ); + } + + @Test + public void shouldUseActionHandlerToHandleNonAPIRequestsWhenNotProxying() { // given HttpRequest request = request("request_one"); + embeddedChannel.attr(LOCAL_HOST_HEADERS).set(ImmutableSet.of( + "local_address", + "localhost", + "127.0.0.1" + )); + embeddedChannel.attr(PROXYING).set(false); // when embeddedChannel.writeInbound(request); // then - verify(mockActionHandler).processAction(eq(request), any(ResponseWriter.class), any(ChannelHandlerContext.class)); + verify(mockActionHandler).processAction( + eq(request), + any(NettyResponseWriter.class), + any(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address", + "localhost", + "127.0.0.1" + )), + eq(false) + ); } } diff --git a/mockserver-netty/src/test/java/org/mockserver/proxy/direct/DirectProxyUnificationHandlerTest.java b/mockserver-netty/src/test/java/org/mockserver/proxy/direct/DirectProxyUnificationHandlerTest.java index 96a45423d..18f254860 100644 --- a/mockserver-netty/src/test/java/org/mockserver/proxy/direct/DirectProxyUnificationHandlerTest.java +++ b/mockserver-netty/src/test/java/org/mockserver/proxy/direct/DirectProxyUnificationHandlerTest.java @@ -18,7 +18,7 @@ import org.mockserver.proxy.http.HttpProxyUnificationHandler; import org.mockserver.proxy.relay.RelayConnectHandler; import org.mockserver.proxy.socks.SocksProxyHandler; -import org.mockserver.proxy.unification.PortUnificationHandler; +import org.mockserver.unification.PortUnificationHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,8 +31,8 @@ import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; import static org.mockserver.proxy.Proxy.HTTP_PROXY; -import static org.mockserver.proxy.Proxy.STATE_HANDLER; public class DirectProxyUnificationHandlerTest { @@ -139,20 +139,24 @@ public void shouldSwitchToHttp() { // then - should add HTTP handlers last if (LoggerFactory.getLogger(PortUnificationHandler.class).isTraceEnabled()) { - assertThat(embeddedChannel.pipeline().names(), contains( + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( "LoggingHandler#0", "HttpServerCodec#0", "HttpContentDecompressor#0", + "HttpContentLengthRemover#0", "HttpObjectAggregator#0", + "WebSocketServerHandler#0", "MockServerServerCodec#0", "HttpProxyHandler#0", "DefaultChannelPipeline$TailContext#0" )); } else { - assertThat(embeddedChannel.pipeline().names(), contains( + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( "HttpServerCodec#0", "HttpContentDecompressor#0", + "HttpContentLengthRemover#0", "HttpObjectAggregator#0", + "WebSocketServerHandler#0", "MockServerServerCodec#0", "HttpProxyHandler#0", "DefaultChannelPipeline$TailContext#0" diff --git a/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyHandlerTest.java b/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyHandlerTest.java index 676f24545..c03c02230 100644 --- a/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyHandlerTest.java +++ b/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyHandlerTest.java @@ -1,5 +1,6 @@ package org.mockserver.proxy.http; +import com.google.common.collect.ImmutableSet; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.embedded.EmbeddedChannel; import org.junit.Before; @@ -26,6 +27,7 @@ import java.net.InetSocketAddress; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.concurrent.TimeUnit; import static com.google.common.net.MediaType.JSON_UTF_8; @@ -41,6 +43,8 @@ import static org.mockserver.model.HttpRequest.request; import static org.mockserver.model.HttpResponse.response; import static org.mockserver.model.PortBinding.portBinding; +import static org.mockserver.proxy.Proxy.LOCAL_HOST_HEADERS; +import static org.mockserver.proxy.Proxy.PROXYING; /** * @author jamesdbloom @@ -245,17 +249,63 @@ public void shouldRetrieveLogMessages() { } @Test - public void shouldProxyRequests() { + public void shouldProxyRequestsWhenProxying() { // given HttpRequest request = request("request_one"); InetSocketAddress remoteAddress = new InetSocketAddress(1080); + embeddedChannel.attr(LOCAL_HOST_HEADERS).set(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )); + embeddedChannel.attr(PROXYING).set(true); + embeddedChannel.attr(REMOTE_SOCKET).set(remoteAddress); // when + embeddedChannel.writeInbound(request); + + // then + verify(mockActionHandler).processAction( + eq(request), + any(NettyResponseWriter.class), + any(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )), + eq(true) + ); + } + + @Test + public void shouldProxyRequestsWhenNotProxying() { + // given + HttpRequest request = request("request_one"); + InetSocketAddress remoteAddress = new InetSocketAddress(1080); + embeddedChannel.attr(LOCAL_HOST_HEADERS).set(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )); + embeddedChannel.attr(PROXYING).set(false); embeddedChannel.attr(REMOTE_SOCKET).set(remoteAddress); + + // when embeddedChannel.writeInbound(request); // then - verify(mockActionHandler).processAction(eq(request), any(NettyResponseWriter.class), any(ChannelHandlerContext.class)); + verify(mockActionHandler).processAction( + eq(request), + any(NettyResponseWriter.class), + any(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )), + eq(false) + ); } } diff --git a/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerSOCKSErrorTest.java b/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerSOCKSErrorTest.java index 6c3d7c842..e74743585 100644 --- a/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerSOCKSErrorTest.java +++ b/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerSOCKSErrorTest.java @@ -15,7 +15,7 @@ import org.mockserver.proxy.Proxy; import org.mockserver.proxy.relay.RelayConnectHandler; import org.mockserver.proxy.socks.SocksProxyHandler; -import org.mockserver.proxy.unification.PortUnificationHandler; +import org.mockserver.unification.PortUnificationHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,8 +31,8 @@ import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; import static org.mockserver.proxy.Proxy.HTTP_PROXY; -import static org.mockserver.proxy.Proxy.STATE_HANDLER; public class HttpProxyUnificationHandlerSOCKSErrorTest { @@ -53,57 +53,57 @@ public void shouldHandleErrorsDuringSOCKSConnection() throws IOException, Interr // when - SOCKS INIT message embeddedChannel.writeInbound(Unpooled.wrappedBuffer(new byte[]{ - (byte) 0x05, // SOCKS5 - (byte) 0x02, // 1 authentication method - (byte) 0x00, // NO_AUTH - (byte) 0x02, // AUTH_PASSWORD + (byte) 0x05, // SOCKS5 + (byte) 0x02, // 1 authentication method + (byte) 0x00, // NO_AUTH + (byte) 0x02, // AUTH_PASSWORD })); // then - INIT response assertThat(ByteBufUtil.hexDump((ByteBuf) embeddedChannel.readOutbound()), is(Hex.encodeHexString(new byte[]{ - (byte) 0x05, // SOCKS5 - (byte) 0x00, // NO_AUTH + (byte) 0x05, // SOCKS5 + (byte) 0x00, // NO_AUTH }))); // and then - should add SOCKS handlers first if (LoggerFactory.getLogger(PortUnificationHandler.class).isTraceEnabled()) { - assertThat(embeddedChannel.pipeline().names(), contains( - "LoggingHandler#0", - "SocksCmdRequestDecoder#0", - "SocksMessageEncoder#0", - "SocksProxyHandler#0", - "HttpProxyUnificationHandler#0", - "DefaultChannelPipeline$TailContext#0" + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( + "LoggingHandler#0", + "SocksCmdRequestDecoder#0", + "SocksMessageEncoder#0", + "SocksProxyHandler#0", + "HttpProxyUnificationHandler#0", + "DefaultChannelPipeline$TailContext#0" )); } else { - assertThat(embeddedChannel.pipeline().names(), contains( - "SocksCmdRequestDecoder#0", - "SocksMessageEncoder#0", - "SocksProxyHandler#0", - "HttpProxyUnificationHandler#0", - "DefaultChannelPipeline$TailContext#0" + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( + "SocksCmdRequestDecoder#0", + "SocksMessageEncoder#0", + "SocksProxyHandler#0", + "HttpProxyUnificationHandler#0", + "DefaultChannelPipeline$TailContext#0" )); } // and when - SOCKS CONNECT command embeddedChannel.writeInbound(Unpooled.wrappedBuffer(new byte[]{ - (byte) 0x05, // SOCKS5 - (byte) 0x01, // command type CONNECT - (byte) 0x00, // reserved (must be 0x00) - (byte) 0x01, // address type IPv4 - (byte) 0x7f, (byte) 0x00, (byte) 0x00, (byte) 0x01, // ip address - (byte) (localPort & 0xFF00), (byte) localPort // port + (byte) 0x05, // SOCKS5 + (byte) 0x01, // command type CONNECT + (byte) 0x00, // reserved (must be 0x00) + (byte) 0x01, // address type IPv4 + (byte) 0x7f, (byte) 0x00, (byte) 0x00, (byte) 0x01, // ip address + (byte) (localPort & 0xFF00), (byte) localPort // port })); // then - CONNECT response assertThat(ByteBufUtil.hexDump((ByteBuf) embeddedChannel.readOutbound()), is(Hex.encodeHexString(new byte[]{ - (byte) 0x05, // SOCKS5 - (byte) 0x01, // general failure (caused by connection failure) - (byte) 0x00, // reserved (must be 0x00) - (byte) 0x01, // address type IPv4 - (byte) 0x7f, (byte) 0x00, (byte) 0x00, (byte) 0x01, // ip address - (byte) (localPort & 0xFF00), (byte) localPort // port + (byte) 0x05, // SOCKS5 + (byte) 0x01, // general failure (caused by connection failure) + (byte) 0x00, // reserved (must be 0x00) + (byte) 0x01, // address type IPv4 + (byte) 0x7f, (byte) 0x00, (byte) 0x00, (byte) 0x01, // ip address + (byte) (localPort & 0xFF00), (byte) localPort // port }))); // then - channel is closed after error @@ -129,23 +129,27 @@ public void shouldSwitchToHttp() { // then - should add HTTP handlers last if (LoggerFactory.getLogger(PortUnificationHandler.class).isTraceEnabled()) { - assertThat(embeddedChannel.pipeline().names(), contains( - "LoggingHandler#0", - "HttpServerCodec#0", - "HttpContentDecompressor#0", - "HttpObjectAggregator#0", - "MockServerServerCodec#0", - "HttpProxyHandler#0", - "DefaultChannelPipeline$TailContext#0" + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( + "LoggingHandler#0", + "HttpServerCodec#0", + "HttpContentDecompressor#0", + "HttpContentLengthRemover#0", + "HttpObjectAggregator#0", + "WebSocketServerHandler#0", + "MockServerServerCodec#0", + "HttpProxyHandler#0", + "DefaultChannelPipeline$TailContext#0" )); } else { - assertThat(embeddedChannel.pipeline().names(), contains( - "HttpServerCodec#0", - "HttpContentDecompressor#0", - "HttpObjectAggregator#0", - "MockServerServerCodec#0", - "HttpProxyHandler#0", - "DefaultChannelPipeline$TailContext#0" + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( + "HttpServerCodec#0", + "HttpContentDecompressor#0", + "HttpContentLengthRemover#0", + "HttpObjectAggregator#0", + "WebSocketServerHandler#0", + "MockServerServerCodec#0", + "HttpProxyHandler#0", + "DefaultChannelPipeline$TailContext#0" )); } } @@ -163,11 +167,11 @@ public void shouldSupportUnknownProtocol() { // then - should add no handlers assertThat(embeddedChannel.pipeline().names(), contains( - "DefaultChannelPipeline$TailContext#0" + "DefaultChannelPipeline$TailContext#0" )); // and - close channel assertThat(embeddedChannel.isOpen(), is(false)); } -} \ No newline at end of file +} diff --git a/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerTest.java b/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerTest.java index 6c0a5173a..55799665f 100644 --- a/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerTest.java +++ b/mockserver-netty/src/test/java/org/mockserver/proxy/http/HttpProxyUnificationHandlerTest.java @@ -27,8 +27,8 @@ import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; +import static org.mockserver.mock.HttpStateHandler.STATE_HANDLER; import static org.mockserver.proxy.Proxy.HTTP_PROXY; -import static org.mockserver.proxy.Proxy.STATE_HANDLER; public class HttpProxyUnificationHandlerTest { @@ -50,7 +50,7 @@ public void shouldSwitchToSsl() { })); // then - should add SSL handlers first - assertThat(embeddedChannel.pipeline().names(), contains( + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( "SslHandler#0", "HttpProxyUnificationHandler#0", "DefaultChannelPipeline$TailContext#0" @@ -88,7 +88,7 @@ public void shouldSwitchToSOCKS() throws IOException, InterruptedException { }))); // and then - should add SOCKS handlers first - assertThat(embeddedChannel.pipeline().names(), contains( + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( "SocksCmdRequestDecoder#0", "SocksMessageEncoder#0", "SocksProxyHandler#0", @@ -114,10 +114,12 @@ public void shouldSwitchToHttp() { embeddedChannel.writeInbound(Unpooled.wrappedBuffer("GET /somePath HTTP/1.1\r\nHost: some.random.host\r\n\r\n".getBytes(UTF_8))); // then - should add HTTP handlers last - assertThat(embeddedChannel.pipeline().names(), contains( + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( "HttpServerCodec#0", "HttpContentDecompressor#0", + "HttpContentLengthRemover#0", "HttpObjectAggregator#0", + "WebSocketServerHandler#0", "MockServerServerCodec#0", "HttpProxyHandler#0", "DefaultChannelPipeline$TailContext#0" @@ -136,11 +138,11 @@ public void shouldSupportUnknownProtocol() { embeddedChannel.writeInbound(Unpooled.wrappedBuffer("UNKNOWN_PROTOCOL".getBytes(UTF_8))); // then - should add no handlers - assertThat(embeddedChannel.pipeline().names(), contains( + assertThat(String.valueOf(embeddedChannel.pipeline().names()), embeddedChannel.pipeline().names(), contains( "DefaultChannelPipeline$TailContext#0" )); // and - close channel assertThat(embeddedChannel.isOpen(), is(false)); } -} \ No newline at end of file +} diff --git a/mockserver-proxy-war/src/main/java/org/mockserver/proxy/ProxyServlet.java b/mockserver-proxy-war/src/main/java/org/mockserver/proxy/ProxyServlet.java index 32b0881de..ca5cd7dd6 100644 --- a/mockserver-proxy-war/src/main/java/org/mockserver/proxy/ProxyServlet.java +++ b/mockserver-proxy-war/src/main/java/org/mockserver/proxy/ProxyServlet.java @@ -1,14 +1,13 @@ package org.mockserver.proxy; +import com.google.common.collect.ImmutableSet; import com.google.common.net.MediaType; import org.mockserver.client.serialization.PortBindingSerializer; -import org.mockserver.log.model.RequestResponseLogEntry; import org.mockserver.logging.LoggingFormatter; import org.mockserver.mappers.HttpServletRequestToMockServerRequestDecoder; import org.mockserver.mock.HttpStateHandler; import org.mockserver.mock.action.ActionHandler; import org.mockserver.model.HttpRequest; -import org.mockserver.model.HttpResponse; import org.mockserver.responsewriter.ResponseWriter; import org.mockserver.server.ServletResponseWriter; @@ -16,11 +15,12 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.util.Arrays; +import java.util.HashSet; + import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.rtsp.RtspResponseStatuses.NOT_IMPLEMENTED; -import static org.mockserver.character.Character.NEW_LINE; -import static org.mockserver.model.HttpResponse.notFoundResponse; import static org.mockserver.model.HttpResponse.response; import static org.mockserver.model.PortBinding.portBinding; @@ -42,7 +42,7 @@ public class ProxyServlet extends HttpServlet { public ProxyServlet() { this.httpStateHandler = new HttpStateHandler(); this.logFormatter = httpStateHandler.getLogFormatter(); - this.actionHandler = new ActionHandler(httpStateHandler, true); + this.actionHandler = new ActionHandler(httpStateHandler); } @Override @@ -69,7 +69,15 @@ public void service(HttpServletRequest httpServletRequest, HttpServletResponse h } else { - actionHandler.processAction(request, responseWriter, null); + String portExtension = ""; + if (!(httpServletRequest.getLocalPort() == 443 && httpServletRequest.isSecure() || httpServletRequest.getLocalPort() == 80)) { + portExtension = ":" + httpServletRequest.getLocalPort(); + } + actionHandler.processAction(request, responseWriter, null, ImmutableSet.of( + httpServletRequest.getLocalAddr() + portExtension, + "localhost" + portExtension, + "127.0.0.1" + portExtension + ), true); } } diff --git a/mockserver-proxy-war/src/test/java/org/mockserver/proxy/ProxyServletTest.java b/mockserver-proxy-war/src/test/java/org/mockserver/proxy/ProxyServletTest.java index da395e96a..02c2c9c49 100644 --- a/mockserver-proxy-war/src/test/java/org/mockserver/proxy/ProxyServletTest.java +++ b/mockserver-proxy-war/src/test/java/org/mockserver/proxy/ProxyServletTest.java @@ -1,14 +1,13 @@ package org.mockserver.proxy; +import com.google.common.collect.ImmutableSet; import io.netty.channel.ChannelHandlerContext; import org.junit.Before; import org.junit.Test; import org.mockito.InjectMocks; -import org.mockserver.client.netty.NettyHttpClient; import org.mockserver.client.serialization.ExpectationSerializer; import org.mockserver.client.serialization.HttpRequestSerializer; import org.mockserver.client.serialization.PortBindingSerializer; -import org.mockserver.client.serialization.curl.HttpRequestToCurlSerializer; import org.mockserver.log.model.ExpectationMatchLogEntry; import org.mockserver.log.model.RequestLogEntry; import org.mockserver.logging.LoggingFormatter; @@ -22,6 +21,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import java.util.Collections; +import java.util.HashSet; import static com.google.common.net.MediaType.JSON_UTF_8; import static org.apache.commons.codec.Charsets.UTF_8; @@ -225,17 +225,125 @@ public void shouldRetrieveLogMessages() { } @Test - public void shouldProxyRequests() { + public void shouldProxyRequestsOnDefaultPort() { // given HttpRequest request = request("request_one").withHeader("Host", "localhost").withMethod("GET"); MockHttpServletRequest httpServletRequest = buildHttpServletRequest("GET", "request_one", ""); httpServletRequest.addHeader("Host", "localhost"); + httpServletRequest.setLocalAddr("local_address"); + httpServletRequest.setLocalPort(80); // when proxyServlet.service(httpServletRequest, response); // then - verify(mockActionHandler).processAction(eq(request.withSecure(false).withKeepAlive(true)), any(ServletResponseWriter.class), isNull(ChannelHandlerContext.class)); + verify(mockActionHandler).processAction( + eq( + request + .withSecure(false) + .withKeepAlive(true) + ), + any(ServletResponseWriter.class), + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address", + "localhost", + "127.0.0.1" + )), + eq(true) + ); + } + + @Test + public void shouldProxyRequestsOnNonDefaultPort() { + // given + HttpRequest request = request("request_one").withHeader("Host", "localhost").withMethod("GET"); + MockHttpServletRequest httpServletRequest = buildHttpServletRequest("GET", "request_one", ""); + httpServletRequest.addHeader("Host", "localhost"); + httpServletRequest.setLocalAddr("local_address"); + httpServletRequest.setLocalPort(666); + + // when + proxyServlet.service(httpServletRequest, response); + + // then + verify(mockActionHandler).processAction( + eq( + request + .withSecure(false) + .withKeepAlive(true) + ), + any(ServletResponseWriter.class), + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )), + eq(true) + ); + } + + @Test + public void shouldProxySecureRequestsOnDefaultPort() { + // given + HttpRequest request = request("request_one").withHeader("Host", "localhost").withMethod("GET"); + MockHttpServletRequest httpServletRequest = buildHttpServletRequest("GET", "request_one", ""); + httpServletRequest.addHeader("Host", "localhost"); + httpServletRequest.setSecure(true); + httpServletRequest.setLocalAddr("local_address"); + httpServletRequest.setLocalPort(443); + + // when + proxyServlet.service(httpServletRequest, response); + + // then + verify(mockActionHandler).processAction( + eq( + request + .withSecure(true) + .withKeepAlive(true) + ), + any(ServletResponseWriter.class), + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address", + "localhost", + "127.0.0.1" + )), + eq(true) + ); + } + + @Test + public void shouldProxySecureRequestsOnNonDefaultPort() { + // given + HttpRequest request = request("request_one").withHeader("Host", "localhost").withMethod("GET"); + MockHttpServletRequest httpServletRequest = buildHttpServletRequest("GET", "request_one", ""); + httpServletRequest.addHeader("Host", "localhost"); + httpServletRequest.setSecure(true); + httpServletRequest.setLocalAddr("local_address"); + httpServletRequest.setLocalPort(666); + + // when + proxyServlet.service(httpServletRequest, response); + + // then + verify(mockActionHandler).processAction( + eq( + request + .withSecure(true) + .withKeepAlive(true) + ), + any(ServletResponseWriter.class), + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )), + eq(true) + ); } } diff --git a/mockserver-war/src/main/java/org/mockserver/server/MockServerServlet.java b/mockserver-war/src/main/java/org/mockserver/server/MockServerServlet.java index da618f628..f034ee98f 100644 --- a/mockserver-war/src/main/java/org/mockserver/server/MockServerServlet.java +++ b/mockserver-war/src/main/java/org/mockserver/server/MockServerServlet.java @@ -1,5 +1,6 @@ package org.mockserver.server; +import com.google.common.collect.ImmutableSet; import com.google.common.net.MediaType; import org.mockserver.client.serialization.PortBindingSerializer; import org.mockserver.logging.LoggingFormatter; @@ -13,6 +14,8 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.util.HashSet; + import static io.netty.handler.codec.http.HttpResponseStatus.*; import static org.mockserver.model.HttpResponse.response; import static org.mockserver.model.PortBinding.portBinding; @@ -35,7 +38,7 @@ public class MockServerServlet extends HttpServlet { public MockServerServlet() { this.httpStateHandler = new HttpStateHandler(); this.logFormatter = httpStateHandler.getLogFormatter(); - this.actionHandler = new ActionHandler(httpStateHandler, false); + this.actionHandler = new ActionHandler(httpStateHandler); } @Override @@ -62,7 +65,15 @@ public void service(HttpServletRequest httpServletRequest, HttpServletResponse h } else { - actionHandler.processAction(request, responseWriter, null); + String portExtension = ""; + if (!(httpServletRequest.getLocalPort() == 443 && httpServletRequest.isSecure() || httpServletRequest.getLocalPort() == 80)) { + portExtension = ":" + httpServletRequest.getLocalPort(); + } + actionHandler.processAction(request, responseWriter, null, ImmutableSet.of( + httpServletRequest.getLocalAddr() + portExtension, + "localhost" + portExtension, + "127.0.0.1" + portExtension + ), false); } } diff --git a/mockserver-war/src/test/java/org/mockserver/server/DeployableWARAbstractClientServerIntegrationTest.java b/mockserver-war/src/test/java/org/mockserver/server/DeployableWARAbstractClientServerIntegrationTest.java index 66dc27e2b..4440772db 100644 --- a/mockserver-war/src/test/java/org/mockserver/server/DeployableWARAbstractClientServerIntegrationTest.java +++ b/mockserver-war/src/test/java/org/mockserver/server/DeployableWARAbstractClientServerIntegrationTest.java @@ -3,9 +3,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.mockserver.client.server.ClientException; +import org.mockserver.client.ClientException; import org.mockserver.integration.server.SameJVMAbstractClientServerIntegrationTest; -import org.mockserver.model.HttpStatusCode; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; import static org.hamcrest.Matchers.containsString; diff --git a/mockserver-war/src/test/java/org/mockserver/server/MockServerServletTest.java b/mockserver-war/src/test/java/org/mockserver/server/MockServerServletTest.java index 588e06741..955037492 100644 --- a/mockserver-war/src/test/java/org/mockserver/server/MockServerServletTest.java +++ b/mockserver-war/src/test/java/org/mockserver/server/MockServerServletTest.java @@ -1,5 +1,6 @@ package org.mockserver.server; +import com.google.common.collect.ImmutableSet; import io.netty.channel.ChannelHandlerContext; import org.junit.Before; import org.junit.Test; @@ -17,6 +18,7 @@ import org.springframework.mock.web.MockHttpServletResponse; import java.util.Collections; +import java.util.HashSet; import static com.google.common.net.MediaType.JSON_UTF_8; import static org.apache.commons.codec.Charsets.UTF_8; @@ -281,13 +283,15 @@ public void shouldRetrieveLogMessages() { } @Test - public void shouldUseActionHandlerToHandleNonAPIRequests() { + public void shouldUseActionHandlerToHandleNonAPIRequestsOnDefaultPort() { // given MockHttpServletRequest request = buildHttpServletRequest( "GET", "request_one", "" ); + request.setLocalAddr("local_address"); + request.setLocalPort(80); // when mockServerServlet.service(request, response); @@ -301,7 +305,114 @@ public void shouldUseActionHandlerToHandleNonAPIRequests() { .withSecure(false) ), any(ServletResponseWriter.class), - isNull(ChannelHandlerContext.class) + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address", + "localhost", + "127.0.0.1" + )), + eq(false) + ); + } + + @Test + public void shouldUseActionHandlerToHandleNonAPIRequestsOnNonDefaultPort() { + // given + MockHttpServletRequest request = buildHttpServletRequest( + "GET", + "request_one", + "" + ); + request.setLocalAddr("local_address"); + request.setLocalPort(666); + + // when + mockServerServlet.service(request, response); + + // then + verify(mockActionHandler).processAction( + eq( + request("request_one") + .withMethod("GET") + .withKeepAlive(true) + .withSecure(false) + ), + any(ServletResponseWriter.class), + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )), + eq(false) + ); + } + + @Test + public void shouldUseActionHandlerToHandleNonAPISecureRequestsOnDefaultPort() { + // given + MockHttpServletRequest request = buildHttpServletRequest( + "GET", + "request_one", + "" + ); + request.setSecure(true); + request.setLocalAddr("local_address"); + request.setLocalPort(443); + + // when + mockServerServlet.service(request, response); + + // then + verify(mockActionHandler).processAction( + eq( + request("request_one") + .withMethod("GET") + .withKeepAlive(true) + .withSecure(true) + ), + any(ServletResponseWriter.class), + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address", + "localhost", + "127.0.0.1" + )), + eq(false) + ); + } + + @Test + public void shouldUseActionHandlerToHandleNonAPISecureRequestsOnNonDefaultPort() { + // given + MockHttpServletRequest request = buildHttpServletRequest( + "GET", + "request_one", + "" + ); + request.setSecure(true); + request.setLocalAddr("local_address"); + request.setLocalPort(666); + + // when + mockServerServlet.service(request, response); + + // then + verify(mockActionHandler).processAction( + eq( + request("request_one") + .withMethod("GET") + .withKeepAlive(true) + .withSecure(true) + ), + any(ServletResponseWriter.class), + isNull(ChannelHandlerContext.class), + eq(ImmutableSet.of( + "local_address:666", + "localhost:666", + "127.0.0.1:666" + )), + eq(false) ); }