diff --git a/src/main/java/no/digipost/io/LimitedInputStream.java b/src/main/java/no/digipost/io/LimitedInputStream.java index 2a71709..eb91353 100644 --- a/src/main/java/no/digipost/io/LimitedInputStream.java +++ b/src/main/java/no/digipost/io/LimitedInputStream.java @@ -21,6 +21,8 @@ import java.io.InputStream; import java.util.function.Supplier; +import static java.lang.Math.max; +import static java.lang.Math.min; import static no.digipost.DiggExceptions.asUnchecked; /** @@ -98,14 +100,16 @@ public LimitedInputStream(InputStream inputStream, long maxBytesCount, Supplier< */ @Override public int read() throws IOException { + if (count > maxBytesCount) { + return reachedLimit(); + } int res = super.read(); - if (res != -1) { - count++; - if (hasReachedLimit()) { - return -1; - } + count++; + if (res == -1 || count <= maxBytesCount) { + return res; + } else { + return reachedLimit(); } - return res; } /** @@ -133,30 +137,26 @@ public int read() throws IOException { */ @Override public int read(byte[] b, int off, int len) throws IOException { - int res = super.read(b, off, len); - if (res > 0) { - count += res; - if (hasReachedLimit()) { - return -1; - } + int allowedRemaing = (int)(maxBytesCount - count); + int res; + if (allowedRemaing > 0) { + res = super.read(b, off, min(len, allowedRemaing)); + count += max(res, 1); + } else { + res = read(); } return res; } - - private boolean hasReachedLimit() throws IOException { - if (count > maxBytesCount) { - if (throwIfTooManyBytes == SILENTLY_EOF_ON_REACHING_LIMIT) { - return true; - } - Exception tooManyBytes = throwIfTooManyBytes.get(); - if (tooManyBytes instanceof IOException) { - throw (IOException) tooManyBytes; - } else { - throw asUnchecked(tooManyBytes); - } + private int reachedLimit() throws IOException { + if (throwIfTooManyBytes == SILENTLY_EOF_ON_REACHING_LIMIT) { + return -1; + } + Exception tooManyBytes = throwIfTooManyBytes.get(); + if (tooManyBytes instanceof IOException) { + throw (IOException) tooManyBytes; } else { - return false; + throw asUnchecked(tooManyBytes); } } diff --git a/src/test/java/no/digipost/io/LimitedInputStreamTest.java b/src/test/java/no/digipost/io/LimitedInputStreamTest.java index a22b9b5..cd3dd09 100644 --- a/src/test/java/no/digipost/io/LimitedInputStreamTest.java +++ b/src/test/java/no/digipost/io/LimitedInputStreamTest.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import org.quicktheories.core.Gen; +import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; @@ -38,7 +39,9 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.quicktheories.QuickTheory.qt; import static org.quicktheories.generators.SourceDSL.integers; @@ -171,4 +174,67 @@ private static byte[] toByteArrayUsingSingleByteReads(InputStream toRead, DataSi return readBytes; } + @Test + void doesNotConsumeMoreFromUnderlyingInputStreamThanGivenLimit() throws IOException { + byte[] readBytes = new byte[3]; + try ( + InputStream threeBytes = new ByteArrayInputStream(new byte[] {65, 66, 67}); + InputStream maxTwoBytes = limit(threeBytes, bytes(2))) { + + assertThat(maxTwoBytes.read(readBytes), is(2)); + } + assertArrayEquals(new byte[] {65, 66, 0}, readBytes); + } + + @Test + void ableToResetBufferedStreamWhenLimitedStreamIsExhausted() throws IOException { + byte[] oneKiloByte = new byte[1024]; + Arrays.fill(oneKiloByte, (byte) 65); + + byte[] readFromLimitedStream = new byte[800]; + byte[] readFromBufferedStream = new byte[oneKiloByte.length]; + int limit = 600; + try ( + InputStream source = new ByteArrayInputStream(oneKiloByte); + InputStream bufferedSource = new BufferedInputStream(source, 400)) { + + bufferedSource.mark(limit); + try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit), () -> new IllegalStateException("Reached limit!"))) { + assertThat(limitedStream.read(readFromLimitedStream), is(limit)); + bufferedSource.reset(); + bufferedSource.read(readFromBufferedStream); + } + } + assertArrayEquals(readFromBufferedStream, oneKiloByte); + assertAll( + () -> assertEquals(readFromLimitedStream[0], (byte)65), + () -> assertEquals(readFromLimitedStream[limit - 1], (byte)65), + () -> assertEquals(readFromLimitedStream[limit], (byte)0), + () -> assertEquals(readFromLimitedStream[readFromLimitedStream.length - 1], (byte)0) + ); + } + + @Test + void rewindWhenReachingLimit() throws IOException { + byte[] twoKiloByte = new byte[2048]; + Arrays.fill(twoKiloByte, (byte) 65); + + int limit = 1024; + try ( + InputStream source = new ByteArrayInputStream(twoKiloByte); + InputStream bufferedSource = new BufferedInputStream(source, 512)) { + + bufferedSource.mark(limit + 1); // <-- :( + try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit))) { + byte[] readFromLimitedStream = toByteArray(limitedStream); + assertThat(readFromLimitedStream.length, is(limit)); + + bufferedSource.reset(); + byte[] readFromBufferedStream = toByteArray(bufferedSource); + assertArrayEquals(readFromBufferedStream, twoKiloByte); + } + } + + } + }