Skip to content

Commit

Permalink
remote: Proactively close the ZstdInputStream in ZstdDecompressingOut…
Browse files Browse the repository at this point in the history
…putStream. (#15372)

ZstdInputStream hangs onto some native memory, which should be released as soon as ZstdDecompressingOutputStream is done being used rather than when the finalizer runs.

Closes #15061.

PiperOrigin-RevId: 438521302

Co-authored-by: Benjamin Peterson <benjamin@engflow.com>
  • Loading branch information
ckolli5 and benjaminp authored May 2, 2022
1 parent 4fd7983 commit 0fb60cd
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 60 deletions.
1 change: 0 additions & 1 deletion src/main/java/com/google/devtools/build/lib/remote/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
"//src/main/java/com/google/devtools/common/options",
"//src/main/protobuf:failure_details_java_proto",
"//third_party:apache_commons_compress",
"//third_party:auth",
"//third_party:caffeine",
"//third_party:flogger",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.flogger.GoogleLogger;
import com.google.common.io.CountingOutputStream;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
Expand Down Expand Up @@ -67,10 +68,8 @@
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.compress.utils.CountingOutputStream;

/** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */
@ThreadSafe
Expand Down Expand Up @@ -298,7 +297,7 @@ public ListenableFuture<Void> uploadActionResult(
public ListenableFuture<Void> downloadBlob(
RemoteActionExecutionContext context, Digest digest, OutputStream out) {
if (digest.getSizeBytes() == 0) {
return Futures.immediateFuture(null);
return Futures.immediateVoidFuture();
}

@Nullable Supplier<Digest> digestSupplier = null;
Expand All @@ -308,26 +307,14 @@ public ListenableFuture<Void> downloadBlob(
out = digestOut;
}

CountingOutputStream outputStream;
if (options.cacheCompression) {
try {
outputStream = new ZstdDecompressingOutputStream(out);
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
} else {
outputStream = new CountingOutputStream(out);
}

return downloadBlob(context, digest, outputStream, digestSupplier);
return downloadBlob(context, digest, new CountingOutputStream(out), digestSupplier);
}

private ListenableFuture<Void> downloadBlob(
RemoteActionExecutionContext context,
Digest digest,
CountingOutputStream out,
@Nullable Supplier<Digest> digestSupplier) {
AtomicLong offset = new AtomicLong(0);
ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff);
ListenableFuture<Long> downloadFuture =
Utils.refreshIfUnauthenticatedAsync(
Expand All @@ -338,7 +325,6 @@ private ListenableFuture<Void> downloadBlob(
channel ->
requestRead(
context,
offset,
progressiveBackoff,
digest,
out,
Expand All @@ -365,20 +351,25 @@ public static String getResourceName(String instanceName, Digest digest, boolean

private ListenableFuture<Long> requestRead(
RemoteActionExecutionContext context,
AtomicLong offset,
ProgressiveBackoff progressiveBackoff,
Digest digest,
CountingOutputStream out,
CountingOutputStream rawOut,
@Nullable Supplier<Digest> digestSupplier,
Channel channel) {
String resourceName =
getResourceName(options.remoteInstanceName, digest, options.cacheCompression);
SettableFuture<Long> future = SettableFuture.create();
OutputStream out;
try {
out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut;
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
}
bsAsyncStub(context, channel)
.read(
ReadRequest.newBuilder()
.setResourceName(resourceName)
.setReadOffset(offset.get())
.setReadOffset(rawOut.getCount())
.build(),
new StreamObserver<ReadResponse>() {

Expand All @@ -387,7 +378,6 @@ public void onNext(ReadResponse readResponse) {
ByteString data = readResponse.getData();
try {
data.writeTo(out);
offset.set(out.getBytesWritten());
} catch (IOException e) {
// Cancel the call.
throw new RuntimeException(e);
Expand All @@ -398,14 +388,15 @@ public void onNext(ReadResponse readResponse) {

@Override
public void onError(Throwable t) {
if (offset.get() == digest.getSizeBytes()) {
if (rawOut.getCount() == digest.getSizeBytes()) {
// If the file was fully downloaded, it doesn't matter if there was an error at
// the end of the stream.
logger.atInfo().withCause(t).log(
"ignoring error because file was fully received");
onCompleted();
return;
}
releaseOut();
Status status = Status.fromThrowable(t);
if (status.getCode() == Status.Code.NOT_FOUND) {
future.setException(new CacheNotFoundException(digest));
Expand All @@ -421,12 +412,24 @@ public void onCompleted() {
Utils.verifyBlobContents(digest, digestSupplier.get());
}
out.flush();
future.set(offset.get());
future.set(rawOut.getCount());
} catch (IOException e) {
future.setException(e);
} catch (RuntimeException e) {
logger.atWarning().withCause(e).log("Unexpected exception");
future.setException(e);
} finally {
releaseOut();
}
}

private void releaseOut() {
if (out instanceof ZstdDecompressingOutputStream) {
try {
((ZstdDecompressingOutputStream) out).closeShallow();
} catch (IOException e) {
logger.atWarning().withCause(e).log("failed to cleanly close output stream");
}
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ java_library(
name = "zstd",
srcs = glob(["*.java"]),
deps = [
"//third_party:apache_commons_compress",
"//third_party:guava",
"//third_party/protobuf:protobuf_java",
"@zstd-jni",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,35 @@
// limitations under the License.
package com.google.devtools.build.lib.remote.zstd;

import com.github.luben.zstd.ZstdInputStream;
import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
import com.google.protobuf.ByteString;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.apache.commons.compress.utils.CountingOutputStream;

/** A {@link CountingOutputStream} that use zstd to decompress the content. */
public class ZstdDecompressingOutputStream extends CountingOutputStream {
/** An {@link OutputStream} that use zstd to decompress the content. */
public final class ZstdDecompressingOutputStream extends OutputStream {
private final OutputStream out;
private ByteArrayInputStream inner;
private final ZstdInputStream zis;
private final ZstdInputStreamNoFinalizer zis;

public ZstdDecompressingOutputStream(OutputStream out) throws IOException {
super(out);
this.out = out;
zis =
new ZstdInputStream(
new InputStream() {
@Override
public int read() {
return inner.read();
}

@Override
public int read(byte[] b, int off, int len) {
return inner.read(b, off, len);
}
});
zis.setContinuous(true);
new ZstdInputStreamNoFinalizer(
new InputStream() {
@Override
public int read() {
return inner.read();
}

@Override
public int read(byte[] b, int off, int len) {
return inner.read(b, off, len);
}
})
.setContinuous(true);
}

@Override
Expand All @@ -58,6 +58,19 @@ public void write(byte[] b) throws IOException {
public void write(byte[] b, int off, int len) throws IOException {
inner = new ByteArrayInputStream(b, off, len);
byte[] data = ByteString.readFrom(zis).toByteArray();
super.write(data, 0, data.length);
out.write(data, 0, data.length);
}

@Override
public void close() throws IOException {
closeShallow();
out.close();
}

/**
* Free resources related to decompression without closing the underlying {@link OutputStream}.
*/
public void closeShallow() throws IOException {
zis.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.ArgumentMatchers.any;

import build.bazel.remote.execution.v2.Digest;
import com.github.luben.zstd.Zstd;
import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase;
import com.google.bytestream.ByteStreamProto.ReadRequest;
import com.google.bytestream.ByteStreamProto.ReadResponse;
import com.google.devtools.build.lib.remote.Retrier.Backoff;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.common.options.Options;
import com.google.protobuf.ByteString;
Expand All @@ -31,38 +29,50 @@
import java.io.IOException;
import java.util.Arrays;
import org.junit.Test;
import org.mockito.Mockito;

/** Extra tests for {@link GrpcCacheClient} that are not tested internally. */
public class GrpcCacheClientTestExtra extends GrpcCacheClientTest {

@Test
public void compressedDownloadBlobIsRetriedWithProgress()
throws IOException, InterruptedException {
Backoff mockBackoff = Mockito.mock(Backoff.class);
RemoteOptions options = Options.getDefaults(RemoteOptions.class);
options.cacheCompression = true;
final GrpcCacheClient client = newClient(options, () -> mockBackoff);
final GrpcCacheClient client = newClient(options);
final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg");
ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8)));
ByteString chunk1 = ByteString.copyFrom(Zstd.compress("abc".getBytes(UTF_8)));
ByteString chunk2 = ByteString.copyFrom(Zstd.compress("def".getBytes(UTF_8)));
ByteString chunk3 = ByteString.copyFrom(Zstd.compress("g".getBytes(UTF_8)));
serviceRegistry.addService(
new ByteStreamImplBase() {
private boolean first = true;

@Override
public void read(ReadRequest request, StreamObserver<ReadResponse> responseObserver) {
assertThat(request.getResourceName().contains(digest.getHash())).isTrue();
int off = (int) request.getReadOffset();
// Zstd header size is 9 bytes
ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off);
responseObserver.onNext(ReadResponse.newBuilder().setData(data).build());
if (off == 0) {
if (first) {
first = false;
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
} else {
responseObserver.onCompleted();
return;
}
switch (Math.toIntExact(request.getReadOffset())) {
case 0:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk1).build());
break;
case 3:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk2).build());
break;
case 6:
responseObserver.onNext(ReadResponse.newBuilder().setData(chunk3).build());
responseObserver.onCompleted();
return;
default:
throw new IllegalStateException("unexpected offset " + request.getReadOffset());
}
responseObserver.onError(Status.DEADLINE_EXCEEDED.asException());
}
});
assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg");
Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public void bytesWrittenMatchesDecompressedBytes() throws IOException {
for (byte b : compressed.toByteArray()) {
zdos.write(b);
zdos.flush();
assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length);
}
assertThat(decompressed.toByteArray()).isEqualTo(data);
}
Expand Down

0 comments on commit 0fb60cd

Please sign in to comment.