Skip to content

Commit

Permalink
[UNDERTOW-2425] At ServletOutputStreamImpl synchronized workflow (lis…
Browse files Browse the repository at this point in the history
…tener = null), prevent the buffer.flip() from not being cleared after an error during attempts to write.

Also, at ServletPrintWriter, verify if no progress is being made when attempting to encode returns overflow after flushing, and mark error even if there are remaining bytes in the buffer.

Signed-off-by: Flavia Rainone <frainone@redhat.com>
  • Loading branch information
fl4via committed Aug 22, 2024
1 parent bf0cfcb commit 8e75dd3
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,34 +178,14 @@ private void writeTooLargeForBuffer(byte[] b, int off, int len, ByteBuffer buffe
int rem = buffer.remaining();
buffer.put(b, bytesWritten + off, rem);
buffer.flip();
bytesWritten += rem;
int bufferCount = 1;
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE; ++i) {
PooledByteBuffer pooled = bufferPool.allocate();
pooledBuffers[bufferCount - 1] = pooled;
buffers[bufferCount++] = pooled.getBuffer();
ByteBuffer cb = pooled.getBuffer();
int toWrite = len - bytesWritten;
if (toWrite > cb.remaining()) {
rem = cb.remaining();
cb.put(b, bytesWritten + off, rem);
cb.flip();
bytesWritten += rem;
} else {
cb.put(b, bytesWritten + off, toWrite);
bytesWritten = len;
cb.flip();
break;
}
}
Channels.writeBlocking(channel, buffers, 0, bufferCount);
while (bytesWritten < len) {
//ok, it did not fit, loop and loop and loop until it is done
bufferCount = 0;
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE + 1; ++i) {
ByteBuffer cb = buffers[i];
cb.clear();
bufferCount++;
try {
bytesWritten += rem;
int bufferCount = 1;
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE; ++i) {
PooledByteBuffer pooled = bufferPool.allocate();
pooledBuffers[bufferCount - 1] = pooled;
buffers[bufferCount++] = pooled.getBuffer();
ByteBuffer cb = pooled.getBuffer();
int toWrite = len - bytesWritten;
if (toWrite > cb.remaining()) {
rem = cb.remaining();
Expand All @@ -219,9 +199,38 @@ private void writeTooLargeForBuffer(byte[] b, int off, int len, ByteBuffer buffe
break;
}
}
Channels.writeBlocking(channel, buffers, 0, bufferCount);
writeBlocking(buffers, 0, bufferCount, bytesWritten);
// at this point, we know that all buffers[i] have 0 bytes remaining(), so it is safe to loop next just
// until we reach len, even if we stop before reaching the end of buffers array
while (bytesWritten < len) {
int oldBytesWritten = bytesWritten;
//ok, it did not fit, loop and loop and loop until it is done
bufferCount = 0;
for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE + 1; ++i) {
ByteBuffer cb = buffers[i];
cb.clear();
bufferCount++;
int toWrite = len - bytesWritten;
if (toWrite > cb.remaining()) {
rem = cb.remaining();
cb.put(b, bytesWritten + off, rem);
cb.flip();
bytesWritten += rem;
} else {
cb.put(b, bytesWritten + off, toWrite);
bytesWritten = len;
cb.flip();
// safe to break, all buffers that come next have zero remaining() bytes and hence
// won't affect the next writeBlocking call
break;
}
}
writeBlocking(buffers, 0, bufferCount, bytesWritten - oldBytesWritten);
}
} finally {
if (buffer != null)
buffer.compact();
}
buffer.clear();
} finally {
for (int i = 0; i < pooledBuffers.length; ++i) {
PooledByteBuffer p = pooledBuffers[i];
Expand All @@ -245,29 +254,36 @@ private void writeAsync(byte[] b, int off, int len) throws IOException {
buffer.put(b, off, len);
} else {
buffer.flip();
final ByteBuffer userBuffer = ByteBuffer.wrap(b, off, len);
final ByteBuffer[] bufs = new ByteBuffer[]{buffer, userBuffer};
long toWrite = Buffers.remaining(bufs);
long res;
long written = 0;
createChannel();
setFlags(FLAG_WRITE_STARTED);
do {
res = channel.write(bufs);
written += res;
if (res == 0) {
//write it out with a listener
//but we need to copy any extra data
final ByteBuffer copy = ByteBuffer.allocate(userBuffer.remaining());
copy.put(userBuffer);
copy.flip();

this.buffersToWrite = new ByteBuffer[]{buffer, copy};
clearFlags(FLAG_READY);
return;
boolean clearBuffer = true;
try {
final ByteBuffer userBuffer = ByteBuffer.wrap(b, off, len);
final ByteBuffer[] bufs = new ByteBuffer[]{buffer, userBuffer};
long toWrite = Buffers.remaining(bufs);
long res;
long written = 0;
createChannel();
setFlags(FLAG_WRITE_STARTED);
do {
res = channel.write(bufs);
written += res;
if (res == 0) {
//write it out with a listener
//but we need to copy any extra data
final ByteBuffer copy = ByteBuffer.allocate(userBuffer.remaining());
copy.put(userBuffer);
copy.flip();

this.buffersToWrite = new ByteBuffer[]{buffer, copy};
clearFlags(FLAG_READY);
clearBuffer = false;
return;
}
} while (written < toWrite);
} finally {
if (clearBuffer && buffer != null) {
buffer.compact();
}
} while (written < toWrite);
buffer.clear();
}
}
} finally {
updateWrittenAsync(len);
Expand Down Expand Up @@ -296,7 +312,7 @@ public void write(ByteBuffer[] buffers) throws IOException {
if (channel == null) {
channel = servletRequestContext.getExchange().getResponseChannel();
}
Channels.writeBlocking(channel, buffers, 0, buffers.length);
writeBlocking(buffers, 0, buffers.length, len);
setFlags(FLAG_WRITE_STARTED);
} else {
ByteBuffer buffer = buffer();
Expand All @@ -307,14 +323,18 @@ public void write(ByteBuffer[] buffers) throws IOException {
channel = servletRequestContext.getExchange().getResponseChannel();
}
if (buffer.position() == 0) {
Channels.writeBlocking(channel, buffers, 0, buffers.length);
writeBlocking(buffers, 0, buffers.length, len);
} else {
final ByteBuffer[] newBuffers = new ByteBuffer[buffers.length + 1];
buffer.flip();
newBuffers[0] = buffer;
System.arraycopy(buffers, 0, newBuffers, 1, buffers.length);
Channels.writeBlocking(channel, newBuffers, 0, newBuffers.length);
buffer.clear();
try {
newBuffers[0] = buffer;
System.arraycopy(buffers, 0, newBuffers, 1, buffers.length);
writeBlocking(newBuffers, 0, newBuffers.length, len + buffer.remaining());
} finally {
if (buffer != null)
buffer.clear();
}
}
setFlags(FLAG_WRITE_STARTED);
}
Expand All @@ -333,30 +353,34 @@ public void write(ByteBuffer[] buffers) throws IOException {
} else {
final ByteBuffer[] bufs = new ByteBuffer[buffers.length + 1];
buffer.flip();
bufs[0] = buffer;
System.arraycopy(buffers, 0, bufs, 1, buffers.length);
long toWrite = Buffers.remaining(bufs);
long res;
long written = 0;
createChannel();
setFlags(FLAG_WRITE_STARTED);
do {
res = channel.write(bufs);
written += res;
if (res == 0) {
//write it out with a listener
//but we need to copy any extra data
//TODO: should really allocate from the pool here
final ByteBuffer copy = ByteBuffer.allocate((int) Buffers.remaining(buffers));
Buffers.copy(copy, buffers, 0, buffers.length);
copy.flip();
this.buffersToWrite = new ByteBuffer[]{buffer, copy};
clearFlags(FLAG_READY);
channel.resumeWrites();
return;
}
} while (written < toWrite);
buffer.clear();
try {
bufs[0] = buffer;
System.arraycopy(buffers, 0, bufs, 1, buffers.length);
long toWrite = Buffers.remaining(bufs);
long res;
long written = 0;
createChannel();
setFlags(FLAG_WRITE_STARTED);
do {
res = channel.write(bufs);
written += res;
if (res == 0) {
//write it out with a listener
//but we need to copy any extra data
//TODO: should really allocate from the pool here
final ByteBuffer copy = ByteBuffer.allocate((int) Buffers.remaining(buffers));
Buffers.copy(copy, buffers, 0, buffers.length);
copy.flip();
this.buffersToWrite = new ByteBuffer[] { buffer, copy };
clearFlags(FLAG_READY);
channel.resumeWrites();
return;
}
} while (written < toWrite);
} finally {
if (buffer != null)
buffer.compact();
}
}
} finally {
updateWrittenAsync(len);
Expand Down Expand Up @@ -515,14 +539,18 @@ public void flushInternal() throws IOException {
//if the write fails we just compact, rather than changing the ready state
setFlags(FLAG_WRITE_STARTED);
buffer.flip();
long res;
do {
res = channel.write(buffer);
} while (buffer.hasRemaining() && res != 0);
if (!buffer.hasRemaining()) {
channel.flush();
try {
long res;
do {
res = channel.write(buffer);
} while (buffer.hasRemaining() && res != 0);
if (!buffer.hasRemaining()) {
channel.flush();
}
} finally {
if (buffer != null)
buffer.compact();
}
buffer.compact();
}
}

Expand Down Expand Up @@ -579,14 +607,18 @@ private void writeBufferBlocking(final boolean writeFinal) throws IOException {
channel = servletRequestContext.getExchange().getResponseChannel();
}
buffer.flip();
while (buffer.hasRemaining()) {
int result = writeFinal ? channel.writeFinal(buffer) : channel.write(buffer);
if (result == 0) {
channel.awaitWritable();
try {
while (buffer.hasRemaining()) {
int result = writeFinal ? channel.writeFinal(buffer) : channel.write(buffer);
if (result == 0) {
channel.awaitWritable();
}
}
} finally {
if (buffer != null)
buffer.compact();
setFlags(FLAG_WRITE_STARTED);
}
buffer.clear();
setFlags(FLAG_WRITE_STARTED);
}

/**
Expand Down Expand Up @@ -964,4 +996,10 @@ private void clearFlags(int flags) {
} while (!stateUpdater.compareAndSet(this, old, old & ~flags));
}

private void writeBlocking(ByteBuffer[] buffers, int offs, int len, int bytesToWrite) throws IOException {
int totalWritten = 0;
do {
totalWritten += Channels.writeBlocking(channel, buffers, 0, len);
} while (totalWritten < bytesToWrite);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ public void close() {
underflow = null;
}
if (charsetEncoder != null) {
int remaining = 0;
do {
// before we get the underlying buffer, we need to flush outputStream
ByteBuffer out = outputStream.underlyingBuffer();
if (out == null) {
//servlet output stream has already been closed
Expand All @@ -113,11 +115,13 @@ public void close() {
CoderResult result = charsetEncoder.encode(buffer, out, true);
if (result.isOverflow()) {
outputStream.flushInternal();
if (out.remaining() == 0) {
if (out.remaining() == remaining) {
// no progress in flush
outputStream.close();
error = true;
return;
}
} else
remaining = out.remaining();
} else {
done = true;
}
Expand Down Expand Up @@ -177,7 +181,7 @@ public void write(final CharBuffer input) {
outputStream.updateWritten(writtenLength);
if (result.isOverflow() || !buffer.hasRemaining()) {
outputStream.flushInternal();
if (!buffer.hasRemaining()) {
if (buffer.remaining() == remaining) {
error = true;
return;
}
Expand Down
9 changes: 7 additions & 2 deletions spotbugs-exclude.xml
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,13 @@
</Match>
<Match>
<Bug pattern="RCN_REDUNDANT_NULLCHECK_WOULD_HAVE_BEEN_A_NPE"/>
<Class name="io.undertow.client.http.HttpClientConnection$ClientReadListener"/>
<Method name="handleEvent"/>
<Or>
<And>
<Class name="io.undertow.client.http.HttpClientConnection$ClientReadListener"/>
<Method name="handleEvent"/>
</And>
<Class name="io.undertow.servlet.spec.ServletOutputStreamImpl"/>
</Or>
</Match>
<!-- ignore benchmarks -->
<Match>
Expand Down

0 comments on commit 8e75dd3

Please sign in to comment.