From ecb3a8eafa89b90d59e98c1e238957df413697b3 Mon Sep 17 00:00:00 2001 From: Rune Flobakk Date: Wed, 31 Jul 2024 11:40:15 +0200 Subject: [PATCH] Rewind BufferedInputStream when reaching limit If using a LimitedInputStream for reading a certain amount of "header" data from an InputStream, then rewinding back to start in order to process the entire stream. E.g. to persist it somewhere. The problem with this approach is that a LimitedInputStream can be instructed to throw an exception if trying to read past the limit (i.e. it is expected to finish processing fairly early in the stream, and then end reading, and the limit is to protect spooling through a potentially huge amount of data), and in the event of actually reaching the limit, a LimitedInputStream _must_ read at least one more byte in order to determine if the underlying stream is exhausted and yields -1, or if it has more data, and followingly the LimitedInputStream must throw an exception. This reeks of a design error in LimitedInputStream. Perhaps having the variance of throwing an exception on reaching the "end" of a limited stream is flawed, and this issue demonstrates that. A LimitedInputStream introduces a potential earlier EOF, and it should perhaps strictly treat it like that and yield -1 in any case it reaches the limit (i.e. the end). If detecting if it actually reached the limit is required, it is a separate concern, and must be done externally wrt. the InputStream. --- .../no/digipost/io/LimitedInputStream.java | 50 +++++++------- .../digipost/io/LimitedInputStreamTest.java | 66 +++++++++++++++++++ 2 files changed, 91 insertions(+), 25 deletions(-) diff --git a/src/main/java/no/digipost/io/LimitedInputStream.java b/src/main/java/no/digipost/io/LimitedInputStream.java index 2a71709..eb91353 100644 --- a/src/main/java/no/digipost/io/LimitedInputStream.java +++ b/src/main/java/no/digipost/io/LimitedInputStream.java @@ -21,6 +21,8 @@ import java.io.InputStream; import java.util.function.Supplier; +import static java.lang.Math.max; +import static java.lang.Math.min; import static no.digipost.DiggExceptions.asUnchecked; /** @@ -98,14 +100,16 @@ public LimitedInputStream(InputStream inputStream, long maxBytesCount, Supplier< */ @Override public int read() throws IOException { + if (count > maxBytesCount) { + return reachedLimit(); + } int res = super.read(); - if (res != -1) { - count++; - if (hasReachedLimit()) { - return -1; - } + count++; + if (res == -1 || count <= maxBytesCount) { + return res; + } else { + return reachedLimit(); } - return res; } /** @@ -133,30 +137,26 @@ public int read() throws IOException { */ @Override public int read(byte[] b, int off, int len) throws IOException { - int res = super.read(b, off, len); - if (res > 0) { - count += res; - if (hasReachedLimit()) { - return -1; - } + int allowedRemaing = (int)(maxBytesCount - count); + int res; + if (allowedRemaing > 0) { + res = super.read(b, off, min(len, allowedRemaing)); + count += max(res, 1); + } else { + res = read(); } return res; } - - private boolean hasReachedLimit() throws IOException { - if (count > maxBytesCount) { - if (throwIfTooManyBytes == SILENTLY_EOF_ON_REACHING_LIMIT) { - return true; - } - Exception tooManyBytes = throwIfTooManyBytes.get(); - if (tooManyBytes instanceof IOException) { - throw (IOException) tooManyBytes; - } else { - throw asUnchecked(tooManyBytes); - } + private int reachedLimit() throws IOException { + if (throwIfTooManyBytes == SILENTLY_EOF_ON_REACHING_LIMIT) { + return -1; + } + Exception tooManyBytes = throwIfTooManyBytes.get(); + if (tooManyBytes instanceof IOException) { + throw (IOException) tooManyBytes; } else { - return false; + throw asUnchecked(tooManyBytes); } } diff --git a/src/test/java/no/digipost/io/LimitedInputStreamTest.java b/src/test/java/no/digipost/io/LimitedInputStreamTest.java index a22b9b5..cd3dd09 100644 --- a/src/test/java/no/digipost/io/LimitedInputStreamTest.java +++ b/src/test/java/no/digipost/io/LimitedInputStreamTest.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import org.quicktheories.core.Gen; +import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; @@ -38,7 +39,9 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.quicktheories.QuickTheory.qt; import static org.quicktheories.generators.SourceDSL.integers; @@ -171,4 +174,67 @@ private static byte[] toByteArrayUsingSingleByteReads(InputStream toRead, DataSi return readBytes; } + @Test + void doesNotConsumeMoreFromUnderlyingInputStreamThanGivenLimit() throws IOException { + byte[] readBytes = new byte[3]; + try ( + InputStream threeBytes = new ByteArrayInputStream(new byte[] {65, 66, 67}); + InputStream maxTwoBytes = limit(threeBytes, bytes(2))) { + + assertThat(maxTwoBytes.read(readBytes), is(2)); + } + assertArrayEquals(new byte[] {65, 66, 0}, readBytes); + } + + @Test + void ableToResetBufferedStreamWhenLimitedStreamIsExhausted() throws IOException { + byte[] oneKiloByte = new byte[1024]; + Arrays.fill(oneKiloByte, (byte) 65); + + byte[] readFromLimitedStream = new byte[800]; + byte[] readFromBufferedStream = new byte[oneKiloByte.length]; + int limit = 600; + try ( + InputStream source = new ByteArrayInputStream(oneKiloByte); + InputStream bufferedSource = new BufferedInputStream(source, 400)) { + + bufferedSource.mark(limit); + try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit), () -> new IllegalStateException("Reached limit!"))) { + assertThat(limitedStream.read(readFromLimitedStream), is(limit)); + bufferedSource.reset(); + bufferedSource.read(readFromBufferedStream); + } + } + assertArrayEquals(readFromBufferedStream, oneKiloByte); + assertAll( + () -> assertEquals(readFromLimitedStream[0], (byte)65), + () -> assertEquals(readFromLimitedStream[limit - 1], (byte)65), + () -> assertEquals(readFromLimitedStream[limit], (byte)0), + () -> assertEquals(readFromLimitedStream[readFromLimitedStream.length - 1], (byte)0) + ); + } + + @Test + void rewindWhenReachingLimit() throws IOException { + byte[] twoKiloByte = new byte[2048]; + Arrays.fill(twoKiloByte, (byte) 65); + + int limit = 1024; + try ( + InputStream source = new ByteArrayInputStream(twoKiloByte); + InputStream bufferedSource = new BufferedInputStream(source, 512)) { + + bufferedSource.mark(limit + 1); // <-- :( + try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit))) { + byte[] readFromLimitedStream = toByteArray(limitedStream); + assertThat(readFromLimitedStream.length, is(limit)); + + bufferedSource.reset(); + byte[] readFromBufferedStream = toByteArray(bufferedSource); + assertArrayEquals(readFromBufferedStream, twoKiloByte); + } + } + + } + }