From 5e5178b480bc865f0f010b9984986377fc13f348 Mon Sep 17 00:00:00 2001 From: Rune Flobakk Date: Thu, 1 Aug 2024 09:42:47 +0200 Subject: [PATCH] Ensure exception on chunk reaching beyond limit This is a subtle behavior, but previous behavior has been that the first invocation to read(byte[]) with an array that would be filled with more elements if not being limited, would actually throw an exception. This commit preserves this behavior. One might argue that the most beneficial way to handle reading the "last chunk" before reaching the limit would be to fill the the array with any remaining "allowed" bytes, and return an int indicating the amount of bytes propagated to the array, though this also indicates that the stream reached a normal last chunk of bytes, and one can assume next read would return -1 (EOF). When setting up the LimitedInputStream to throw an exception if reading past the limit, this will ensure that code which assume that they can read one chunk of bytes, and as long as the read succeeds, they should be given complete and well-formed data. This would not be the case if the LimitedInputStream limits the bytes, and if it is set up to throw an exception, this exception must be thrown. If the LimitedInputStream is set up to "silently" yield EOF on reaching the limit, the read chunk is correctly chopped accordinly, and the next read will yield EOF as expected. --- .../no/digipost/io/LimitedInputStream.java | 12 ++- .../digipost/io/LimitedInputStreamTest.java | 82 +++++++++++++++++-- 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/src/main/java/no/digipost/io/LimitedInputStream.java b/src/main/java/no/digipost/io/LimitedInputStream.java index eb91353..d42770a 100644 --- a/src/main/java/no/digipost/io/LimitedInputStream.java +++ b/src/main/java/no/digipost/io/LimitedInputStream.java @@ -19,6 +19,7 @@ import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; import java.util.function.Supplier; import static java.lang.Math.max; @@ -138,10 +139,19 @@ public int read() throws IOException { @Override public int read(byte[] b, int off, int len) throws IOException { int allowedRemaing = (int)(maxBytesCount - count); + if (len == 0) { + return allowedRemaing > 0 ? 0 : -1; + } int res; if (allowedRemaing > 0) { - res = super.read(b, off, min(len, allowedRemaing)); + int maxAllowedReadLen = min(len, allowedRemaing + 1); + res = super.read(b, off, maxAllowedReadLen); count += max(res, 1); + if (res > allowedRemaing) { + Arrays.fill(b, off + maxAllowedReadLen - (res - allowedRemaing), off + maxAllowedReadLen, (byte)0); + res = allowedRemaing; + reachedLimit(); + } } else { res = read(); } diff --git a/src/test/java/no/digipost/io/LimitedInputStreamTest.java b/src/test/java/no/digipost/io/LimitedInputStreamTest.java index cd3dd09..a356244 100644 --- a/src/test/java/no/digipost/io/LimitedInputStreamTest.java +++ b/src/test/java/no/digipost/io/LimitedInputStreamTest.java @@ -87,6 +87,39 @@ void neverReadsMoreThanTheSetLimit() { } }); } + + @Test + void discardsBytesAfterLimit() throws IOException { + byte[] sixBytes = new byte[] {65, 66, 67, 68, 69, 70}; + try ( + InputStream source = new ByteArrayInputStream(sixBytes); + InputStream limitedToTwoBytes = limit(source, bytes(4))) { + + + byte[] readBytes = new byte[6]; + byte[] expectedEmpty = new byte[2]; + assertAll( + () -> assertThat("first read yields 4 bytes", limitedToTwoBytes.read(readBytes), is(4)), + () -> assertArrayEquals(new byte[] {65, 66, 67, 68, 0, 0}, readBytes), + () -> assertThat("reading single read yields EOF", limitedToTwoBytes.read(), is(-1)), + () -> assertThat("reading chunk yields EOF", limitedToTwoBytes.read(expectedEmpty), is(-1)), + () -> assertArrayEquals(new byte[] {0, 0}, expectedEmpty)); + } + } + + @Test + void readingZeroBytesYieldsZeroOrEof() throws IOException { + byte[] twoBytes = new byte[] {65, 66}; + try ( + InputStream source = new ByteArrayInputStream(twoBytes); + InputStream limitedToOneByte = limit(source, bytes(1))) { + assertThat(limitedToOneByte.read(new byte[0]), is(0)); + limitedToOneByte.read(); + assertThat(limitedToOneByte.read(new byte[0]), is(-1)); + assertThat(limitedToOneByte.read(), is(-1)); + } + } + } @@ -117,6 +150,45 @@ public void wrapsOtherCheckedExceptionsThanIOExceptionAsRuntimeException() throw assertThat(assertThrows(RuntimeException.class, () -> testLimitedStream("xyz", () -> tooManyBytes)), where(Exception::getCause, sameInstance(tooManyBytes))); } + @Test + void throwsWhenAttemptingToReadChunkOneByteLargerThanLimit() throws IOException { + byte[] sixBytes = new byte[] {65, 66, 67, 68, 69, 70}; + try ( + InputStream source = new ByteArrayInputStream(sixBytes); + InputStream limitedToTwoBytes = limit(source, bytes(5), () -> new IllegalStateException("reached limit!"))) { + + assertThrows(IllegalStateException.class, () -> limitedToTwoBytes.read(new byte[6])); + } + } + + @Test + void doesNotThrowWhenAttemptingToReadChunkExactlyEndingAtTheLimit() throws IOException { + byte[] sixBytes = new byte[] {65, 66, 67, 68, 69, 70}; + try ( + InputStream source = new ByteArrayInputStream(sixBytes); + InputStream limitedToTwoBytes = limit(source, bytes(6), () -> new IllegalStateException("reached limit!"))) { + + limitedToTwoBytes.read(new byte[3]); + byte[] lastThreeBytes = new byte[3]; + limitedToTwoBytes.read(lastThreeBytes); + assertArrayEquals(new byte[]{68, 69, 70}, lastThreeBytes); + } + } + + @Test + void readingZeroBytesNeverThrowsUntilTryingToActuallyReadNonZeroBytes() throws IOException { + byte[] twoBytes = new byte[] {65, 66}; + try ( + InputStream source = new ByteArrayInputStream(twoBytes); + InputStream limitedToOneByte = limit(source, bytes(1), () -> new IOException("Should not be thrown"))) { + assertThat(limitedToOneByte.read(new byte[0]), is(0)); + limitedToOneByte.read(); + assertThat(limitedToOneByte.read(new byte[0]), is(-1)); + assertThrows(IOException.class, () -> limitedToOneByte.read()); + assertThrows(IOException.class, () -> limitedToOneByte.read(new byte[1])); + } + } + } @@ -181,9 +253,10 @@ void doesNotConsumeMoreFromUnderlyingInputStreamThanGivenLimit() throws IOExcept InputStream threeBytes = new ByteArrayInputStream(new byte[] {65, 66, 67}); InputStream maxTwoBytes = limit(threeBytes, bytes(2))) { - assertThat(maxTwoBytes.read(readBytes), is(2)); + assertAll( + () -> assertThat(maxTwoBytes.read(readBytes), is(2)), + () -> assertArrayEquals(new byte[] {65, 66, 0}, readBytes)); } - assertArrayEquals(new byte[] {65, 66, 0}, readBytes); } @Test @@ -198,8 +271,8 @@ void ableToResetBufferedStreamWhenLimitedStreamIsExhausted() throws IOException InputStream source = new ByteArrayInputStream(oneKiloByte); InputStream bufferedSource = new BufferedInputStream(source, 400)) { - bufferedSource.mark(limit); - try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit), () -> new IllegalStateException("Reached limit!"))) { + bufferedSource.mark(limit + 1); // <-- :( + try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit))) { assertThat(limitedStream.read(readFromLimitedStream), is(limit)); bufferedSource.reset(); bufferedSource.read(readFromBufferedStream); @@ -234,7 +307,6 @@ void rewindWhenReachingLimit() throws IOException { assertArrayEquals(readFromBufferedStream, twoKiloByte); } } - } }