From 299022ca2dc49b6cb27b2674f933755306ae8b9b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 31 Mar 2022 03:08:43 -0700 Subject: [PATCH] remote: Proactively close the ZstdInputStream in ZstdDecompressingOutputStream. ZstdInputStream hangs onto some native memory, which should be released as soon as ZstdDecompressingOutputStream is done being used rather than when the finalizer runs. Closes #15061. PiperOrigin-RevId: 438521302 --- .../google/devtools/build/lib/remote/BUILD | 1 - .../build/lib/remote/GrpcCacheClient.java | 49 +++++++++-------- .../devtools/build/lib/remote/zstd/BUILD | 1 - .../zstd/ZstdDecompressingOutputStream.java | 53 ++++++++++++------- .../lib/remote/GrpcCacheClientTestExtra.java | 38 ++++++++----- ...stdDecompressingOutputStreamTestExtra.java | 1 - 6 files changed, 83 insertions(+), 60 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD index 7f8d359926902a..06fb72dffa432b 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD @@ -97,7 +97,6 @@ java_library( "//src/main/java/com/google/devtools/build/lib/vfs:pathfragment", "//src/main/java/com/google/devtools/common/options", "//src/main/protobuf:failure_details_java_proto", - "//third_party:apache_commons_compress", "//third_party:auth", "//third_party:caffeine", "//third_party:flogger", diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index 293562d0043e6d..5dd7dc03428ca6 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -37,6 +37,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.flogger.GoogleLogger; +import com.google.common.io.CountingOutputStream; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; @@ -67,10 +68,8 @@ import java.util.List; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import javax.annotation.Nullable; -import org.apache.commons.compress.utils.CountingOutputStream; /** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */ @ThreadSafe @@ -303,7 +302,7 @@ public ListenableFuture uploadActionResult( public ListenableFuture downloadBlob( RemoteActionExecutionContext context, Digest digest, OutputStream out) { if (digest.getSizeBytes() == 0) { - return Futures.immediateFuture(null); + return Futures.immediateVoidFuture(); } @Nullable Supplier digestSupplier = null; @@ -313,18 +312,7 @@ public ListenableFuture downloadBlob( out = digestOut; } - CountingOutputStream outputStream; - if (options.cacheCompression) { - try { - outputStream = new ZstdDecompressingOutputStream(out); - } catch (IOException e) { - return Futures.immediateFailedFuture(e); - } - } else { - outputStream = new CountingOutputStream(out); - } - - return downloadBlob(context, digest, outputStream, digestSupplier); + return downloadBlob(context, digest, new CountingOutputStream(out), digestSupplier); } private ListenableFuture downloadBlob( @@ -332,7 +320,6 @@ private ListenableFuture downloadBlob( Digest digest, CountingOutputStream out, @Nullable Supplier digestSupplier) { - AtomicLong offset = new AtomicLong(0); ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff); ListenableFuture downloadFuture = Utils.refreshIfUnauthenticatedAsync( @@ -343,7 +330,6 @@ private ListenableFuture downloadBlob( channel -> requestRead( context, - offset, progressiveBackoff, digest, out, @@ -370,20 +356,25 @@ public static String getResourceName(String instanceName, Digest digest, boolean private ListenableFuture requestRead( RemoteActionExecutionContext context, - AtomicLong offset, ProgressiveBackoff progressiveBackoff, Digest digest, - CountingOutputStream out, + CountingOutputStream rawOut, @Nullable Supplier digestSupplier, Channel channel) { String resourceName = getResourceName(options.remoteInstanceName, digest, options.cacheCompression); SettableFuture future = SettableFuture.create(); + OutputStream out; + try { + out = options.cacheCompression ? new ZstdDecompressingOutputStream(rawOut) : rawOut; + } catch (IOException e) { + return Futures.immediateFailedFuture(e); + } bsAsyncStub(context, channel) .read( ReadRequest.newBuilder() .setResourceName(resourceName) - .setReadOffset(offset.get()) + .setReadOffset(rawOut.getCount()) .build(), new StreamObserver() { @@ -392,7 +383,6 @@ public void onNext(ReadResponse readResponse) { ByteString data = readResponse.getData(); try { data.writeTo(out); - offset.set(out.getBytesWritten()); } catch (IOException e) { // Cancel the call. throw new RuntimeException(e); @@ -403,7 +393,7 @@ public void onNext(ReadResponse readResponse) { @Override public void onError(Throwable t) { - if (offset.get() == digest.getSizeBytes()) { + if (rawOut.getCount() == digest.getSizeBytes()) { // If the file was fully downloaded, it doesn't matter if there was an error at // the end of the stream. logger.atInfo().withCause(t).log( @@ -411,6 +401,7 @@ public void onError(Throwable t) { onCompleted(); return; } + releaseOut(); Status status = Status.fromThrowable(t); if (status.getCode() == Status.Code.NOT_FOUND) { future.setException(new CacheNotFoundException(digest)); @@ -426,12 +417,24 @@ public void onCompleted() { Utils.verifyBlobContents(digest, digestSupplier.get()); } out.flush(); - future.set(offset.get()); + future.set(rawOut.getCount()); } catch (IOException e) { future.setException(e); } catch (RuntimeException e) { logger.atWarning().withCause(e).log("Unexpected exception"); future.setException(e); + } finally { + releaseOut(); + } + } + + private void releaseOut() { + if (out instanceof ZstdDecompressingOutputStream) { + try { + ((ZstdDecompressingOutputStream) out).closeShallow(); + } catch (IOException e) { + logger.atWarning().withCause(e).log("failed to cleanly close output stream"); + } } } }); diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD index 6108cddc569f03..75691a65473044 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD @@ -16,7 +16,6 @@ java_library( name = "zstd", srcs = glob(["*.java"]), deps = [ - "//third_party:apache_commons_compress", "//third_party:guava", "//third_party/protobuf:protobuf_java", "@zstd-jni", diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java index ad1c333320964c..9fdb6ae4fdaa89 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java +++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java @@ -13,35 +13,35 @@ // limitations under the License. package com.google.devtools.build.lib.remote.zstd; -import com.github.luben.zstd.ZstdInputStream; +import com.github.luben.zstd.ZstdInputStreamNoFinalizer; import com.google.protobuf.ByteString; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import org.apache.commons.compress.utils.CountingOutputStream; -/** A {@link CountingOutputStream} that use zstd to decompress the content. */ -public class ZstdDecompressingOutputStream extends CountingOutputStream { +/** An {@link OutputStream} that use zstd to decompress the content. */ +public final class ZstdDecompressingOutputStream extends OutputStream { + private final OutputStream out; private ByteArrayInputStream inner; - private final ZstdInputStream zis; + private final ZstdInputStreamNoFinalizer zis; public ZstdDecompressingOutputStream(OutputStream out) throws IOException { - super(out); + this.out = out; zis = - new ZstdInputStream( - new InputStream() { - @Override - public int read() { - return inner.read(); - } - - @Override - public int read(byte[] b, int off, int len) { - return inner.read(b, off, len); - } - }); - zis.setContinuous(true); + new ZstdInputStreamNoFinalizer( + new InputStream() { + @Override + public int read() { + return inner.read(); + } + + @Override + public int read(byte[] b, int off, int len) { + return inner.read(b, off, len); + } + }) + .setContinuous(true); } @Override @@ -58,6 +58,19 @@ public void write(byte[] b) throws IOException { public void write(byte[] b, int off, int len) throws IOException { inner = new ByteArrayInputStream(b, off, len); byte[] data = ByteString.readFrom(zis).toByteArray(); - super.write(data, 0, data.length); + out.write(data, 0, data.length); + } + + @Override + public void close() throws IOException { + closeShallow(); + out.close(); + } + + /** + * Free resources related to decompression without closing the underlying {@link OutputStream}. + */ + public void closeShallow() throws IOException { + zis.close(); } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java index 51effa08170977..80d55edc7a0677 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java @@ -15,14 +15,12 @@ import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; -import static org.mockito.ArgumentMatchers.any; import build.bazel.remote.execution.v2.Digest; import com.github.luben.zstd.Zstd; import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase; import com.google.bytestream.ByteStreamProto.ReadRequest; import com.google.bytestream.ByteStreamProto.ReadResponse; -import com.google.devtools.build.lib.remote.Retrier.Backoff; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.common.options.Options; import com.google.protobuf.ByteString; @@ -31,7 +29,6 @@ import java.io.IOException; import java.util.Arrays; import org.junit.Test; -import org.mockito.Mockito; /** Extra tests for {@link GrpcCacheClient} that are not tested internally. */ public class GrpcCacheClientTestExtra extends GrpcCacheClientTest { @@ -39,30 +36,43 @@ public class GrpcCacheClientTestExtra extends GrpcCacheClientTest { @Test public void compressedDownloadBlobIsRetriedWithProgress() throws IOException, InterruptedException { - Backoff mockBackoff = Mockito.mock(Backoff.class); RemoteOptions options = Options.getDefaults(RemoteOptions.class); options.cacheCompression = true; - final GrpcCacheClient client = newClient(options, () -> mockBackoff); + final GrpcCacheClient client = newClient(options); final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg"); - ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8))); + ByteString chunk1 = ByteString.copyFrom(Zstd.compress("abc".getBytes(UTF_8))); + ByteString chunk2 = ByteString.copyFrom(Zstd.compress("def".getBytes(UTF_8))); + ByteString chunk3 = ByteString.copyFrom(Zstd.compress("g".getBytes(UTF_8))); serviceRegistry.addService( new ByteStreamImplBase() { + private boolean first = true; + @Override public void read(ReadRequest request, StreamObserver responseObserver) { assertThat(request.getResourceName().contains(digest.getHash())).isTrue(); - int off = (int) request.getReadOffset(); - // Zstd header size is 9 bytes - ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off); - responseObserver.onNext(ReadResponse.newBuilder().setData(data).build()); - if (off == 0) { + if (first) { + first = false; responseObserver.onError(Status.DEADLINE_EXCEEDED.asException()); - } else { - responseObserver.onCompleted(); + return; + } + switch (Math.toIntExact(request.getReadOffset())) { + case 0: + responseObserver.onNext(ReadResponse.newBuilder().setData(chunk1).build()); + break; + case 3: + responseObserver.onNext(ReadResponse.newBuilder().setData(chunk2).build()); + break; + case 6: + responseObserver.onNext(ReadResponse.newBuilder().setData(chunk3).build()); + responseObserver.onCompleted(); + return; + default: + throw new IllegalStateException("unexpected offset " + request.getReadOffset()); } + responseObserver.onError(Status.DEADLINE_EXCEEDED.asException()); } }); assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg"); - Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class)); } @Test diff --git a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java index 22cba85b8b6f68..62352dd5678a98 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java +++ b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java @@ -63,7 +63,6 @@ public void bytesWrittenMatchesDecompressedBytes() throws IOException { for (byte b : compressed.toByteArray()) { zdos.write(b); zdos.flush(); - assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length); } assertThat(decompressed.toByteArray()).isEqualTo(data); }