Skip to content

Commit

Permalink
Rewind BufferedInputStream when reaching limit
Browse files Browse the repository at this point in the history
If using a LimitedInputStream for reading a certain amount of
"header" data from an InputStream, then rewinding back to start in order
to process the entire stream. E.g. to persist it somewhere.

The problem with this approach is that a LimitedInputStream can be
instructed to throw an exception if trying to read past the limit (i.e.
it is expected to finish processing fairly early in the stream, and then
end reading, and the limit is to protect spooling through a potentially
huge amount of data), and in the event of actually reaching the limit, a
LimitedInputStream _must_ read at least one more byte in order to
determine if the underlying stream is exhausted and yields -1, or if it
has more data, and followingly the LimitedInputStream must throw an
exception.

This reeks of a design error in LimitedInputStream. Perhaps having the
variance of throwing an exception on reaching the "end" of a limited
stream is flawed, and this issue demonstrates that. A LimitedInputStream
introduces a potential earlier EOF, and it should perhaps strictly treat
it like that and yield -1 in any case it reaches the limit (i.e. the end).
If detecting if it actually reached the limit is required, it is a
separate concern, and must be done externally wrt. the InputStream.
  • Loading branch information
runeflobakk committed Jul 31, 2024
1 parent 67e2f78 commit ecb3a8e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 25 deletions.
50 changes: 25 additions & 25 deletions src/main/java/no/digipost/io/LimitedInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.io.InputStream;
import java.util.function.Supplier;

import static java.lang.Math.max;
import static java.lang.Math.min;
import static no.digipost.DiggExceptions.asUnchecked;

/**
Expand Down Expand Up @@ -98,14 +100,16 @@ public LimitedInputStream(InputStream inputStream, long maxBytesCount, Supplier<
*/
@Override
public int read() throws IOException {
if (count > maxBytesCount) {
return reachedLimit();
}
int res = super.read();
if (res != -1) {
count++;
if (hasReachedLimit()) {
return -1;
}
count++;
if (res == -1 || count <= maxBytesCount) {
return res;
} else {
return reachedLimit();
}
return res;
}

/**
Expand Down Expand Up @@ -133,30 +137,26 @@ public int read() throws IOException {
*/
@Override
public int read(byte[] b, int off, int len) throws IOException {
int res = super.read(b, off, len);
if (res > 0) {
count += res;
if (hasReachedLimit()) {
return -1;
}
int allowedRemaing = (int)(maxBytesCount - count);
int res;
if (allowedRemaing > 0) {
res = super.read(b, off, min(len, allowedRemaing));
count += max(res, 1);
} else {
res = read();
}
return res;
}


private boolean hasReachedLimit() throws IOException {
if (count > maxBytesCount) {
if (throwIfTooManyBytes == SILENTLY_EOF_ON_REACHING_LIMIT) {
return true;
}
Exception tooManyBytes = throwIfTooManyBytes.get();
if (tooManyBytes instanceof IOException) {
throw (IOException) tooManyBytes;
} else {
throw asUnchecked(tooManyBytes);
}
private int reachedLimit() throws IOException {
if (throwIfTooManyBytes == SILENTLY_EOF_ON_REACHING_LIMIT) {
return -1;
}
Exception tooManyBytes = throwIfTooManyBytes.get();
if (tooManyBytes instanceof IOException) {
throw (IOException) tooManyBytes;
} else {
return false;
throw asUnchecked(tooManyBytes);
}
}

Expand Down
66 changes: 66 additions & 0 deletions src/test/java/no/digipost/io/LimitedInputStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.junit.jupiter.api.Test;
import org.quicktheories.core.Gen;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
Expand All @@ -38,7 +39,9 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.quicktheories.QuickTheory.qt;
import static org.quicktheories.generators.SourceDSL.integers;
Expand Down Expand Up @@ -171,4 +174,67 @@ private static byte[] toByteArrayUsingSingleByteReads(InputStream toRead, DataSi
return readBytes;
}

@Test
void doesNotConsumeMoreFromUnderlyingInputStreamThanGivenLimit() throws IOException {
byte[] readBytes = new byte[3];
try (
InputStream threeBytes = new ByteArrayInputStream(new byte[] {65, 66, 67});
InputStream maxTwoBytes = limit(threeBytes, bytes(2))) {

assertThat(maxTwoBytes.read(readBytes), is(2));
}
assertArrayEquals(new byte[] {65, 66, 0}, readBytes);
}

@Test
void ableToResetBufferedStreamWhenLimitedStreamIsExhausted() throws IOException {
byte[] oneKiloByte = new byte[1024];
Arrays.fill(oneKiloByte, (byte) 65);

byte[] readFromLimitedStream = new byte[800];
byte[] readFromBufferedStream = new byte[oneKiloByte.length];
int limit = 600;
try (
InputStream source = new ByteArrayInputStream(oneKiloByte);
InputStream bufferedSource = new BufferedInputStream(source, 400)) {

bufferedSource.mark(limit);
try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit), () -> new IllegalStateException("Reached limit!"))) {
assertThat(limitedStream.read(readFromLimitedStream), is(limit));
bufferedSource.reset();
bufferedSource.read(readFromBufferedStream);
}
}
assertArrayEquals(readFromBufferedStream, oneKiloByte);
assertAll(
() -> assertEquals(readFromLimitedStream[0], (byte)65),
() -> assertEquals(readFromLimitedStream[limit - 1], (byte)65),
() -> assertEquals(readFromLimitedStream[limit], (byte)0),
() -> assertEquals(readFromLimitedStream[readFromLimitedStream.length - 1], (byte)0)
);
}

@Test
void rewindWhenReachingLimit() throws IOException {
byte[] twoKiloByte = new byte[2048];
Arrays.fill(twoKiloByte, (byte) 65);

int limit = 1024;
try (
InputStream source = new ByteArrayInputStream(twoKiloByte);
InputStream bufferedSource = new BufferedInputStream(source, 512)) {

bufferedSource.mark(limit + 1); // <-- :(
try (InputStream limitedStream = limit(bufferedSource, DataSize.bytes(limit))) {
byte[] readFromLimitedStream = toByteArray(limitedStream);
assertThat(readFromLimitedStream.length, is(limit));

bufferedSource.reset();
byte[] readFromBufferedStream = toByteArray(bufferedSource);
assertArrayEquals(readFromBufferedStream, twoKiloByte);
}
}

}

}

0 comments on commit ecb3a8e

Please sign in to comment.