diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index d39bdd55baaba4..679fcebb4f7554 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -26,8 +26,7 @@ import build.bazel.remote.execution.v2.FindMissingBlobsResponse; import build.bazel.remote.execution.v2.GetActionResultRequest; import build.bazel.remote.execution.v2.RequestMetadata; -import build.bazel.remote.execution.v2.UpdateActionResultRequest; -import com.github.luben.zstd.ZstdInputStream; +import build.bazel.remote.execution.v2.UpdateActionResultRequest import com.google.bytestream.ByteStreamGrpc; import com.google.bytestream.ByteStreamGrpc.ByteStreamStub; import com.google.bytestream.ByteStreamProto.ReadRequest; @@ -61,9 +60,7 @@ import io.grpc.StatusRuntimeException; import io.grpc.stub.StreamObserver; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.List; @@ -298,6 +295,14 @@ public ListenableFuture downloadBlob( out = digestOut; } + if (options.cacheByteStreamCompression) { + try { + out = new ZstdDecompressingOutputStream(out); + } catch (IOException e) { + return Futures.immediateFailedFuture(e); + } + } + return downloadBlob(context, digest, out, digestSupplier); } @@ -352,41 +357,13 @@ private ListenableFuture requestRead( .setReadOffset(offset.get()) .build(), new StreamObserver() { - InputStream inner; - ZstdInputStream zis; - - { - initialise(); - } - - private void initialise() throws IOException { - if (options.cacheByteStreamCompression) { - zis = - new ZstdInputStream( - new InputStream() { - @Override - public int read() throws IOException { - return inner.read(); - } - }); - - zis.setContinuous(true); - } - } @Override public void onNext(ReadResponse readResponse) { ByteString data = readResponse.getData(); try { - if (options.cacheByteStreamCompression) { - inner = new ByteArrayInputStream(data.toByteArray()); - ByteString bs = ByteString.readFrom(zis); - bs.writeTo(out); - offset.addAndGet(bs.size()); - } else { - data.writeTo(out); - offset.addAndGet(data.size()); - } + data.writeTo(out); + offset.addAndGet(data.size()); } catch (IOException e) { // Cancel the call. throw new RuntimeException(e); diff --git a/src/main/java/com/google/devtools/build/lib/remote/ZstdDecompressingOutputStream.java b/src/main/java/com/google/devtools/build/lib/remote/ZstdDecompressingOutputStream.java new file mode 100644 index 00000000000000..cb644888522bcc --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/ZstdDecompressingOutputStream.java @@ -0,0 +1,40 @@ +package com.google.devtools.build.lib.remote; + +import com.github.luben.zstd.ZstdInputStream; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +public class ZstdDecompressingOutputStream extends OutputStream { + private final ByteArrayOutputStream baos; + private final ZstdInputStream zis; + private final OutputStream out; + private int lastByte; + + ZstdDecompressingOutputStream(OutputStream out) throws IOException { + this.out = out; + baos = new ByteArrayOutputStream(); + zis = + new ZstdInputStream( + new InputStream() { + @Override + public int read() throws IOException { + int value = lastByte; + lastByte = -1; + return value; + } + }); + zis.setContinuous(true); + } + + @Override + public void write(int b) throws IOException { + lastByte = b; + int c; + while ((c = zis.read()) != -1) { + out.write(c); + } + } +}