diff --git a/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/ConditionallyReliableLoggingHttpHandlerTest.java b/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/ConditionallyReliableLoggingHttpHandlerTest.java index ef05a2bad..ac35777b2 100644 --- a/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/ConditionallyReliableLoggingHttpHandlerTest.java +++ b/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/ConditionallyReliableLoggingHttpHandlerTest.java @@ -1,6 +1,7 @@ package org.opensearch.migrations.trafficcapture.netty; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import lombok.extern.slf4j.Slf4j; @@ -28,8 +29,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.opensearch.migrations.trafficcapture.netty.TestStreamManager.consumeIntoArray; - @Slf4j public class ConditionallyReliableLoggingHttpHandlerTest { @@ -48,7 +47,7 @@ private static void writeMessageAndVerify(byte[] fullTrafficBytes, Consumer new ByteArrayInputStream(consumeIntoArray((ByteBuf) m))) + .map(m -> new ByteBufInputStream((ByteBuf) m, true)) .collect(Collectors.toList()))) .readAllBytes(); Assertions.assertArrayEquals(fullTrafficBytes, outputData); @@ -132,7 +131,7 @@ public void testThatSuppressedCaptureWorks() throws Exception { Assertions.assertEquals(0, streamMgr.flushCount.get()); // we wrote the correct data to the downstream handler/channel var outputData = new SequenceInputStream(Collections.enumeration(channel.inboundMessages().stream() - .map(m -> new ByteArrayInputStream(consumeIntoArray((ByteBuf) m))) + .map(m -> new ByteBufInputStream((ByteBuf) m, true)) .collect(Collectors.toList()))) .readAllBytes(); log.info("outputdata = " + new String(outputData, StandardCharsets.UTF_8)); @@ -161,7 +160,7 @@ public void testThatHealthCheckCaptureCanBeSuppressed(boolean singleBytes) throw // we wrote the correct data to the downstream handler/channel var consumedData = new SequenceInputStream(Collections.enumeration(channel.inboundMessages().stream() - .map(m -> new ByteArrayInputStream(consumeIntoArray((ByteBuf) m))) + .map(m -> new ByteBufInputStream((ByteBuf) m)) .collect(Collectors.toList()))) .readAllBytes(); log.info("captureddata = " + new String(consumedData, StandardCharsets.UTF_8)); diff --git a/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/RootWireLoggingContextTest.java b/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/RootWireLoggingContextTest.java index 22f58ed30..4911cbf4c 100644 --- a/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/RootWireLoggingContextTest.java +++ b/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/RootWireLoggingContextTest.java @@ -1,6 +1,7 @@ package org.opensearch.migrations.trafficcapture.netty; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; import io.netty.channel.embedded.EmbeddedChannel; import io.opentelemetry.sdk.trace.data.SpanData; import lombok.Getter; @@ -21,17 +22,17 @@ import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; -import static org.opensearch.migrations.trafficcapture.netty.TestStreamManager.consumeIntoArray; - @Slf4j public class RootWireLoggingContextTest { - private static void writeMessageAndVerifyTraces(byte[] fullTrafficBytes, Consumer channelWriter, boolean shouldBlock, - List expectedTraces) throws IOException { + private static void writeMessageAndVerifyTraces(byte[] fullTrafficBytes, boolean shouldBlock, boolean shouldCloseChannel, + Set expectedTraces) throws IOException { try (var rootContext = new TestRootContext(true, true)) { + Consumer channelWriter = w -> w.writeInbound(TestUtilities.getByteBuf(fullTrafficBytes, false)); var streamManager = new TestStreamManager(); var offloader = new StreamChannelConnectionCaptureSerializer("Test", "c", streamManager); @@ -43,18 +44,28 @@ private static void writeMessageAndVerifyTraces(byte[] fullTrafficBytes, Consume // we wrote the correct data to the downstream handler/channel var outputData = new SequenceInputStream(Collections.enumeration(channel.inboundMessages().stream() - .map(m -> new ByteArrayInputStream(consumeIntoArray((ByteBuf) m))) + .map(m -> new ByteBufInputStream((ByteBuf) m, true)) .collect(Collectors.toList()))) .readAllBytes(); Assertions.assertArrayEquals(fullTrafficBytes, outputData); - // For a non-blocking write, close the channel to force it to flush + // For a non-blocking write, we need to force it to flush if (!shouldBlock) { + // We should not have flushed at this point + Assertions.assertEquals(0, streamManager.flushCount.get()); + if (!shouldCloseChannel) { + // Force a manual flush without closing the channel + offloader.flushCommitAndResetStream(true).join(); + } + } + if (shouldCloseChannel) { + // Fully close the channel, which should prompt a flush channel.close().awaitUninterruptibly(); } Assertions.assertNotNull(streamManager.byteBufferAtomicReference.get(), "This would be null if the handler didn't block until the output was written"); + Assertions.assertEquals(1, streamManager.flushCount.get()); var trafficStream = TrafficStream.parseFrom(streamManager.byteBufferAtomicReference.get()); Assertions.assertTrue(trafficStream.getSubStreamCount() > 0 && @@ -65,39 +76,37 @@ private static void writeMessageAndVerifyTraces(byte[] fullTrafficBytes, Consume .map(to -> new ByteArrayInputStream(to.getRead().getData().toByteArray())) .collect(Collectors.toList()))); Assertions.assertArrayEquals(fullTrafficBytes, combinedTrafficPacketsStream.readAllBytes()); - Assertions.assertEquals(1, streamManager.flushCount.get()); - List finishedSpans = rootContext.instrumentationBundle.getFinishedSpans(); + var finishedSpans = rootContext.instrumentationBundle.getFinishedSpans(); Assertions.assertTrue(!finishedSpans.isEmpty()); Assertions.assertTrue(!rootContext.instrumentationBundle.getFinishedMetrics().isEmpty()); - List finishedSpanNames = finishedSpans.stream().map(x -> x.getName()).collect(Collectors.toList()); - expectedTraces.forEach(trace -> - Assertions.assertTrue(finishedSpanNames.contains(trace), "finishedSpans does not contain " + trace) + Assertions.assertEquals(expectedTraces.stream().sorted().collect(Collectors.joining("\n")), + finishedSpans.stream().map(SpanData::getName).sorted().collect(Collectors.joining("\n")) ); - if (shouldBlock) { - Assertions.assertTrue(finishedSpanNames.contains("blocked"), "finishedSpans does not contain 'blocked' on a blocking request"); - } } - } @Test - public void testThatAGetProducesGatheringRequestTrace() - throws IOException { + public void testThatAGetProducesGatheringRequestTrace_WithoutClosingChannel() throws IOException { byte[] fullTrafficBytes = SimpleRequests.HEALTH_CHECK.getBytes(StandardCharsets.UTF_8); - var bb = TestUtilities.getByteBuf(fullTrafficBytes, false); - writeMessageAndVerifyTraces(fullTrafficBytes, w -> w.writeInbound(bb), false, List.of("gatheringRequest")); + writeMessageAndVerifyTraces(fullTrafficBytes, false, false, + Set.of("gatheringRequest")); } @Test - public void testThatAPostProducesGatheringRequestTrace() - throws IOException { - byte[] fullTrafficBytes = SimpleRequests.SMALL_POST.getBytes(StandardCharsets.UTF_8); - var bb = TestUtilities.getByteBuf(fullTrafficBytes, false); - writeMessageAndVerifyTraces(fullTrafficBytes, w -> w.writeInbound(bb), true, List.of("gatheringRequest")); + public void testThatAGetProducesGatheringRequestTrace_WithClosingChannel() throws IOException { + byte[] fullTrafficBytes = SimpleRequests.HEALTH_CHECK.getBytes(StandardCharsets.UTF_8); + writeMessageAndVerifyTraces(fullTrafficBytes, false, true, + Set.of("captureConnection", "gatheringRequest", "waitingForResponse")); } + @Test + public void testThatAPostProducesGatheringRequestTrace_WithoutClosingChannel() throws IOException { + byte[] fullTrafficBytes = SimpleRequests.SMALL_POST.getBytes(StandardCharsets.UTF_8); + writeMessageAndVerifyTraces(fullTrafficBytes, true, false, + Set.of("gatheringRequest", "blocked")); + } } diff --git a/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/TestStreamManager.java b/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/TestStreamManager.java index d252ea5ca..3f2c6888f 100644 --- a/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/TestStreamManager.java +++ b/TrafficCapture/nettyWireLogging/src/test/java/org/opensearch/migrations/trafficcapture/netty/TestStreamManager.java @@ -18,13 +18,6 @@ class TestStreamManager extends OrderedStreamLifecyleManager implements AutoClos AtomicReference byteBufferAtomicReference = new AtomicReference<>(); AtomicInteger flushCount = new AtomicInteger(); - static byte[] consumeIntoArray(ByteBuf m) { - var bArr = new byte[m.readableBytes()]; - m.readBytes(bArr); - m.release(); - return bArr; - } - @Override public void close() { }